From 3f57690471b1ab5dcc6c1d0c6f31587187bba7fa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 18 Oct 2021 12:14:36 +0200 Subject: [PATCH 1/6] replace np.frombuffer with torch.frombuffer in MNIST prototype --- .../prototype/datasets/_builtin/mnist.py | 67 ++++++++++--------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index b20c3ed6266..478697bb246 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -5,9 +5,9 @@ import operator import pathlib import string +import sys from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union -import numpy as np import torch from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter import ( @@ -40,12 +40,12 @@ class MNISTFileReader(IterDataPipe): _DTYPE_MAP = { - 8: "u1", # uint8 - 9: "i1", # int8 - 11: "i2", # int16 - 12: "i4", # int32 - 13: "f4", # float32 - 14: "f8", # float64 + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, } def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Optional[int]) -> None: @@ -57,30 +57,33 @@ def __init__(self, datapipe: IterDataPipe, *, start: Optional[int], stop: Option def _decode(bytes): return int(codecs.encode(bytes, "hex"), 16) - def __iter__(self) -> Iterator[np.ndarray]: + def __iter__(self) -> Iterator[torch.Tensor]: for _, file in self.datapipe: magic = self._decode(file.read(4)) - dtype_type = self._DTYPE_MAP[magic // 256] + dtype = self._DTYPE_MAP[magic // 256] ndim = magic % 256 - 1 num_samples = self._decode(file.read(4)) shape = [self._decode(file.read(4)) for _ in range(ndim)] - in_dtype = np.dtype(f">{dtype_type}") - out_dtype = np.dtype(dtype_type) - chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize + num_bytes_per_value = torch.iinfo(dtype).bits // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with torch.frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 + chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value start = self.start or 0 stop = self.stop or num_samples file.seek(start * chunk_size, 1) for _ in range(stop - start): - chunk = file.read(chunk_size) - yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) + chunk = bytearray(file.read(chunk_size)) + if needs_byte_reversal: + chunk.reverse() + yield torch.frombuffer(chunk, dtype=dtype).reshape(shape) class _MNISTBase(Dataset): - _FORMAT = "png" _URL_BASE: str @abc.abstractmethod @@ -103,24 +106,22 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ): - image_array, label_array = data + image, label = data image: Union[torch.Tensor, io.BytesIO] - if decoder is raw: - image = torch.from_numpy(image_array) - else: - image_buffer = image_buffer_from_array(image_array) + if decoder is not raw: + image_buffer = image_buffer_from_array(image.numpy()) # type: ignore[union-attr] image = decoder(image_buffer) if decoder else image_buffer - label = torch.tensor(label_array, dtype=torch.int64) category = self.info.categories[int(label)] + label = label.to(torch.int64) - return dict(image=image, label=label, category=category) + return dict(image=image, category=category, label=label) def _make_datapipe( self, @@ -291,12 +292,11 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ): - image_array, label_array = data # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, @@ -306,8 +306,9 @@ def _collate_and_decode( # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) # in self.categories. Thus, we need to add 1 to the label to correct this. if config.image_set in ("Balanced", "By_Merge"): - label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype) - return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder) + label = data[1] + label += self._LABEL_OFFSETS.get(int(label), 0) + return super()._collate_and_decode(data, config=config, decoder=decoder) def _make_datapipe( self, @@ -377,22 +378,22 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional def _collate_and_decode( self, - data: Tuple[np.ndarray, np.ndarray], + data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig, decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ): - image_array, label_array = data - label_parts = label_array.tolist() - sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) + image, ann = data + label, *extra_anns = ann + sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) sample.update( dict( zip( ("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"), - label_parts[1:6], + [int(value) for value in extra_anns[:5]], ) ) ) - sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]]))) + sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) return sample From 723236c173f4618a2ea5926f7712bbc80bfd10e3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 20 Oct 2021 15:08:15 +0200 Subject: [PATCH 2/6] cleanup --- torchvision/prototype/datasets/_builtin/mnist.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 478697bb246..c7e7680303f 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -80,7 +80,10 @@ def __iter__(self) -> Iterator[torch.Tensor]: chunk = bytearray(file.read(chunk_size)) if needs_byte_reversal: chunk.reverse() - yield torch.frombuffer(chunk, dtype=dtype).reshape(shape) + data = torch.frombuffer(chunk, dtype=dtype) + if needs_byte_reversal: + data = data.flip(0) + yield data.reshape(shape) class _MNISTBase(Dataset): @@ -114,7 +117,9 @@ def _collate_and_decode( image, label = data image: Union[torch.Tensor, io.BytesIO] - if decoder is not raw: + if decoder is raw: + image = image.unsqueeze(0) + else: image_buffer = image_buffer_from_array(image.numpy()) # type: ignore[union-attr] image = decoder(image_buffer) if decoder else image_buffer From 8791c4cce852c72520ff214b4ee2fc3602e1a422 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 20 Oct 2021 15:20:22 +0200 Subject: [PATCH 3/6] appease mypy --- torchvision/prototype/datasets/_builtin/mnist.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index a87a440fcfa..8fa3f9cc938 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -6,7 +6,7 @@ import pathlib import string import sys -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast import torch from torchdata.datapipes.iter import ( @@ -116,12 +116,11 @@ def _collate_and_decode( ): image, label = data - image: Union[torch.Tensor, io.BytesIO] if decoder is raw: image = image.unsqueeze(0) else: - image_buffer = image_buffer_from_array(image.numpy()) # type: ignore[union-attr] - image = decoder(image_buffer) if decoder else image_buffer + image_buffer = image_buffer_from_array(image.numpy()) + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] category = self.info.categories[int(label)] label = label.to(torch.int64) From 98a51f98c322e96e7ed4562e60b4a331f4b9f532 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 21 Oct 2021 05:52:00 +0200 Subject: [PATCH 4/6] more cleanup --- torchvision/prototype/datasets/_builtin/mnist.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 8fa3f9cc938..5717c4a4fbd 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -77,13 +77,13 @@ def __iter__(self) -> Iterator[torch.Tensor]: file.seek(start * chunk_size, 1) for _ in range(stop - start): - chunk = bytearray(file.read(chunk_size)) - if needs_byte_reversal: - chunk.reverse() - data = torch.frombuffer(chunk, dtype=dtype) - if needs_byte_reversal: - data = data.flip(0) - yield data.reshape(shape) + chunk = file.read(chunk_size) + if not needs_byte_reversal: + yield torch.frombuffer(chunk, dtype=dtype).reshape(shape) + + chunk = bytearray(chunk) + chunk.reverse() + yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape) class _MNISTBase(Dataset): From 3a85b22f8285450200a16e1813db274f0fccc3b7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 21 Oct 2021 14:12:11 +0200 Subject: [PATCH 5/6] clarify inplace offset --- torchvision/prototype/datasets/_builtin/mnist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 5717c4a4fbd..b86ee42cc2d 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -310,8 +310,7 @@ def _collate_and_decode( # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) # in self.categories. Thus, we need to add 1 to the label to correct this. if config.image_set in ("Balanced", "By_Merge"): - label = data[1] - label += self._LABEL_OFFSETS.get(int(label), 0) + data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0) return super()._collate_and_decode(data, config=config, decoder=decoder) def _make_datapipe( From e04be6f63e1c8fb9d7beb2622ac10184ad4537ed Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 26 Oct 2021 11:48:31 +0200 Subject: [PATCH 6/6] fix num bytes for floating point data --- torchvision/prototype/datasets/_builtin/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 6034b061dd6..02cc0e0e3d4 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -68,7 +68,7 @@ def __iter__(self) -> Iterator[torch.Tensor]: num_samples = self._decode(file.read(4)) shape = [self._decode(file.read(4)) for _ in range(ndim)] - num_bytes_per_value = torch.iinfo(dtype).bits // 8 + num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, # we need to reverse the bytes before we can read them with torch.frombuffer(). needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1