Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f61a0b9
add FloReader datapipe
pmeier Nov 8, 2021
675eaa0
add NumericBinaryReader
pmeier Nov 8, 2021
e0157bc
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 8, 2021
05e934f
revert unrelated change
pmeier Nov 8, 2021
3a2d812
cleanup
pmeier Nov 8, 2021
2d7111d
cleanup
pmeier Nov 8, 2021
f984983
add comment for byte reversal
pmeier Nov 8, 2021
c4b46b7
use numpy after all
pmeier Nov 8, 2021
ba362a7
appease mypy
pmeier Nov 8, 2021
263b454
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 8, 2021
3bb9256
use .astype() with copy=False
pmeier Nov 9, 2021
5e029a9
add docstring and cleanuo
pmeier Nov 9, 2021
e9c5584
reuse current _read_flo and revert MNIST changes
pmeier Nov 10, 2021
fa4fafb
cleanup
pmeier Nov 10, 2021
61a71a1
revert demonstration
pmeier Nov 16, 2021
950dc49
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 16, 2021
68f2d95
refactor
pmeier Nov 16, 2021
a3823ba
cleanup
pmeier Nov 16, 2021
de865cf
add support for mutable memory
pmeier Nov 18, 2021
c3fd445
add test
pmeier Nov 18, 2021
7c3a33f
add comments
pmeier Nov 18, 2021
2c62670
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 18, 2021
aa780fd
catch more exceptions
pmeier Nov 18, 2021
e9031af
fix mypy
pmeier Nov 18, 2021
1d55fc0
fix variable names
pmeier Nov 18, 2021
5ebb5ae
hardcode flow sizes in test
pmeier Nov 18, 2021
ac3e4c2
add fix dtype docstring
pmeier Nov 18, 2021
507681a
expand comment on different reading modes
pmeier Nov 18, 2021
c52c547
add comment about files in update mode
pmeier Nov 18, 2021
80e8f25
add tests for fromfile
pmeier Nov 18, 2021
1979d17
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 18, 2021
2bb491b
cleanup
pmeier Nov 19, 2021
388ccb1
Merge branch 'main' into flow-reader-datapipe
pmeier Nov 19, 2021
0969cf9
cleanup
pmeier Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/test_prototype_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys

import numpy as np
import pytest
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile


@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning")
@pytest.mark.parametrize(
("np_dtype", "torch_dtype", "byte_order"),
[
(">f4", torch.float32, "big"),
("<f8", torch.float64, "little"),
("<i4", torch.int32, "little"),
(">i8", torch.int64, "big"),
("|u1", torch.uint8, sys.byteorder),
],
)
@pytest.mark.parametrize("count", (-1, 2))
@pytest.mark.parametrize("mode", ("rb", "r+b"))
def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode):
path = tmpdir / "data.bin"
rng = np.random.RandomState(0)
rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path)

for count_ in (-1, count // 2):
expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:]))

with open(path, mode) as file:
actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_)

torch.testing.assert_close(actual, expected)


def test_read_flo(tmpdir):
path = tmpdir / "test.flo"
make_fake_flo_file(3, 4, path)

with open(path, "rb") as file:
actual = read_flo(file)

expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))

torch.testing.assert_close(actual, expected)
46 changes: 14 additions & 32 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import abc
import codecs
import functools
import io
import operator
import pathlib
import string
import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO

import torch
from torchdata.datapipes.iter import (
Expand All @@ -30,6 +28,7 @@
image_buffer_from_array,
Decompressor,
INFINITE_BUFFER_SIZE,
fromfile,
)
from torchvision.prototype.features import Image, Label

Expand All @@ -50,50 +49,33 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
}

def __init__(
self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int]
self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int]
) -> None:
self.datapipe = datapipe
self.start = start
self.stop = stop

@staticmethod
def _decode(input: bytes) -> int:
return int(codecs.encode(input, "hex"), 16)

@staticmethod
def _to_tensor(chunk: bytes, *, dtype: torch.dtype, shape: List[int], reverse_bytes: bool) -> torch.Tensor:
# As is, the chunk is not writeable, because it is read from a file and not from memory. Thus, we copy here to
# avoid the warning that torch.frombuffer would emit otherwise. This also enables inplace operations on the
# contents, which would otherwise fail.
chunk = bytearray(chunk)
if reverse_bytes:
chunk.reverse()
tensor = torch.frombuffer(chunk, dtype=dtype).flip(0)
else:
tensor = torch.frombuffer(chunk, dtype=dtype)
return tensor.reshape(shape)

def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
magic = self._decode(file.read(4))
read = functools.partial(fromfile, file, byte_order="big")

magic = int(read(dtype=torch.int32, count=1))
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1

num_samples = self._decode(file.read(4))
shape = [self._decode(file.read(4)) for _ in range(ndim)]

num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
reverse_bytes = sys.byteorder == "little" and num_bytes_per_value > 1
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
num_samples = int(read(dtype=torch.int32, count=1))
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
count = prod(shape) if shape else 1

start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples

file.seek(start * chunk_size, 1)
if start:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
file.seek(num_bytes_per_value * count * start, 1)

for _ in range(stop - start):
yield self._to_tensor(file.read(chunk_size), dtype=dtype, shape=shape, reverse_bytes=reverse_bytes)
yield read(dtype=dtype, count=count).reshape(shape)


class _MNISTBase(Dataset):
Expand Down
68 changes: 68 additions & 0 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import gzip
import io
import lzma
import mmap
import os
import os.path
import pathlib
import pickle
from typing import BinaryIO
from typing import (
Sequence,
Callable,
Expand All @@ -24,6 +26,7 @@

import numpy as np
import PIL.Image
import torch
import torch.distributed as dist
import torch.utils.data
from torch.utils.data import IterDataPipe
Expand All @@ -43,6 +46,8 @@
"path_accessor",
"path_comparator",
"Decompressor",
"fromfile",
"read_flo",
]

K = TypeVar("K")
Expand Down Expand Up @@ -253,3 +258,66 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp


def fromfile(
file: BinaryIO,
*,
dtype: torch.dtype,
byte_order: str,
count: int = -1,
) -> torch.Tensor:
"""Construct a tensor from a binary file.

.. note::

This function is similar to :func:`numpy.fromfile` with two notable differences:

1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.

.. note::

If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.

Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order = "<" if byte_order == "little" else ">"
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)

# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to
# a mutable location afterwards.
buffer: Union[memoryview, bytearray]
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation):
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
buffer = bytearray(file.read(-1 if count == -1 else count * item_size))

# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))


def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")

width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1))