diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index af22199ce39..02cc0e0e3d4 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -5,9 +5,9 @@ import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union +import sys +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast -import numpy as np import torch from torchdata.datapipes.iter import ( IterDataPipe, @@ -38,14 +38,14 @@ prod = functools.partial(functools.reduce, operator.mul) -class MNISTFileReader(IterDataPipe[np.ndarray]): +class MNISTFileReader(IterDataPipe[torch.Tensor]): _DTYPE_MAP = { - 8: "u1", # uint8 - 9: "i1", # int8 - 11: "i2", # int16 - 12: "i4", # int32 - 13: "f4", # float32 - 14: "f8", # float64 + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, } def __init__( @@ -59,18 +59,20 @@ def __init__( def _decode(bytes: bytes) -> int: return int(codecs.encode(bytes, "hex"), 16) - def __iter__(self) -> Iterator[np.ndarray]: + def __iter__(self) -> Iterator[torch.Tensor]: for _, file in self.datapipe: magic = self._decode(file.read(4)) - dtype_type = self._DTYPE_MAP[magic // 256] + dtype = self._DTYPE_MAP[magic // 256] ndim = magic % 256 - 1 num_samples = self._decode(file.read(4)) shape = [self._decode(file.read(4)) for _ in range(ndim)] - in_dtype = np.dtype(f">{dtype_type}") - out_dtype = np.dtype(dtype_type) - chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize + num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with torch.frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 + chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value start = self.start or 0 stop = self.stop or num_samples @@ -78,11 +80,15 @@ def __iter__(self) -> Iterator[np.ndarray]: file.seek(start * chunk_size, 1) for _ in range(stop - start): chunk = file.read(chunk_size) - yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) + if not needs_byte_reversal: + yield torch.frombuffer(chunk, dtype=dtype).reshape(shape) + + chunk = bytearray(chunk) + chunk.reverse() + yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape) class _MNISTBase(Dataset): - _FORMAT = "png" _URL_BASE: str @abc.abstractmethod @@ -105,24 +111,23 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: - image_array, label_array = data + image, label = data - image: Union[torch.Tensor, io.BytesIO] if decoder is raw: - image = torch.from_numpy(image_array) + image = image.unsqueeze(0) else: - image_buffer = image_buffer_from_array(image_array) - image = decoder(image_buffer) if decoder else image_buffer + image_buffer = image_buffer_from_array(image.numpy()) + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - label = torch.tensor(label_array, dtype=torch.int64) category = self.info.categories[int(label)] + label = label.to(torch.int64) - return dict(image=image, label=label, category=category) + return dict(image=image, category=category, label=label) def _make_datapipe( self, @@ -293,12 +298,11 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: - image_array, label_array = data # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, @@ -308,8 +312,8 @@ def _collate_and_decode( # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) # in self.categories. Thus, we need to add 1 to the label to correct this. if config.image_set in ("Balanced", "By_Merge"): - label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype) - return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder) + data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0) + return super()._collate_and_decode(data, config=config, decoder=decoder) def _make_datapipe( self, @@ -379,22 +383,22 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: - image_array, label_array = data - label_parts = label_array.tolist() - sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) + image, ann = data + label, *extra_anns = ann + sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) sample.update( dict( zip( ("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"), - label_parts[1:6], + [int(value) for value in extra_anns[:5]], ) ) ) - sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]]))) + sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) return sample