Skip to content
72 changes: 38 additions & 34 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import operator
import pathlib
import string
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union
import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast

import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Expand Down Expand Up @@ -38,14 +38,14 @@
prod = functools.partial(functools.reduce, operator.mul)


class MNISTFileReader(IterDataPipe[np.ndarray]):
class MNISTFileReader(IterDataPipe[torch.Tensor]):
_DTYPE_MAP = {
8: "u1", # uint8
9: "i1", # int8
11: "i2", # int16
12: "i4", # int32
13: "f4", # float32
14: "f8", # float64
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}

def __init__(
Expand All @@ -59,30 +59,36 @@ def __init__(
def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16)

def __iter__(self) -> Iterator[np.ndarray]:
def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
magic = self._decode(file.read(4))
dtype_type = self._DTYPE_MAP[magic // 256]
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1

num_samples = self._decode(file.read(4))
shape = [self._decode(file.read(4)) for _ in range(ndim)]

in_dtype = np.dtype(f">{dtype_type}")
out_dtype = np.dtype(dtype_type)
chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
Comment on lines +72 to +74
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lc0 you also need to take care of that in #4598.

chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value

start = self.start or 0
stop = self.stop or num_samples

file.seek(start * chunk_size, 1)
for _ in range(stop - start):
chunk = file.read(chunk_size)
yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape)
if not needs_byte_reversal:
yield torch.frombuffer(chunk, dtype=dtype).reshape(shape)

chunk = bytearray(chunk)
chunk.reverse()
yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape)


class _MNISTBase(Dataset):
_FORMAT = "png"
_URL_BASE: str

@abc.abstractmethod
Expand All @@ -105,24 +111,23 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional

def _collate_and_decode(
self,
data: Tuple[np.ndarray, np.ndarray],
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data
image, label = data

image: Union[torch.Tensor, io.BytesIO]
if decoder is raw:
image = torch.from_numpy(image_array)
image = image.unsqueeze(0)
else:
image_buffer = image_buffer_from_array(image_array)
image = decoder(image_buffer) if decoder else image_buffer
image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]

label = torch.tensor(label_array, dtype=torch.int64)
category = self.info.categories[int(label)]
label = label.to(torch.int64)

return dict(image=image, label=label, category=category)
return dict(image=image, category=category, label=label)

def _make_datapipe(
self,
Expand Down Expand Up @@ -293,12 +298,11 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) ->

def _collate_and_decode(
self,
data: Tuple[np.ndarray, np.ndarray],
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
Expand All @@ -308,8 +312,8 @@ def _collate_and_decode(
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self.categories. Thus, we need to add 1 to the label to correct this.
if config.image_set in ("Balanced", "By_Merge"):
label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype)
return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder)
data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0)
return super()._collate_and_decode(data, config=config, decoder=decoder)

def _make_datapipe(
self,
Expand Down Expand Up @@ -379,22 +383,22 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional

def _collate_and_decode(
self,
data: Tuple[np.ndarray, np.ndarray],
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data
label_parts = label_array.tolist()
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder)
image, ann = data
label, *extra_anns = ann
sample = super()._collate_and_decode((image, label), config=config, decoder=decoder)

sample.update(
dict(
zip(
("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"),
label_parts[1:6],
[int(value) for value in extra_anns[:5]],
)
)
)
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]])))
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
return sample