diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py new file mode 100644 index 00000000000..7207299c9d4 --- /dev/null +++ b/test/test_prototype_datasets_utils.py @@ -0,0 +1,47 @@ +import sys + +import numpy as np +import pytest +import torch +from datasets_utils import make_fake_flo_file +from torchvision.datasets._optical_flow import _read_flo as read_flo_ref +from torchvision.prototype.datasets.utils._internal import read_flo, fromfile + + +@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning") +@pytest.mark.parametrize( + ("np_dtype", "torch_dtype", "byte_order"), + [ + (">f4", torch.float32, "big"), + ("i8", torch.int64, "big"), + ("|u1", torch.uint8, sys.byteorder), + ], +) +@pytest.mark.parametrize("count", (-1, 2)) +@pytest.mark.parametrize("mode", ("rb", "r+b")) +def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode): + path = tmpdir / "data.bin" + rng = np.random.RandomState(0) + rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path) + + for count_ in (-1, count // 2): + expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:])) + + with open(path, mode) as file: + actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_) + + torch.testing.assert_close(actual, expected) + + +def test_read_flo(tmpdir): + path = tmpdir / "test.flo" + make_fake_flo_file(3, 4, path) + + with open(path, "rb") as file: + actual = read_flo(file) + + expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False)) + + torch.testing.assert_close(actual, expected) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 5c22521612c..c242207b7d7 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -1,12 +1,10 @@ import abc -import codecs import functools import io import operator import pathlib import string -import sys -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -30,6 +28,7 @@ image_buffer_from_array, Decompressor, INFINITE_BUFFER_SIZE, + fromfile, ) from torchvision.prototype.features import Image, Label @@ -50,50 +49,33 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]): } def __init__( - self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int] + self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int] ) -> None: self.datapipe = datapipe self.start = start self.stop = stop - @staticmethod - def _decode(input: bytes) -> int: - return int(codecs.encode(input, "hex"), 16) - - @staticmethod - def _to_tensor(chunk: bytes, *, dtype: torch.dtype, shape: List[int], reverse_bytes: bool) -> torch.Tensor: - # As is, the chunk is not writeable, because it is read from a file and not from memory. Thus, we copy here to - # avoid the warning that torch.frombuffer would emit otherwise. This also enables inplace operations on the - # contents, which would otherwise fail. - chunk = bytearray(chunk) - if reverse_bytes: - chunk.reverse() - tensor = torch.frombuffer(chunk, dtype=dtype).flip(0) - else: - tensor = torch.frombuffer(chunk, dtype=dtype) - return tensor.reshape(shape) - def __iter__(self) -> Iterator[torch.Tensor]: for _, file in self.datapipe: - magic = self._decode(file.read(4)) + read = functools.partial(fromfile, file, byte_order="big") + + magic = int(read(dtype=torch.int32, count=1)) dtype = self._DTYPE_MAP[magic // 256] ndim = magic % 256 - 1 - num_samples = self._decode(file.read(4)) - shape = [self._decode(file.read(4)) for _ in range(ndim)] - - num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, - # we need to reverse the bytes before we can read them with torch.frombuffer(). - reverse_bytes = sys.byteorder == "little" and num_bytes_per_value > 1 - chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value + num_samples = int(read(dtype=torch.int32, count=1)) + shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else [] + count = prod(shape) if shape else 1 start = self.start or 0 stop = min(self.stop, num_samples) if self.stop else num_samples - file.seek(start * chunk_size, 1) + if start: + num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 + file.seek(num_bytes_per_value * count * start, 1) + for _ in range(stop - start): - yield self._to_tensor(file.read(chunk_size), dtype=dtype, shape=shape, reverse_bytes=reverse_bytes) + yield read(dtype=dtype, count=count).reshape(shape) class _MNISTBase(Dataset): diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index a86447bfc2b..3db10183f68 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -3,10 +3,12 @@ import gzip import io import lzma +import mmap import os import os.path import pathlib import pickle +from typing import BinaryIO from typing import ( Sequence, Callable, @@ -24,6 +26,7 @@ import numpy as np import PIL.Image +import torch import torch.distributed as dist import torch.utils.data from torch.utils.data import IterDataPipe @@ -43,6 +46,8 @@ "path_accessor", "path_comparator", "Decompressor", + "fromfile", + "read_flo", ] K = TypeVar("K") @@ -253,3 +258,66 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe: # dp = dp.cycle(2) dp = TakerDataPipe(dp, dataset_size) return dp + + +def fromfile( + file: BinaryIO, + *, + dtype: torch.dtype, + byte_order: str, + count: int = -1, +) -> torch.Tensor: + """Construct a tensor from a binary file. + + .. note:: + + This function is similar to :func:`numpy.fromfile` with two notable differences: + + 1. This function only accepts an open binary file, but not a path to it. + 2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that + concept. + + .. note:: + + If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as + long as the file is still open, inplace operations on the returned tensor will reflect back to the file. + + Args: + file (IO): Open binary file. + dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor. + byte_order (str): Byte order of the data. Can be "little" or "big" endian. + count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file. + """ + byte_order = "<" if byte_order == "little" else ">" + char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u") + item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 + np_dtype = byte_order + char + str(item_size) + + # PyTorch does not support tensors with underlying read-only memory. In case + # - the file has a .fileno(), + # - the file was opened for updating, i.e. 'r+b' or 'w+b', + # - the file is seekable + # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to + # a mutable location afterwards. + buffer: Union[memoryview, bytearray] + try: + buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] + # Reading from the memoryview does not advance the file cursor, so we have to do it manually. + file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) + except (PermissionError, io.UnsupportedOperation): + # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable + buffer = bytearray(file.read(-1 if count == -1 else count * item_size)) + + # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we + # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the + # successive .astype() call. + return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False)) + + +def read_flo(file: BinaryIO) -> torch.Tensor: + if file.read(4) != b"PIEH": + raise ValueError("Magic number incorrect. Invalid .flo file") + + width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2) + flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2) + return flow.reshape((height, width, 2)).permute((2, 0, 1))