diff --git a/mypy.ini b/mypy.ini index c2012102143..6d7863b627e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,36 @@ pretty = True allow_redefinition = True warn_redundant_casts = True +[mypy-torchvision.prototype.features.*] + +; untyped definitions and calls +disallow_untyped_defs = True + +; None and Optional handling +no_implicit_optional = True + +; warnings +warn_unused_ignores = True +warn_return_any = True + +; miscellaneous strictness flags +allow_redefinition = True + +[mypy-torchvision.prototype.transforms.*] + +; untyped definitions and calls +disallow_untyped_defs = True + +; None and Optional handling +no_implicit_optional = True + +; warnings +warn_unused_ignores = True +warn_return_any = True + +; miscellaneous strictness flags +allow_redefinition = True + [mypy-torchvision.prototype.datasets.*] ; untyped definitions and calls diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index d79cd78d1ff..123d8f29d3f 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -432,50 +432,52 @@ def caltech256(info, root, config): @register_mock def imagenet(info, root, config): - wnids = tuple(info.extra.wnid_to_category.keys()) - if config.split == "train": - images_root = root / "ILSVRC2012_img_train" + from scipy.io import savemat + categories = info.categories + wnids = [info.extra.category_to_wnid[category] for category in categories] + if config.split == "train": num_samples = len(wnids) + archive_name = "ILSVRC2012_img_train.tar" + files = [] for wnid in wnids: - files = create_image_folder( - root=images_root, + create_image_folder( + root=root, name=wnid, file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG", num_examples=1, ) - make_tar(images_root, f"{wnid}.tar", files[0].parent) + files.append(make_tar(root, f"{wnid}.tar")) elif config.split == "val": num_samples = 3 - files = create_image_folder( - root=root, - name="ILSVRC2012_img_val", - file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG", - num_examples=num_samples, - ) - images_root = files[0].parent - else: # config.split == "test" - images_root = root / "ILSVRC2012_img_test_v10102019" + archive_name = "ILSVRC2012_img_val.tar" + files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - num_samples = 3 + devkit_root = root / "ILSVRC2012_devkit_t12" + data_root = devkit_root / "data" + data_root.mkdir(parents=True) - create_image_folder( - root=images_root, - name="test", - file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG", - num_examples=num_samples, - ) - make_tar(root, f"{images_root.name}.tar", images_root) + with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: + for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): + file.write(f"{label}\n") + + num_children = 0 + synsets = [ + (idx, wnid, category, "", num_children, [], 0, 0) + for idx, (category, wnid) in enumerate(zip(categories, wnids), 1) + ] + num_children = 1 + synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) + savemat(data_root / "meta.mat", dict(synsets=synsets)) + + make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") + else: # config.split == "test" + num_samples = 5 + archive_name = "ILSVRC2012_img_test_v10102019.tar" + files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] - devkit_root = root / "ILSVRC2012_devkit_t12" - devkit_root.mkdir() - data_root = devkit_root / "data" - data_root.mkdir() - with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): - file.write(f"{label}\n") - make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz") + make_tar(root, archive_name, *files) return num_samples @@ -667,14 +669,15 @@ def sbd(info, root, config): @register_mock def semeion(info, root, config): num_samples = 3 + num_categories = len(info.categories) images = torch.rand(num_samples, 256) - labels = one_hot(torch.randint(len(info.categories), size=(num_samples,))) + labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) with open(root / "semeion.data", "w") as fh: for image, one_hot_label in zip(images, labels): image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image]) labels_columns = " ".join([str(label.item()) for label in one_hot_label]) - fh.write(f"{image_columns} {labels_columns}\n") + fh.write(f"{image_columns} {labels_columns} \n") return num_samples @@ -729,32 +732,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples): def _make_detection_ann_file(cls, root, name): def add_child(parent, name, text=None): child = ET.SubElement(parent, name) - child.text = text + child.text = str(text) return child def add_name(obj, name="dog"): add_child(obj, "name", name) - return name - def add_bndbox(obj, bndbox=None): - if bndbox is None: - bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"} + def add_size(obj): + obj = add_child(obj, "size") + size = {"width": 0, "height": 0, "depth": 3} + for name, text in size.items(): + add_child(obj, name, text) + def add_bndbox(obj): obj = add_child(obj, "bndbox") + bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4} for name, text in bndbox.items(): add_child(obj, name, text) - return bndbox - annotation = ET.Element("annotation") + add_size(annotation) obj = add_child(annotation, "object") - data = dict(name=add_name(obj), bndbox=add_bndbox(obj)) + add_name(obj) + add_bndbox(obj) with open(root / name, "wb") as fh: fh.write(ET.tostring(annotation)) - return data - @classmethod def generate(cls, root, *, year, trainval): archive_folder = root diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index d8e07314e00..067359cac2b 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -1,9 +1,11 @@ +import functools import io from pathlib import Path import pytest import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS +from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse from torchdata.datapipes.iter import IterDataPipe, Shuffler @@ -11,6 +13,11 @@ from torchvision.prototype.utils._internal import sequence_to_str +assert_samples_equal = functools.partial( + assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True +) + + @pytest.fixture def test_home(mocker, tmp_path): mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) @@ -92,6 +99,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." ) + @pytest.mark.xfail @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) @@ -137,6 +145,17 @@ def scan(graph): if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_save_load(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + sample = next(iter(dataset)) + + with io.BytesIO() as buffer: + torch.save(sample, buffer) + buffer.seek(0) + assert_samples_equal(torch.load(buffer), sample) + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: @@ -171,5 +190,5 @@ def test_label_matches_path(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: - label_from_path = int(Path(sample["image_path"]).parent.name) + label_from_path = int(Path(sample["path"]).parent.name) assert sample["label"] == label_from_path diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py index 3a6126f4990..70a2707d050 100644 --- a/test/test_prototype_datasets_api.py +++ b/test/test_prototype_datasets_api.py @@ -5,8 +5,8 @@ from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch -def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs) +def make_minimal_dataset_info(name="name", categories=None, **kwargs): + return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) class TestFrozenMapping: @@ -176,7 +176,7 @@ def resources(self, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass - def _make_datapipe(self, resource_dps, *, config, decoder): + def _make_datapipe(self, resource_dps, *, config): # This method is just defined to appease the ABC, but will be overwritten at instantiation pass @@ -229,12 +229,3 @@ def test_resources(self, mocker): (call_args, _) = dataset._make_datapipe.call_args assert call_args[0][0] is sentinel - - def test_decoder(self): - dataset = self.DatasetMock() - - sentinel = object() - dataset.load("", decoder=sentinel) - - (_, call_kwargs) = dataset._make_datapipe.call_args - assert call_kwargs["decoder"] is sentinel diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py deleted file mode 100644 index 147243286d4..00000000000 --- a/test/test_prototype_features.py +++ /dev/null @@ -1,185 +0,0 @@ -import functools -import itertools - -import pytest -import torch -from torch.testing import make_tensor as _make_tensor, assert_close -from torchvision.prototype import features -from torchvision.prototype.utils._internal import sequence_to_str - - -make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) - - -def make_image(**kwargs): - data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist())) - return features.Image(data, **kwargs) - - -def make_bounding_box(*, format="xyxy", image_size=(10, 10)): - if isinstance(format, str): - format = features.BoundingBoxFormat[format] - - height, width = image_size - - if format == features.BoundingBoxFormat.XYXY: - x1 = torch.randint(0, width // 2, ()) - y1 = torch.randint(0, height // 2, ()) - x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1 - y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1 - parts = (x1, y1, x2, y2) - elif format == features.BoundingBoxFormat.XYWH: - x = torch.randint(0, width // 2, ()) - y = torch.randint(0, height // 2, ()) - w = torch.randint(1, width - int(x), ()) - h = torch.randint(1, height - int(y), ()) - parts = (x, y, w, h) - elif format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, ()) - cy = torch.randint(1, height - 1, ()) - w = torch.randint(1, min(int(cx), width - int(cx)), ()) - h = torch.randint(1, min(int(cy), height - int(cy)), ()) - parts = (cx, cy, w, h) - else: # format == features.BoundingBoxFormat._SENTINEL: - parts = make_tensor((4,)).unbind() - - return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size) - - -MAKE_DATA_MAP = { - features.Image: make_image, - features.BoundingBox: make_bounding_box, -} - - -def make_feature(feature_type, **meta_data): - maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data)) - return maker(**meta_data) - - -class TestCommon: - FEATURE_TYPES, NON_DEFAULT_META_DATA = zip( - *( - (features.Image, dict(color_space=features.ColorSpace._SENTINEL)), - (features.Label, dict(category="category")), - (features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))), - ) - ) - feature_types = pytest.mark.parametrize( - "feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__ - ) - features = pytest.mark.parametrize( - "feature", - [ - pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__) - for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA) - ], - ) - - def test_consistency(self): - builtin_feature_types = { - name - for name, feature_type in features.__dict__.items() - if not name.startswith("_") - and isinstance(feature_type, type) - and issubclass(feature_type, features.Feature) - and feature_type is not features.Feature - } - untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES} - if untested_feature_types: - raise AssertionError( - f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} " - f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. " - f"Please add it/them to `TestCommon.FEATURE_TYPES`." - ) - - @features - def test_meta_data_attribute_access(self, feature): - for name, value in feature._meta_data.items(): - assert getattr(feature, name) == feature._meta_data[name] - - @feature_types - def test_torch_function(self, feature_type): - input = make_feature(feature_type) - # This can be any Tensor operation besides clone - output = input + 1 - - assert type(output) is torch.Tensor - assert_close(output, input + 1) - - @feature_types - def test_clone(self, feature_type): - input = make_feature(feature_type) - output = input.clone() - - assert type(output) is feature_type - assert_close(output, input) - assert output._meta_data == input._meta_data - - @features - def test_serialization(self, tmpdir, feature): - file = tmpdir / "test_serialization.pt" - - torch.save(feature, str(file)) - loaded_feature = torch.load(str(file)) - - assert isinstance(loaded_feature, type(feature)) - assert_close(loaded_feature, feature) - assert loaded_feature._meta_data == feature._meta_data - - @features - def test_repr(self, feature): - assert type(feature).__name__ in repr(feature) - - -class TestBoundingBox: - @pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2)) - def test_cycle_consistency(self, format, intermediate_format): - input = make_bounding_box(format=format) - output = input.convert(intermediate_format).convert(format) - assert_close(input, output) - - -# For now, tensor subclasses with additional meta data do not work with torchscript. -# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037. -@pytest.mark.xfail -class TestJit: - def test_bounding_box(self): - def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: - old_height, old_width = input.image_size - new_height, new_width = size - - height_scale = new_height / old_height - width_scale = new_width / old_width - - old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts() - - new_x1 = old_x1 * width_scale - new_y1 = old_y1 * height_scale - - new_x2 = old_x2 * width_scale - new_y2 = old_y2 * height_scale - - return features.BoundingBox.from_parts( - new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist()) - ) - - def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox: - x, y, w, h = input.convert("xywh").to_parts() - x = input.image_size[1] - (x + w) - return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh") - - def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: - return horizontal_flip(resize(input, size)).convert("xyxy") - - image_size = (8, 6) - input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size) - size = torch.tensor((4, 12)) - expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size) - - actual_eager = compose(input, size) - assert_close(actual_eager, expected) - - sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5))) - actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size) - assert_close(actual_jit, expected) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py deleted file mode 100644 index 2bcd6692e81..00000000000 --- a/test/test_prototype_transforms.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from torchvision.prototype import transforms, features -from torchvision.prototype.utils._internal import sequence_to_str - - -FEATURE_TYPES = { - feature_type - for name, feature_type in features.__dict__.items() - if not name.startswith("_") - and isinstance(feature_type, type) - and issubclass(feature_type, features.Feature) - and feature_type is not features.Feature -} - -TRANSFORM_TYPES = tuple( - transform_type - for name, transform_type in transforms.__dict__.items() - if not name.startswith("_") - and isinstance(transform_type, type) - and issubclass(transform_type, transforms.Transform) - and transform_type is not transforms.Transform -) - - -def test_feature_type_support(): - missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES) - if missing_feature_types: - names = sorted([feature_type.__name__ for feature_type in missing_feature_types]) - raise AssertionError( - f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at " - f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. " - f"Please add it/them to the collection." - ) - - -@pytest.mark.parametrize( - "transform_type", - [transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity], - ids=lambda transform_type: transform_type.__name__, -) -def test_feature_no_op_coverage(transform_type): - unsupported_features = ( - FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES) - ) - if unsupported_features: - names = sorted([feature_type.__name__ for feature_type in unsupported_features]) - raise AssertionError( - f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as " - f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, " - f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection." - ) - - -def test_non_feature_no_op(): - class TestTransform(transforms.Transform): - @staticmethod - def image(input): - return input - - no_op_sample = dict(int=0, float=0.0, bool=False, str="str") - assert TestTransform()(no_op_sample) == no_op_sample diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_kernels.py new file mode 100644 index 00000000000..b83febd8915 --- /dev/null +++ b/test/test_prototype_transforms_kernels.py @@ -0,0 +1,197 @@ +import functools +import itertools + +import pytest +import torch.testing +import torchvision.prototype.transforms.kernels as K +from torch import jit +from torchvision.prototype import features + +make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") + + +def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): + size = size or torch.randint(16, 33, (2,)).tolist() + + if isinstance(color_space, str): + color_space = features.ColorSpace[color_space] + num_channels = { + features.ColorSpace.GRAYSCALE: 1, + features.ColorSpace.RGB: 3, + }[color_space] + + shape = (*extra_dims, num_channels, *size) + if dtype.is_floating_point: + data = torch.rand(shape, dtype=dtype) + else: + data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype) + return features.Image(data, color_space=color_space) + + +make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE) +make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) + + +def make_images( + sizes=((16, 16), (7, 33), (31, 9)), + color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB), + dtypes=(torch.float32, torch.uint8), + extra_dims=((4,), (2, 3)), +): + for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): + yield make_image(size, color_space=color_space) + + for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims): + yield make_image(color_space=color_space, extra_dims=extra_dims_) + + +def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): + low, high = torch.broadcast_tensors( + *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] + ) + try: + return torch.stack( + [ + torch.randint(low_scalar, high_scalar, (), **kwargs) + for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) + ] + ).reshape(low.shape) + except RuntimeError as error: + raise error + + +def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + + height, width = image_size + + if format == features.BoundingBoxFormat.XYXY: + x1 = torch.randint(0, width // 2, extra_dims) + y1 = torch.randint(0, height // 2, extra_dims) + x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 + y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 + parts = (x1, y1, x2, y2) + elif format == features.BoundingBoxFormat.XYWH: + x = torch.randint(0, width // 2, extra_dims) + y = torch.randint(0, height // 2, extra_dims) + w = randint_with_tensor_bounds(1, width - x) + h = randint_with_tensor_bounds(1, height - y) + parts = (x, y, w, h) + elif format == features.BoundingBoxFormat.CXCYWH: + cx = torch.randint(1, width - 1, ()) + cy = torch.randint(1, height - 1, ()) + w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) + h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) + parts = (cx, cy, w, h) + else: # format == features.BoundingBoxFormat._SENTINEL: + raise ValueError() + + return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) + + +make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY) + + +def make_bounding_boxes( + formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH), + image_sizes=((32, 32),), + dtypes=(torch.int64, torch.float32), + extra_dims=((4,), (2, 3)), +): + for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes): + yield make_bounding_box(format=format, image_size=image_size, dtype=dtype) + + for format, extra_dims_ in itertools.product(formats, extra_dims): + yield make_bounding_box(format=format, extra_dims=extra_dims_) + + +class SampleInput: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +class KernelInfo: + def __init__(self, name, *, sample_inputs_fn): + self.name = name + self.kernel = getattr(K, name) + self._sample_inputs_fn = sample_inputs_fn + + def sample_inputs(self): + yield from self._sample_inputs_fn() + + def __call__(self, *args, **kwargs): + if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): + sample_input = args[0] + return self.kernel(*sample_input.args, **sample_input.kwargs) + + return self.kernel(*args, **kwargs) + + +KERNEL_INFOS = [] + + +def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): + KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) + return sample_inputs_fn + + +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_image(): + for image in make_images(): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_bounding_box(): + for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): + yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) + + +@register_kernel_info_from_sample_inputs_fn +def resize_image(): + for image, interpolation in itertools.product( + make_images(), + [ + K.InterpolationMode.BILINEAR, + K.InterpolationMode.NEAREST, + ], + ): + height, width = image.shape[-2:] + for size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + yield SampleInput(image, size=size, interpolation=interpolation) + + +@register_kernel_info_from_sample_inputs_fn +def resize_bounding_box(): + for bounding_box in make_bounding_boxes(): + height, width = bounding_box.image_size + for size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) + + +class TestKernelsCommon: + @pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) + def test_scriptable(self, kernel_info): + jit.script(kernel_info.kernel) + + @pytest.mark.parametrize( + ("kernel_info", "sample_input"), + [ + pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") + for kernel_info in KERNEL_INFOS + for idx, sample_input in enumerate(kernel_info.sample_inputs()) + ], + ) + def test_eager_vs_scripted(self, kernel_info, sample_input): + eager = kernel_info(sample_input) + scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) + + torch.testing.assert_close(eager, scripted) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 270c75ba330..bf99e175d36 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -7,9 +7,9 @@ "Note that you cannot install it with `pip install torchdata`, since this is another package." ) from error -from . import decoder, utils +from . import utils from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import register, list_datasets, info, load # usort: skip +from ._api import list_datasets, info, load # usort: skip from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index f3c398d5552..13ee920cea2 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,12 +1,9 @@ -import io import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.decoder import raw, pil -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.utils._internal import add_suggestion from . import _builtin @@ -48,27 +45,15 @@ def info(name: str) -> DatasetInfo: return find(name).info -DEFAULT_DECODER = object() - -DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = { - DatasetType.RAW: raw, - DatasetType.IMAGE: pil, -} - - def load( name: str, *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] skip_integrity_check: bool = False, **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) - if decoder is DEFAULT_DECODER: - decoder = DEFAULT_DECODER_MAP.get(dataset.info.type) - config = dataset.info.make_config(**options) root = os.path.join(home(), dataset.name) - return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check) + return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md index c5a27d5d1c9..fbe84856aeb 100644 --- a/torchvision/prototype/datasets/_builtin/README.md +++ b/torchvision/prototype/datasets/_builtin/README.md @@ -19,10 +19,8 @@ that module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be discussed in detail below: ```python -import io -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource @@ -34,11 +32,7 @@ class MyDataset(Dataset): ... def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, ) -> IterDataPipe[Dict[str, Any]]: ... ``` @@ -49,10 +43,6 @@ The `DatasetInfo` carries static information about the dataset. There are two required fields: - `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain lowercase characters. -- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select - the default decoder in case the user doesn't pass one. There are currently - only two options: `IMAGE` and `RAW` ([see - below](#what-is-the-datasettyperaw-and-when-do-i-use-it) for details). There are more optional parameters that can be passed: @@ -105,7 +95,7 @@ def sha256sum(path, chunk_size=1024 * 1024): print(checksum.hexdigest()) ``` -### `_make_datapipe(resource_dps, *, config, decoder)` +### `_make_datapipe(resource_dps, *, config)` This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared to the current stable datasets is @@ -178,28 +168,6 @@ contains. You can also do that with `resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these datapipes and return the appropriate dictionary format. -### What is the `DatasetType.RAW` and when do I use it? - -`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values, -rather than encoded image files such as `.jpg` or `.png`. This is usually only -the case for small datasets, since it requires a lot more disk space. The -default decoder `datasets.decoder.raw` is only a sentinel and should not be -called directly. The decoding should look something like - -```python -from torchvision.prototype.datasets.decoder import raw - -image = ... - -if decoder is raw: - image = Image(image) -else: - image_buffer = image_buffer_from_raw(image) - image = decoder(image_buffer) if decoder else image_buffer -``` - -For examples, have a look at the MNIST, CIFAR, or SEMEION datasets. - ### How do I handle a dataset that defines many categories? As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index be19b7c240f..1a052860ebf 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,11 +1,8 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -18,17 +15,15 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling -from torchvision.prototype.features import Label, BoundingBox, Feature +from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage class Caltech101(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech101", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", ) @@ -81,33 +76,26 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]: return category, id - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _prepare_sample( + self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]] ) -> Dict[str, Any]: key, (image_data, ann_data) = data category, _ = key image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - label = self.info.categories.index(category) - - image = decoder(image_buffer) if decoder else image_buffer - + image = EncodedImage.from_file(image_buffer) ann = read_mat(ann_buffer) - bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy") - contour = Feature(ann["obj_contour"].T) return dict( - category=category, - label=label, - image=image, + label=Label.from_category(category, categories=self.categories), image_path=image_path, - bbox=bbox, - contour=contour, + image=image, ann_path=ann_path, + bounding_box=BoundingBox( + ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size + ), + contour=_Feature(ann["obj_contour"].T), ) def _make_datapipe( @@ -115,7 +103,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps @@ -133,7 +120,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: resources = self.resources(self.default_config) @@ -148,7 +135,6 @@ class Caltech256(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "caltech256", - type=DatasetType.IMAGE, homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", ) @@ -164,32 +150,26 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) return path.name != "RENAME2" - def _collate_and_decode_sample( - self, - data: Tuple[str, io.IOBase], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - dir_name = pathlib.Path(path).parent.name - label_str, category = dir_name.split(".") - label = Label(int(label_str), category=category) - - return dict(label=label, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=EncodedImage.from_file(buffer), + label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, self._is_not_rogue_file) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: resources = self.resources(self.default_config) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index b59959b49f1..191f49e9d53 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,9 +1,7 @@ import csv import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -17,7 +15,6 @@ DatasetInfo, GDriveResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -26,7 +23,8 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import Feature, Label, BoundingBox +from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox + csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -34,7 +32,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): def __init__( self, - datapipe: IterDataPipe[Tuple[Any, io.IOBase]], + datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, fieldnames: Optional[Sequence[str]] = None, ) -> None: @@ -66,7 +64,6 @@ class CelebA(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "celeba", - type=DatasetType.IMAGE, homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", valid_options=dict(split=("train", "val", "test")), ) @@ -92,7 +89,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0", file_name="list_attr_celeba.txt", ) - bboxes = GDriveResource( + bounding_boxes = GDriveResource( "0B7EVK8r0v71pbThiMVRxWXZ4dU0", sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b", file_name="list_bbox_celeba.txt", @@ -102,7 +99,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b", file_name="list_landmarks_align_celeba.txt", ) - return [splits, images, identities, attributes, bboxes, landmarks] + return [splits, images, identities, attributes, bounding_boxes, landmarks] _SPLIT_ID_TO_NAME = { "0": "train", @@ -113,38 +110,39 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split - def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]: - (image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data - return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks) - - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[str, Tuple[Tuple[str, Dict[str, Any]], Tuple[str, io.IOBase]]], Tuple[str, Dict[str, Any]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + data: Tuple[ + Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]], + Tuple[ + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + Tuple[str, Dict[str, str]], + ], + ], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, (_, image_data) = split_and_image_data path, buffer = image_data - _, ann = ann_data - - image = decoder(buffer) if decoder else buffer - identity = Label(int(ann["identity"]["identity"])) - attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} - bbox = BoundingBox([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) - landmarks = { - landmark: Feature((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"]))) - for landmark in {key[:-2] for key in ann["landmarks"].keys()} - } + image = EncodedImage.from_file(buffer) + (_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data return dict( path=path, image=image, - identity=identity, - attributes=attributes, - bbox=bbox, - landmarks=landmarks, + identity=Label(int(identity["identity"])), + attributes={attr: value == "1" for attr, value in attributes.items()}, + bounding_box=BoundingBox( + [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], + format="xywh", + image_size=image.image_size, + ), + landmarks={ + landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) + for landmark in {key[:-2] for key in landmarks.keys()} + }, ) def _make_datapipe( @@ -152,9 +150,8 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps + splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) @@ -167,12 +164,11 @@ def _make_datapipe( for dp, fieldnames in ( (identities_dp, ("image_id", "identity")), (attributes_dp, None), - (bboxes_dp, None), + (bounding_boxes_dp, None), (landmarks_dp, None), ) ] ) - anns_dp = Mapper(anns_dp, self._collate_anns) dp = IterKeyZipper( splits_dp, @@ -182,5 +178,11 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, keep_key=True, ) - dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + dp = IterKeyZipper( + dp, + anns_dp, + key_fn=getitem(0), + ref_key_fn=getitem(0, 0), + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 6ac2de3c9e6..f15ed9e9782 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -3,34 +3,28 @@ import io import pathlib import pickle -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Filter, Mapper, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, - image_buffer_from_array, path_comparator, hint_sharding, ) from torchvision.prototype.features import Label, Image -__all__ = ["Cifar10", "Cifar100"] - class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -52,13 +46,12 @@ class _CifarBase(Dataset): _CATEGORIES_KEY: str @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, io.IOBase], *, split: str) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]: pass def _make_info(self) -> DatasetInfo: return DatasetInfo( type(self).__name__.lower(), - type=DatasetType.RAW, homepage="https://www.cs.toronto.edu/~kriz/cifar.html", valid_options=dict(split=("train", "test")), ) @@ -75,31 +68,18 @@ def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]: _, file = data return cast(Dict[str, Any], pickle.load(file, encoding="latin1")) - def _collate_and_decode( - self, - data: Tuple[np.ndarray, int], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data - - image: Union[Image, io.BytesIO] - if decoder is raw: - image = Image(image_array) - else: - image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0))) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(category_idx, category=self.categories[category_idx]) - - return dict(image=image, label=label) + return dict( + image=Image(image_array), + label=Label(category_idx, categories=self.categories), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) @@ -107,7 +87,7 @@ def _make_datapipe( dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> List[str]: resources = self.resources(self.default_config) diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 447c1b5190d..af5a49a9822 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,9 +1,6 @@ -import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher from torchvision.prototype.datasets.utils import ( Dataset, @@ -11,7 +8,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -21,14 +17,13 @@ path_accessor, getitem, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage class CLEVR(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "clevr", - type=DatasetType.IMAGE, homepage="https://cs.stanford.edu/people/jcjohns/clevr/", valid_options=dict(split=("train", "val", "test")), ) @@ -53,21 +48,16 @@ def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool: key, _ = data return key == "scenes" - def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]: + def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]: return data, None - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]: image_data, scenes_data = data path, buffer = image_data return dict( path=path, - image=decoder(buffer) if decoder else buffer, + image=EncodedImage.from_file(buffer), label=Label(len(scenes_data["objects"])) if scenes_data else None, ) @@ -76,7 +66,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, scenes_dp = Demultiplexer( @@ -107,4 +96,4 @@ def _make_datapipe( else: dp = Mapper(images_dp, self._add_empty_anns) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 6fde966402c..74232eb714d 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,9 +1,8 @@ import functools -import io import pathlib import re from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import torch from torchdata.datapipes.iter import ( @@ -22,7 +21,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -33,7 +31,7 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import BoundingBox, Label, Feature +from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage from torchvision.prototype.utils._internal import FrozenMapping @@ -44,7 +42,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("pycocotools",), categories=categories, homepage="https://cocodataset.org/", @@ -96,10 +93,9 @@ def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: image_size = (image_meta["height"], image_meta["width"]) labels = [ann["category_id"] for ann in anns] - categories = [self.info.categories[label] for label in labels] return dict( # TODO: create a segmentation feature - segmentations=Feature( + segmentations=_Feature( torch.stack( [ self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size) @@ -107,16 +103,17 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st ] ) ), - areas=Feature([ann["area"] for ann in anns]), - crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), + areas=_Feature([ann["area"] for ann in anns]), + crowds=_Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool), bounding_boxes=BoundingBox( [ann["bbox"] for ann in anns], format="xywh", image_size=image_size, ), - labels=Label(labels), - categories=categories, - super_categories=[self.info.extra.category_to_super_category[category] for category in categories], + labels=Label(labels, categories=self.categories), + super_categories=[ + self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels + ], ann_ids=[ann["id"] for ann in anns], ) @@ -150,26 +147,24 @@ def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _collate_and_decode_image( - self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: + def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: path, buffer = data - return dict(path=path, image=decoder(buffer) if decoder else buffer) + return dict( + path=path, + image=EncodedImage.from_file(buffer), + ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], *, - annotations: Optional[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data - sample = self._collate_and_decode_image(image_data, decoder=decoder) - if annotations: - sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) - + sample = self._prepare_image(image_data) + sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample def _make_datapipe( @@ -177,14 +172,13 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps if config.annotations is None: dp = hint_sharding(images_dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder)) + return Mapper(dp, self._prepare_image) meta_dp = Filter( meta_dp, @@ -230,9 +224,8 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder) - ) + + return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: config = self.default_config diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index facd909f468..ae34b48d191 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,10 +1,8 @@ import csv import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -21,7 +19,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -32,7 +29,7 @@ path_comparator, path_accessor, ) -from torchvision.prototype.features import Label, BoundingBox, Feature +from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage csv.register_dialect("cub200", delimiter=" ") @@ -41,7 +38,6 @@ class CUB200(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "cub200", - type=DatasetType.IMAGE, homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html", dependencies=("scipy",), valid_options=dict( @@ -105,58 +101,55 @@ def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: path = pathlib.Path(data[0]) return path.with_suffix(".jpg").name - def _2011_load_ann( - self, - data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _2011_prepare_ann( + self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int] ) -> Dict[str, Any]: _, (bounding_box_data, segmentation_data) = data segmentation_path, segmentation_buffer = segmentation_data return dict( - bounding_box=BoundingBox([float(part) for part in bounding_box_data[1:]], format="xywh"), + bounding_box=BoundingBox( + [float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size + ), segmentation_path=segmentation_path, - segmentation=Feature(decoder(segmentation_buffer)) if decoder else segmentation_buffer, + segmentation=EncodedImage.from_file(segmentation_buffer), ) def _2010_split_key(self, data: str) -> str: return data.rsplit("/", maxsplit=1)[1] - def _2010_anns_key(self, data: Tuple[str, io.IOBase]) -> Tuple[str, Tuple[str, io.IOBase]]: + def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) return path.with_suffix(".jpg").name, data - def _2010_load_ann( - self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: + def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]: _, (path, buffer) = data content = read_mat(buffer) return dict( ann_path=path, bounding_box=BoundingBox( - [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], format="xyxy" + [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], + format="xyxy", + image_size=image_size, ), - segmentation=Feature(content["seg"]), + segmentation=_Feature(content["seg"]), ) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[str, Tuple[str, io.IOBase]], Any], + data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any], *, - year: str, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]], ) -> Dict[str, Any]: data, anns_data = data _, image_data = data path, buffer = image_data - dir_name = pathlib.Path(path).parent.name - label_str, category = dir_name.split(".") + image = EncodedImage.from_file(buffer) return dict( - (self._2011_load_ann if year == "2011" else self._2010_load_ann)(anns_data, decoder=decoder), - image=decoder(buffer) if decoder else buffer, - label=Label(int(label_str), category=category), + prepare_ann_fn(anns_data, image.image_size), + image=image, + label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories), ) def _make_datapipe( @@ -164,8 +157,8 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: + prepare_ann_fn: Callable if config.year == "2011": archive_dp, segmentations_dp = resource_dps images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( @@ -193,6 +186,8 @@ def _make_datapipe( keep_key=True, buffer_size=INFINITE_BUFFER_SIZE, ) + + prepare_ann_fn = self._2011_prepare_ann else: # config.year == "2010" split_dp, images_dp, anns_dp = resource_dps @@ -202,6 +197,8 @@ def _make_datapipe( anns_dp = Mapper(anns_dp, self._2010_anns_key) + prepare_ann_fn = self._2010_prepare_ann + split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -218,7 +215,7 @@ def _make_datapipe( getitem(0), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, year=config.year, decoder=decoder)) + return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) def _generate_categories(self, root: pathlib.Path) -> List[str]: config = self.info.make_config(year="2011") diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index fc3ec61efc7..171861454d4 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,10 +1,7 @@ import enum -import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -21,7 +18,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -29,7 +25,7 @@ path_comparator, getitem, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage class DTDDemux(enum.IntEnum): @@ -42,7 +38,6 @@ class DTD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "dtd", - type=DatasetType.IMAGE, homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", valid_options=dict( split=("train", "test", "val"), @@ -75,12 +70,7 @@ def _image_key_fn(self, data: Tuple[str, Any]) -> str: # The split files contain hardcoded posix paths for the images, e.g. banded/banded_0001.jpg return str(path.relative_to(path.parents[1]).as_posix()) - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: (_, joint_categories_data), image_data = data _, *joint_categories = joint_categories_data path, buffer = image_data @@ -89,9 +79,9 @@ def _collate_and_decode_sample( return dict( joint_categories={category for category in joint_categories if category}, - label=Label(self.info.categories.index(category), category=category), + label=Label.from_category(category, categories=self.categories), path=path, - image=decoder(buffer) if decoder else buffer, + image=EncodedImage.from_file(buffer), ) def _make_datapipe( @@ -99,7 +89,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] @@ -128,7 +117,7 @@ def _make_datapipe( ref_key_fn=self._image_key_fn, buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index 2d9bd713990..47d2ddc9acc 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,22 +1,17 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, cast import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, OnlineResource, - DatasetType, KaggleDownloadResource, ) from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, - image_buffer_from_array, ) from torchvision.prototype.features import Label, Image @@ -25,7 +20,6 @@ class FER2013(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "fer2013", - type=DatasetType.RAW, homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), valid_options=dict(split=("train", "test")), @@ -44,26 +38,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [archive] - def _collate_and_decode_sample( - self, - data: Dict[str, Any], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: - raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48) + def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: label_id = data.get("emotion") - label_idx = int(label_id) if label_id is not None else None - - image: Union[Image, io.BytesIO] - if decoder is raw: - image = Image(raw_image) - else: - image_buffer = image_buffer_from_array(raw_image.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] return dict( - image=image, - label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None, + image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), + label=Label(int(label_id), categories=self.categories) if label_id is not None else None, ) def _make_datapipe( @@ -71,10 +51,9 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVDictParser(dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 08855b3a2bd..2288766c10f 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,16 +1,12 @@ -import io import pathlib -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, OnlineResource, - DatasetType, HttpResource, ) from torchvision.prototype.datasets.utils._internal import ( @@ -19,14 +15,13 @@ hint_shuffling, INFINITE_BUFFER_SIZE, ) -from torchvision.prototype.features import Label, BoundingBox +from torchvision.prototype.features import Label, BoundingBox, EncodedImage class GTSRB(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "gtsrb", - type=DatasetType.IMAGE, homepage="https://benchmark.ini.rub.de", categories=[f"{label:05d}" for label in range(43)], valid_options=dict(split=("train", "test")), @@ -66,33 +61,26 @@ def _classify_train_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _collate_and_decode( - self, data: Tuple[Tuple[str, Any], Dict[str, Any]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]] - ) -> Dict[str, Any]: - (image_path, image_buffer), csv_info = data + def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[str, Any]: + (path, buffer), csv_info = data label = int(csv_info["ClassId"]) - bbox = BoundingBox( - torch.tensor([int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")]), + bounding_box = BoundingBox( + [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], format="xyxy", image_size=(int(csv_info["Height"]), int(csv_info["Width"])), ) return { - "image_path": image_path, - "image": decoder(image_buffer) if decoder else image_buffer, - "label": Label(label, category=self.categories[label]), - "bbox": bbox, + "path": path, + "image": EncodedImage.from_file(buffer), + "label": Label(label, categories=self.categories), + "bounding_box": bounding_box, } def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: - if config.split == "train": images_dp, ann_dp = Demultiplexer( resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE @@ -101,13 +89,12 @@ def _make_datapipe( images_dp, ann_dp = resource_dps images_dp = Filter(images_dp, path_comparator("suffix", ".ppm")) - # The order of the image files in the the .zip archives perfectly match the order of the entries in - # the (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper. + # The order of the image files in the .zip archives perfectly match the order of the entries in the + # (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper. ann_dp = CSVDictParser(ann_dp, delimiter=";") dp = Zipper(images_dp, ann_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, partial(self._collate_and_decode, decoder=decoder)) - return dp + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index ac3649c8839..0d11b642c13 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,18 +1,16 @@ import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast -import torch -from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter +from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer +from torchdata.datapipes.iter import TarArchiveReader from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, OnlineResource, ManualDownloadResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -24,7 +22,7 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.utils._internal import FrozenMapping @@ -40,7 +38,6 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( name, - type=DatasetType.IMAGE, dependencies=("scipy",), categories=categories, homepage="https://www.image-net.org/", @@ -61,14 +58,6 @@ def _make_info(self) -> DatasetInfo: def supports_sharded(self) -> bool: return True - @property - def category_to_wnid(self) -> Dict[str, str]: - return cast(Dict[str, str], self.info.extra.category_to_wnid) - - @property - def wnid_to_category(self) -> Dict[str, str]: - return cast(Dict[str, str], self.info.extra.wnid_to_category) - _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", "val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0", @@ -77,23 +66,56 @@ def wnid_to_category(self) -> Dict[str, str]: def resources(self, config: DatasetConfig) -> List[OnlineResource]: name = "test_v10102019" if config.split == "test" else config.split - images = ImageNetResource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name]) - - devkit = ImageNetResource( - file_name="ILSVRC2012_devkit_t12.tar.gz", - sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", + images = ImageNetResource( + file_name=f"ILSVRC2012_img_{name}.tar", + sha256=self._IMAGES_CHECKSUMS[name], ) + resources: List[OnlineResource] = [images] + + if config.split == "val": + devkit = ImageNetResource( + file_name="ILSVRC2012_devkit_t12.tar.gz", + sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", + ) + resources.append(devkit) + + return resources - return [images, devkit] + def num_samples(self, config: DatasetConfig) -> int: + return { + "train": 1_281_167, + "val": 50_000, + "test": 100_000, + }[config.split] _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) - wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] - category = self.wnid_to_category[wnid] - label_data = (Label(self.categories.index(category)), category, wnid) - return label_data, data + wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] + label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + return (label, wnid), data + + def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: + return None, data + + def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: + return { + "meta.mat": 0, + "ILSVRC2012_validation_ground_truth.txt": 1, + }.get(pathlib.Path(data[0]).name) + + def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: + synsets = read_mat(data[1], squeeze_me=True)["synsets"] + return [ + (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) + for _, wnid, category, _, num_children, *_ in synsets + # if num_children > 0, we are looking at a superclass that has no direct instance + if num_children == 0 + ] + + def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str: + return wnids[int(imagenet_label) - 1] _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") @@ -101,72 +123,65 @@ def _val_test_image_key(self, data: Tuple[str, Any]) -> int: path = pathlib.Path(data[0]) return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] - def _collate_val_data( - self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] - ) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]: + def _prepare_val_data( + self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] + ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: label_data, image_data = data - _, label = label_data - category = self.categories[label] - wnid = self.category_to_wnid[category] - return (Label(label), category, wnid), image_data - - def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]: - return None, data + _, wnid = label_data + label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + return (label, wnid), image_data - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], ) -> Dict[str, Any]: label_data, (path, buffer) = data - sample = dict( + return dict( + dict(zip(("label", "wnid"), label_data if label_data else (None, None))), path=path, - image=decoder(buffer) if decoder else buffer, + image=EncodedImage.from_file(buffer), ) - if label_data: - sample.update(dict(zip(("label", "category", "wnid"), label_data))) - - return sample def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: - images_dp, devkit_dp = resource_dps + if config.split in {"train", "test"}: + dp = resource_dps[0] - if config.split == "train": # the train archive is a tar of tars - dp = TarArchiveReader(images_dp) + if config.split == "train": + dp = TarArchiveReader(dp) + dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_train_data) - elif config.split == "val": - devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) - devkit_dp = LineReader(devkit_dp, return_path=False) - devkit_dp = Mapper(devkit_dp, int) - devkit_dp = Enumerator(devkit_dp, 1) - devkit_dp = hint_sharding(devkit_dp) - devkit_dp = hint_shuffling(devkit_dp) + dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) + else: # config.split == "val": + images_dp, devkit_dp = resource_dps + + meta_dp, label_dp = Demultiplexer( + devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + ) + + meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + _, wnids = zip(*next(iter(meta_dp))) + + label_dp = LineReader(label_dp, decode=True, return_path=False) + label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) + label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) + label_dp = hint_sharding(label_dp) + label_dp = hint_shuffling(label_dp) dp = IterKeyZipper( - devkit_dp, + label_dp, images_dp, key_fn=getitem(0), ref_key_fn=self._val_test_image_key, buffer_size=INFINITE_BUFFER_SIZE, ) - dp = Mapper(dp, self._collate_val_data) - else: # config.split == "test" - dp = hint_sharding(images_dp) - dp = hint_shuffling(dp) - dp = Mapper(dp, self._collate_test_data) + dp = Mapper(dp, self._prepare_val_data) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment @@ -176,22 +191,13 @@ def _make_datapipe( } def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: - resources = self.resources(self.default_config) + config = self.info.make_config(split="val") + resources = self.resources(config) devkit_dp = resources[1].load(root) - devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) - - meta = next(iter(devkit_dp))[1] - synsets = read_mat(meta, squeeze_me=True)["synsets"] - categories_and_wnids = cast( - List[Tuple[str, ...]], - [ - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) - for _, wnid, category, _, num_children, *_ in synsets - # if num_children > 0, we are looking at a superclass that has no direct instance - if num_children == 0 - ], - ) - categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) + meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) + meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) + categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) return categories_and_wnids diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 0d7fe36a3fd..e5b9fa84b0d 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -1,10 +1,9 @@ import abc import functools -import io import operator import pathlib import string -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO, Union, Sequence import torch from torchdata.datapipes.iter import ( @@ -13,24 +12,21 @@ Mapper, Zipper, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, - DatasetType, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import ( - image_buffer_from_array, Decompressor, INFINITE_BUFFER_SIZE, - fromfile, hint_sharding, hint_shuffling, ) from torchvision.prototype.features import Image, Label +from torchvision.prototype.utils._internal import fromfile __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] @@ -105,31 +101,15 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: return None, None - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, label = data - - if decoder is raw: - image = Image(image) - else: - image_buffer = image_buffer_from_array(image.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)]) - - return dict(image=image, label=label) + return dict( + image=Image(image), + label=Label(label, dtype=torch.int64, categories=self.categories), + ) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps start, stop = self.start_and_stop(config) @@ -143,14 +123,13 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder)) + return Mapper(dp, functools.partial(self._prepare_sample, config=config)) class MNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "mnist", - type=DatasetType.RAW, categories=10, homepage="http://yann.lecun.com/exdb/mnist", valid_options=dict( @@ -183,7 +162,6 @@ class FashionMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "fashionmnist", - type=DatasetType.RAW, categories=( "T-shirt/top", "Trouser", @@ -215,7 +193,6 @@ class KMNIST(MNIST): def _make_info(self) -> DatasetInfo: return DatasetInfo( "kmnist", - type=DatasetType.RAW, categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], homepage="http://codh.rois.ac.jp/kmnist/index.html.en", valid_options=dict( @@ -236,7 +213,6 @@ class EMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "emnist", - type=DatasetType.RAW, categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", valid_options=dict( @@ -291,13 +267,7 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, @@ -310,14 +280,10 @@ def _collate_and_decode( image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._collate_and_decode(data, config=config, decoder=decoder) + return super()._prepare_sample(data, config=config) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( @@ -327,14 +293,13 @@ def _make_datapipe( drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config, decoder=decoder) + return super()._make_datapipe([images_dp, labels_dp], config=config) class QMNIST(_MNISTBase): def _make_info(self) -> DatasetInfo: return DatasetInfo( "qmnist", - type=DatasetType.RAW, categories=10, homepage="https://github.com/facebookresearch/qmnist", valid_options=dict( @@ -376,16 +341,10 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _collate_and_decode( - self, - data: Tuple[torch.Tensor, torch.Tensor], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._collate_and_decode((image, label), config=config, decoder=decoder) + sample = super()._prepare_sample((image, label), config=config) sample.update( dict( diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 59a28796cbc..1780b8829f4 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,10 +1,7 @@ import enum -import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( Dataset, @@ -12,7 +9,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -22,7 +18,7 @@ path_accessor, path_comparator, ) -from torchvision.prototype.features import Label +from torchvision.prototype.features import Label, EncodedImage class OxfordIITPetDemux(enum.IntEnum): @@ -34,7 +30,6 @@ class OxfordIITPet(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "oxford-iiit-pet", - type=DatasetType.IMAGE, homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", valid_options=dict( split=("trainval", "test"), @@ -66,18 +61,8 @@ def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_segmentations(self, data: Tuple[str, Any]) -> bool: return not pathlib.Path(data[0]).name.startswith(".") - def _decode_classification_data(self, data: Dict[str, str]) -> Dict[str, Any]: - label_idx = int(data["label"]) - 1 - return dict( - label=Label(label_idx, category=self.info.categories[label_idx]), - species="cat" if data["species"] == "1" else "dog", - ) - - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[Dict[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + def _prepare_sample( + self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]] ) -> Dict[str, Any]: ann_data, image_data = data classification_data, segmentation_data = ann_data @@ -85,19 +70,16 @@ def _collate_and_decode_sample( image_path, image_buffer = image_data return dict( - self._decode_classification_data(classification_data), + label=Label(int(classification_data["label"]) - 1, categories=self.categories), + species="cat" if classification_data["species"] == "1" else "dog", segmentation_path=segmentation_path, - segmentation=decoder(segmentation_buffer) if decoder else segmentation_buffer, + segmentation=EncodedImage.from_file(segmentation_buffer), image_path=image_path, - image=decoder(image_buffer) if decoder else image_buffer, + image=EncodedImage.from_file(image_buffer), ) def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps @@ -137,7 +119,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index bf43b30d650..988ff8d4138 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -1,8 +1,7 @@ import io from collections import namedtuple -from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator +from typing import Any, Dict, List, Optional, Tuple, Iterator -import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features from torchvision.prototype.datasets.utils import ( @@ -10,7 +9,6 @@ DatasetConfig, DatasetInfo, OnlineResource, - DatasetType, GDriveResource, ) from torchvision.prototype.datasets.utils._internal import ( @@ -46,7 +44,6 @@ class PCAM(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "pcam", - type=DatasetType.RAW, homepage="https://github.com/basveeling/pcam", categories=2, valid_options=dict(split=("train", "test", "val")), @@ -98,7 +95,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] ] - def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: image, target = data # They're both numpy arrays at this point return { @@ -107,11 +104,7 @@ def _collate_and_decode(self, data: Tuple[Any, Any]) -> Dict[str, Any]: } def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps @@ -122,4 +115,4 @@ def _make_datapipe( dp = Zipper(images_dp, targets_dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, self._collate_and_decode) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 27b27b2745b..2619dff67eb 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,11 +1,8 @@ -import functools -import io import pathlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +17,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -31,20 +27,17 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import Feature +from torchvision.prototype.features import _Feature, EncodedImage class SBD(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "sbd", - type=DatasetType.IMAGE, dependencies=("scipy",), homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", valid_options=dict( split=("train", "val", "train_noval"), - boundaries=(True, False), - segmentation=(False, True), ), ) @@ -75,50 +68,21 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _decode_ann( - self, data: Dict[str, Any], *, decode_boundaries: bool, decode_segmentation: bool - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - raw_anns = data["GTcls"][0] - raw_boundaries = raw_anns["Boundaries"][0] - raw_segmentation = raw_anns["Segmentation"][0] - - # the boundaries are stored in sparse CSC format, which is not supported by PyTorch - boundaries = ( - Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries])) - if decode_boundaries - else None - ) - segmentation = Feature(raw_segmentation) if decode_segmentation else None - - return boundaries, segmentation - - def _collate_and_decode_sample( - self, - data: Tuple[Tuple[Any, Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], - *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.boundaries or config.segmentation: - boundaries, segmentation = self._decode_ann( - read_mat(ann_buffer), decode_boundaries=config.boundaries, decode_segmentation=config.segmentation - ) - else: - boundaries = segmentation = None + anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"] return dict( image_path=image_path, - image=image, + image=EncodedImage.from_file(image_buffer), ann_path=ann_path, - boundaries=boundaries, - segmentation=segmentation, + # the boundaries are stored in sparse CSC format, which is not supported by PyTorch + boundaries=_Feature(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])), + segmentation=_Feature(anns["Segmentation"].item()), ) def _make_datapipe( @@ -126,7 +90,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps @@ -138,10 +101,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) - if config.split == "train_noval": split_dp = extra_split_dp - split_dp = Filter(split_dp, path_comparator("stem", config.split)) + + split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -155,7 +118,7 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) + return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: resources = self.resources(self.default_config) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index d153debcefd..a6fc1098fda 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,6 +1,4 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch from torchdata.datapipes.iter import ( @@ -8,24 +6,21 @@ Mapper, CSVParser, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) -from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling -from torchvision.prototype.features import Image, Label +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Image, OneHotLabel class SEMEION(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "semeion", - type=DatasetType.RAW, categories=10, homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", ) @@ -37,34 +32,22 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [data] - def _collate_and_decode_sample( - self, - data: Tuple[str, ...], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: - image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16) - label_data = [int(label) for label in data[256:] if label] + def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: + image_data, label_data = data[:256], data[256:-1] - if decoder is raw: - image = Image(image_data.unsqueeze(0)) - else: - image_buffer = image_buffer_from_array(image_data.numpy()) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - - label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) - return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx])) + return dict( + image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)), + label=OneHotLabel([int(label) for label in label_data], categories=self.categories), + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) - return dp + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 7f9c019e92e..21af4add909 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -1,28 +1,22 @@ -import functools -import io -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple, BinaryIO import numpy as np -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, UnBatcher, ) -from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.utils import ( Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( read_mat, hint_sharding, hint_shuffling, - image_buffer_from_array, ) from torchvision.prototype.features import Label, Image @@ -31,7 +25,6 @@ class SVHN(Dataset): def _make_info(self) -> DatasetInfo: return DatasetInfo( "svhn", - type=DatasetType.RAW, dependencies=("scipy",), categories=10, homepage="http://ufldl.stanford.edu/housenumbers/", @@ -52,7 +45,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [data] - def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np.ndarray, np.ndarray]]: + def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]: _, buffer = data content = read_mat(buffer) return list( @@ -62,23 +55,12 @@ def _read_images_and_labels(self, data: Tuple[str, io.IOBase]) -> List[Tuple[np. ) ) - def _collate_and_decode_sample( - self, - data: Tuple[np.ndarray, np.ndarray], - *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], - ) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]: image_array, label_array = data - if decoder is raw: - image = Image(image_array.transpose((2, 0, 1))) - else: - image_buffer = image_buffer_from_array(image_array) - image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - return dict( - image=image, - label=Label(int(label_array) % 10), + image=Image(image_array.transpose((2, 0, 1))), + label=Label(int(label_array) % 10, categories=self.categories), ) def _make_datapipe( @@ -86,11 +68,10 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Mapper(dp, self._read_images_and_labels) dp = UnBatcher(dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) + return Mapper(dp, self._prepare_sample) diff --git a/torchvision/prototype/datasets/_builtin/voc.categories b/torchvision/prototype/datasets/_builtin/voc.categories new file mode 100644 index 00000000000..8420ab35ede --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/voc.categories @@ -0,0 +1,20 @@ +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +diningtable +dog +horse +motorbike +person +pottedplant +sheep +sofa +train +tvmonitor diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index da145ab1e1c..6ba2186853d 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,10 +1,8 @@ import functools -import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable from xml.etree import ElementTree -import torch from torchdata.datapipes.iter import ( IterDataPipe, Mapper, @@ -20,7 +18,6 @@ DatasetInfo, HttpResource, OnlineResource, - DatasetType, ) from torchvision.prototype.datasets.utils._internal import ( path_accessor, @@ -30,7 +27,7 @@ hint_sharding, hint_shuffling, ) -from torchvision.prototype.features import BoundingBox +from torchvision.prototype.features import BoundingBox, Label, EncodedImage class VOCDatasetInfo(DatasetInfo): @@ -50,7 +47,6 @@ class VOC(Dataset): def _make_info(self) -> DatasetInfo: return VOCDatasetInfo( "voc", - type=DatasetType.IMAGE, homepage="http://host.robots.ox.ac.uk/pascal/VOC/", valid_options=dict( split=("train", "val", "trainval", "test"), @@ -99,40 +95,52 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None - def _decode_detection_ann(self, buffer: io.IOBase) -> torch.Tensor: - result = VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot()) # type: ignore[arg-type] - objects = result["annotation"]["object"] - bboxes = [obj["bndbox"] for obj in objects] - bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes] - return BoundingBox(bboxes) + def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"]) + + def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + anns = self._parse_detection_ann(buffer) + instances = anns["object"] + return dict( + bounding_boxes=BoundingBox( + [ + [int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for instance in instances + ], + format="xyxy", + image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), + ), + labels=Label( + [self.categories.index(instance["name"]) for instance in instances], categories=self.categories + ), + ) + + def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: + return dict(segmentation=EncodedImage.from_file(buffer)) - def _collate_and_decode_sample( + def _prepare_sample( self, - data: Tuple[Tuple[Tuple[str, str], Tuple[str, io.IOBase]], Tuple[str, io.IOBase]], + data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], *, - config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], + prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data image_path, image_buffer = image_data ann_path, ann_buffer = ann_data - image = decoder(image_buffer) if decoder else image_buffer - - if config.task == "detection": - ann = self._decode_detection_ann(ann_buffer) - else: # config.task == "segmentation": - ann = decoder(ann_buffer) if decoder else ann_buffer # type: ignore[assignment] - - return dict(image_path=image_path, image=image, ann_path=ann_path, ann=ann) + return dict( + prepare_ann_fn(ann_buffer), + image_path=image_path, + image=EncodedImage.from_file(image_buffer), + ann_path=ann_path, + ) def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( @@ -158,4 +166,25 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder)) + return Mapper( + dp, + functools.partial( + self._prepare_sample, + prepare_ann_fn=self._prepare_detection_ann + if config.task == "detection" + else self._prepare_segmentation_ann, + ), + ) + + def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: + return self._classify_archive(data, config=config) == 2 + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.info.make_config(task="detection") + + resource = self.resources(config)[0] + dp = resource.load(pathlib.Path(root) / self.name) + dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + dp = Mapper(dp, self._parse_detection_ann, input_col=1) + + return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index fbca8b07b1a..c3a38becb6c 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -1,15 +1,12 @@ import functools -import io import os import os.path import pathlib -from typing import Callable, Optional, Collection -from typing import Union, Tuple, List, Dict, Any +from typing import BinaryIO, Optional, Collection, Union, Tuple, List, Dict, Any -import torch -from torchdata.datapipes.iter import IterDataPipe, FileLister, FileOpener, Mapper, Shuffler, Filter -from torchvision.prototype.datasets.decoder import pil -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding +from torchdata.datapipes.iter import IterDataPipe, FileLister, Mapper, Filter, FileOpener +from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling +from torchvision.prototype.features import Label, EncodedImage, EncodedData __all__ = ["from_data_folder", "from_image_folder"] @@ -20,29 +17,24 @@ def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool: return rel_path.is_dir() or rel_path.parent != pathlib.Path(".") -def _collate_and_decode_data( - data: Tuple[str, io.IOBase], +def _prepare_sample( + data: Tuple[str, BinaryIO], *, root: pathlib.Path, categories: List[str], - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> Dict[str, Any]: path, buffer = data - data = decoder(buffer) if decoder else buffer category = pathlib.Path(path).relative_to(root).parts[0] - label = torch.tensor(categories.index(category)) return dict( path=path, - data=data, - label=label, - category=category, + data=EncodedData.from_file(buffer), + label=Label.from_category(category, categories=categories), ) def from_data_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, valid_extensions: Optional[Collection[str]] = None, recursive: bool = True, ) -> Tuple[IterDataPipe, List[str]]: @@ -52,26 +44,22 @@ def from_data_folder( dp = FileLister(str(root), recursive=recursive, masks=masks) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp = hint_sharding(dp) - dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + dp = hint_shuffling(dp) dp = FileOpener(dp, mode="rb") - return ( - Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)), - categories, - ) + return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = sample.pop("data") + sample["image"] = EncodedImage(sample.pop("data").data) return sample def from_image_folder( root: Union[str, pathlib.Path], *, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil, valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"), **kwargs: Any, ) -> Tuple[IterDataPipe, List[str]]: valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())] - dp, categories = from_data_folder(root, decoder=decoder, valid_extensions=valid_extensions, **kwargs) + dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs) return Mapper(dp, _data_to_image_key), categories diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py deleted file mode 100644 index 530a357f239..00000000000 --- a/torchvision/prototype/datasets/decoder.py +++ /dev/null @@ -1,16 +0,0 @@ -import io - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.transforms.functional import pil_to_tensor - -__all__ = ["raw", "pil"] - - -def raw(buffer: io.IOBase) -> torch.Tensor: - raise RuntimeError("This is just a sentinel and should never be called.") - - -def pil(buffer: io.IOBase) -> features.Image: - return features.Image(pil_to_tensor(PIL.Image.open(buffer))) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index bde05c49cb1..9423b65a8ee 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ -from . import _internal -from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset +from . import _internal # usort: skip +from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 04fb4312728..5ee7c5ccc60 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,28 +1,19 @@ import abc import csv -import enum import importlib -import io import itertools import os import pathlib -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple, Collection +from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection -import torch from torch.utils.data import IterDataPipe -from torchvision.prototype.utils._internal import FrozenBunch, make_repr -from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str +from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str from .._home import use_sharded_dataset from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._resource import OnlineResource -class DatasetType(enum.Enum): - RAW = enum.auto() - IMAGE = enum.auto() - - class DatasetConfig(FrozenBunch): # This needs to be Frozen because we often pass configs as partial(func, config=config) # and partial() requires the parameters to be hashable. @@ -34,7 +25,6 @@ def __init__( self, name: str, *, - type: Union[str, DatasetType], dependencies: Collection[str] = (), categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, @@ -44,7 +34,6 @@ def __init__( extra: Optional[Dict[str, Any]] = None, ) -> None: self.name = name.lower() - self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.dependecies = dependencies @@ -163,7 +152,6 @@ def _make_datapipe( resource_dps: List[IterDataPipe], *, config: DatasetConfig, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]], ) -> IterDataPipe[Dict[str, Any]]: pass @@ -175,7 +163,6 @@ def load( root: Union[str, pathlib.Path], *, config: Optional[DatasetConfig] = None, - decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, skip_integrity_check: bool = False, ) -> IterDataPipe[Dict[str, Any]]: if not config: @@ -190,7 +177,7 @@ def load( resource_dps = [ resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) ] - return self._make_datapipe(resource_dps, config=config, decoder=decoder) + return self._make_datapipe(resource_dps, config=config) def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 1b437d50b85..3ed40f63ff0 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,14 +1,11 @@ import enum import functools import gzip -import io import lzma -import mmap import os import os.path import pathlib import pickle -import platform from typing import BinaryIO from typing import ( Sequence, @@ -25,27 +22,24 @@ ) from typing import cast -import numpy as np -import PIL.Image import torch import torch.distributed as dist import torch.utils.data from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler from torchdata.datapipes.utils import StreamWrapper +from torchvision.prototype.utils._internal import fromfile __all__ = [ "INFINITE_BUFFER_SIZE", "BUILTIN_DIR", "read_mat", - "image_buffer_from_array", "MappingIterator", "Enumerator", "getitem", "path_accessor", "path_comparator", "Decompressor", - "fromfile", "read_flo", "hint_sharding", ] @@ -59,7 +53,7 @@ BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" -def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: +def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any: try: import scipy.io as sio except ImportError as error: @@ -71,14 +65,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: return sio.loadmat(buffer, **kwargs) -def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.BytesIO: - image = PIL.Image.fromarray(array) - buffer = io.BytesIO() - image.save(buffer, format=format) - buffer.seek(0) - return buffer - - class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: self.datapipe = datapipe @@ -142,17 +128,17 @@ class CompressionType(enum.Enum): LZMA = "lzma" -class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): +class Decompressor(IterDataPipe[Tuple[str, BinaryIO]]): types = CompressionType - _DECOMPRESSORS = { - types.GZIP: lambda file: gzip.GzipFile(fileobj=file), - types.LZMA: lambda file: lzma.LZMAFile(file), + _DECOMPRESSORS: Dict[CompressionType, Callable[[BinaryIO], BinaryIO]] = { + types.GZIP: lambda file: cast(BinaryIO, gzip.GzipFile(fileobj=file)), + types.LZMA: lambda file: cast(BinaryIO, lzma.LZMAFile(file)), } def __init__( self, - datapipe: IterDataPipe[Tuple[str, io.IOBase]], + datapipe: IterDataPipe[Tuple[str, BinaryIO]], *, type: Optional[Union[str, CompressionType]] = None, ) -> None: @@ -174,7 +160,7 @@ def _detect_compression_type(self, path: str) -> CompressionType: else: raise RuntimeError("FIXME") - def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: + def __iter__(self) -> Iterator[Tuple[str, BinaryIO]]: for path, file in self.datapipe: type = self._detect_compression_type(path) decompressor = self._DECOMPRESSORS[type] @@ -257,69 +243,6 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[st return dp -def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: - # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable - return bytearray(file.read(-1 if count == -1 else count * item_size)) - - -def fromfile( - file: BinaryIO, - *, - dtype: torch.dtype, - byte_order: str, - count: int = -1, -) -> torch.Tensor: - """Construct a tensor from a binary file. - - .. note:: - - This function is similar to :func:`numpy.fromfile` with two notable differences: - - 1. This function only accepts an open binary file, but not a path to it. - 2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that - concept. - - .. note:: - - If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as - long as the file is still open, inplace operations on the returned tensor will reflect back to the file. - - Args: - file (IO): Open binary file. - dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor. - byte_order (str): Byte order of the data. Can be "little" or "big" endian. - count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file. - """ - byte_order = "<" if byte_order == "little" else ">" - char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u") - item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 - np_dtype = byte_order + char + str(item_size) - - buffer: Union[memoryview, bytearray] - if platform.system() != "Windows": - # PyTorch does not support tensors with underlying read-only memory. In case - # - the file has a .fileno(), - # - the file was opened for updating, i.e. 'r+b' or 'w+b', - # - the file is seekable - # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it - # to a mutable location afterwards. - try: - buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] - # Reading from the memoryview does not advance the file cursor, so we have to do it manually. - file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) - except (PermissionError, io.UnsupportedOperation): - buffer = _read_mutable_buffer_fallback(file, count, item_size) - else: - # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state - # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. - buffer = _read_mutable_buffer_fallback(file, count, item_size) - - # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we - # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the - # successive .astype() call. - return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False)) - - def read_flo(file: BinaryIO) -> torch.Tensor: if file.read(4) != b"PIEH": raise ValueError("Magic number incorrect. Invalid .flo file") @@ -329,9 +252,9 @@ def read_flo(file: BinaryIO) -> torch.Tensor: return flow.reshape((height, width, 2)).permute((2, 0, 1)) -def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: +def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: return ShardingFilter(datapipe) -def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]: +def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE) diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 4d77d3a5ce3..218c8876495 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -1,4 +1,6 @@ -from ._bounding_box import BoundingBoxFormat, BoundingBox -from ._feature import Feature, DEFAULT -from ._image import Image, ColorSpace -from ._label import Label +from ._bounding_box import BoundingBox, BoundingBoxFormat +from ._encoded import EncodedData, EncodedImage, EncodedVideo +from ._feature import _Feature +from ._image import ColorSpace, Image +from ._label import Label, OneHotLabel +from ._segmentation_mask import SegmentationMask diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 64ba449ae76..fbe19549dca 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,155 +1,51 @@ -import enum -import functools -from typing import Callable, Union, Tuple, Dict, Any, Optional, cast +from __future__ import annotations + +from typing import Any, Tuple, Union, Optional import torch from torchvision.prototype.utils._internal import StrEnum -from ._feature import Feature, DEFAULT +from ._feature import _Feature class BoundingBoxFormat(StrEnum): - # this is just for test purposes - _SENTINEL = -1 - XYXY = enum.auto() - XYWH = enum.auto() - CXCYWH = enum.auto() - - -def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return input.unbind(dim=-1) # type: ignore[return-value] - - -def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor: - return torch.stack((a, b, c, d), dim=-1) - - -def format_converter_wrapper( - part_converter: Callable[ - [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - ] -): - def wrapper(input: torch.Tensor) -> torch.Tensor: - return from_parts(*part_converter(*to_parts(input))) - - return wrapper - - -@format_converter_wrapper -def xywh_to_xyxy( - x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x1 = x - y1 = y - x2 = x + w - y2 = y + h - return x1, y1, x2, y2 - - -@format_converter_wrapper -def xyxy_to_xywh( - x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x = x1 - y = y1 - w = x2 - x1 - h = y2 - y1 - return x, y, w, h - + XYXY = StrEnum.auto() + XYWH = StrEnum.auto() + CXCYWH = StrEnum.auto() -@format_converter_wrapper -def cxcywh_to_xyxy( - cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - x1 = cx - 0.5 * w - y1 = cy - 0.5 * h - x2 = cx + 0.5 * w - y2 = cy + 0.5 * h - return x1, y1, x2, y2 - -@format_converter_wrapper -def xyxy_to_cxcywh( - x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - cx = (x1 + x2) / 2 - cy = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - return cx, cy, w, h - - -class BoundingBox(Feature): - formats = BoundingBoxFormat +class BoundingBox(_Feature): format: BoundingBoxFormat image_size: Tuple[int, int] - @classmethod - def _parse_meta_data( + def __new__( cls, - format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] - image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] - ) -> Dict[str, Tuple[Any, Any]]: + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + format: Union[BoundingBoxFormat, str], + image_size: Tuple[int, int], + ) -> BoundingBox: + bounding_box = super().__new__(cls, data, dtype=dtype, device=device) + if isinstance(format, str): format = BoundingBoxFormat[format] - format_fallback = BoundingBoxFormat.XYXY - return dict( - format=(format, format_fallback), - image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)), - ) - _TO_XYXY_MAP = { - BoundingBoxFormat.XYWH: xywh_to_xyxy, - BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy, - } - _FROM_XYXY_MAP = { - BoundingBoxFormat.XYWH: xyxy_to_xywh, - BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh, - } + bounding_box._metadata.update(dict(format=format, image_size=image_size)) - @classmethod - def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]: - if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH): - if format != BoundingBoxFormat.XYXY: - data = cls._TO_XYXY_MAP[format](data) - data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data) - *_, w, h = to_parts(data) - if data.dtype.is_floating_point: - w = w.ceil() - h = h.ceil() - return int(h.max()), int(w.max()) + return bounding_box - @classmethod - def from_parts( - cls, - a, - b, - c, - d, - *, - like: Optional["BoundingBox"] = None, - format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] - image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] - ) -> "BoundingBox": - return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format) + def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state - def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return to_parts(self) + # import at runtime to avoid cyclic imports + from torchvision.prototype.transforms.kernels import convert_bounding_box_format - def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": if isinstance(format, str): format = BoundingBoxFormat[format] - if format == self.format: - return cast(BoundingBox, self.clone()) - - data = self - - if self.format != BoundingBoxFormat.XYXY: - data = self._TO_XYXY_MAP[self.format](data) - - if format != BoundingBoxFormat.XYXY: - data = self._FROM_XYXY_MAP[format](data) - - return BoundingBox(data, like=self, format=format) + return BoundingBox.new_like( + self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format + ) diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py new file mode 100644 index 00000000000..ab6b821d673 --- /dev/null +++ b/torchvision/prototype/features/_encoded.py @@ -0,0 +1,52 @@ +import os +import sys +from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any + +import PIL.Image +import torch +from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer + +from ._feature import _Feature +from ._image import Image + +D = TypeVar("D", bound="EncodedData") + + +class EncodedData(_Feature): + @classmethod + def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: + # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? + return super()._to_tensor(data, dtype=dtype, device=device) + + @classmethod + def from_file(cls: Type[D], file: BinaryIO) -> D: + return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) + + @classmethod + def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D: + with open(path, "rb") as file: + return cls.from_file(file) + + +class EncodedImage(EncodedData): + # TODO: Use @functools.cached_property if we can depend on Python 3.8 + @property + def image_size(self) -> Tuple[int, int]: + if not hasattr(self, "_image_size"): + with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: + self._image_size = image.height, image.width + + return self._image_size + + def decode(self) -> Image: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state + + # import at runtime to avoid cyclic imports + from torchvision.prototype.transforms.kernels import decode_image_with_pil + + return Image(decode_image_with_pil(self)) + + +class EncodedVideo(EncodedData): + pass diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 1837ffc1e89..14c1a9a2a9e 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,85 +1,80 @@ -from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence +from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping import torch from torch._C import _TensorBase, DisableTorchFunction -from torchvision.prototype.utils._internal import add_suggestion -F = TypeVar("F", bound="Feature") +F = TypeVar("F", bound="_Feature") -DEFAULT = object() - - -class Feature(torch.Tensor): +class _Feature(torch.Tensor): _META_ATTRS: Set[str] = set() - _meta_data: Dict[str, Any] - - def __init_subclass__(cls): - # In order to help static type checkers, we require subclasses of `Feature` add the meta data attributes - # as static class annotations: - # - # >>> class Foo(Feature): - # ... bar: str - # ... baz: Optional[str] - # - # Internally, this information is used twofold: - # - # 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference - # to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime - # access. This happens in this method. - # 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for - # unknown arguments. + _metadata: Dict[str, Any] + + def __init_subclass__(cls) -> None: + """ + For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes. + By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds + properties to have the same convenient access as regular attributes. + + >>> class Foo(_Feature): + ... bar: str + ... baz: Optional[str] + >>> foo = Foo() + >>> foo.bar + >>> foo.baz + + This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata. + """ meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")} for super_cls in cls.__mro__[1:]: - if super_cls is Feature: + if super_cls is _Feature: break - meta_attrs.update(super_cls._META_ATTRS) + meta_attrs.update(cast(Type[_Feature], super_cls)._META_ATTRS) cls._META_ATTRS = meta_attrs - for attr in meta_attrs: - setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr])) - - def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs): - unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS - if unknown_meta_attrs: - unknown_meta_attr = sorted(unknown_meta_attrs)[0] - raise TypeError( - add_suggestion( - f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.", - word=unknown_meta_attr, - possibilities=sorted(cls._META_ATTRS), - ) - ) - - if like is not None: - dtype = dtype or like.dtype - device = device or like.device - data = cls._to_tensor(data, dtype=dtype, device=device) - requires_grad = False - self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad) - - meta_data = dict() - for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items(): - if explicit is not DEFAULT: - value = explicit - elif like is not None: - value = getattr(like, attr) - else: - value = fallback(data) if callable(fallback) else fallback - meta_data[attr] = value - self._meta_data = meta_data - - return self + for name in meta_attrs: + setattr(cls, name, property(cast(Callable[[F], Any], lambda self, name=name: self._metadata[name]))) + + def __new__( + cls: Type[F], + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + ) -> F: + if isinstance(device, str): + device = torch.device(device) + feature = cast( + F, + torch.Tensor._make_subclass( + cast(_TensorBase, cls), + cls._to_tensor(data, dtype=dtype, device=device), + # requires_grad + False, + ), + ) + feature._metadata = dict() + return feature @classmethod - def _to_tensor(cls, data, *, dtype, device): + def _to_tensor(self, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: return torch.as_tensor(data, dtype=dtype, device=device) @classmethod - def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]: - return dict() + def new_like( + cls: Type[F], + other: F, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + **metadata: Any, + ) -> F: + _metadata = other._metadata.copy() + _metadata.update(metadata) + return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata) @classmethod def __torch_function__( @@ -89,12 +84,37 @@ def __torch_function__( args: Sequence[Any] = (), kwargs: Optional[Mapping[str, Any]] = None, ) -> torch.Tensor: + """For general information about how the __torch_function__ protocol works, + see https://pytorch.org/docs/stable/notes/extending.html#extending-torch + + TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the + ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the + ``args`` and ``kwargs`` of the original call. + + The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature` + use case, this has two downsides: + + 1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. + ``return cls(func(*args, **kwargs))``, will fail for them. + 2. For most operations, there is no way of knowing if the input type is still valid for the output. + + For these reasons, the automatic output wrapping is turned off for most operators. + + Exceptions to this are: + + - :func:`torch.clone` + - :meth:`torch.Tensor.to` + """ + kwargs = kwargs or dict() with DisableTorchFunction(): - output = func(*args, **(kwargs or dict())) - if func is not torch.Tensor.clone: - return output + output = func(*args, **kwargs) - return cls(output, like=args[0]) + if func is torch.Tensor.clone: + return cls.new_like(args[0], output) + elif func is torch.Tensor.to: + return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) + else: + return output def __repr__(self) -> str: - return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) + return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 3d0b3d0c0af..5ecc4cbedb7 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,40 +1,64 @@ -from typing import Dict, Any, Union, Tuple +from __future__ import annotations + +import warnings +from typing import Any, Optional, Union, Tuple, cast import torch from torchvision.prototype.utils._internal import StrEnum +from torchvision.transforms.functional import to_pil_image +from torchvision.utils import draw_bounding_boxes +from torchvision.utils import make_grid -from ._feature import Feature, DEFAULT +from ._bounding_box import BoundingBox +from ._feature import _Feature class ColorSpace(StrEnum): - # this is just for test purposes - _SENTINEL = -1 OTHER = 0 GRAYSCALE = 1 RGB = 3 -class Image(Feature): - color_spaces = ColorSpace +class Image(_Feature): color_space: ColorSpace + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + color_space: Optional[Union[ColorSpace, str]] = None, + ) -> Image: + image = super().__new__(cls, data, dtype=dtype, device=device) + + if color_space is None: + color_space = cls.guess_color_space(image) + if color_space == ColorSpace.OTHER: + warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") + elif isinstance(color_space, str): + color_space = ColorSpace[color_space] + + image._metadata.update(dict(color_space=color_space)) + + return image + @classmethod - def _to_tensor(cls, data, *, dtype, device): - tensor = torch.as_tensor(data, dtype=dtype, device=device) - if tensor.ndim == 2: + def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: + tensor = super()._to_tensor(data, dtype=dtype, device=device) + if tensor.ndim < 2: + raise ValueError + elif tensor.ndim == 2: tensor = tensor.unsqueeze(0) - elif tensor.ndim != 3: - raise ValueError("Only single images with 2 or 3 dimensions are allowed.") return tensor - @classmethod - def _parse_meta_data( - cls, - color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment] - ) -> Dict[str, Tuple[Any, Any]]: - if isinstance(color_space, str): - color_space = ColorSpace[color_space] - return dict(color_space=(color_space, cls.guess_color_space)) + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], self.shape[-2:]) + + @property + def num_channels(self) -> int: + return self.shape[-3] @staticmethod def guess_color_space(data: torch.Tensor) -> ColorSpace: @@ -50,3 +74,13 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: return ColorSpace.RGB else: return ColorSpace.OTHER + + def show(self) -> None: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state + to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() + + def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: + # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we + # promote this out of the prototype state + return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index ebdc6bbbc26..618c020dbfc 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,14 +1,59 @@ -from typing import Dict, Any, Optional, Tuple +from __future__ import annotations -from ._feature import Feature, DEFAULT +from typing import Any, Optional, Sequence, cast +import torch +from torchvision.prototype.utils._internal import apply_recursively -class Label(Feature): - category: Optional[str] +from ._feature import _Feature + + +class Label(_Feature): + categories: Optional[Sequence[str]] + + def __new__( + cls, + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + like: Optional[Label] = None, + categories: Optional[Sequence[str]] = None, + ) -> Label: + label = super().__new__(cls, data, dtype=dtype, device=device) + + label._metadata.update(dict(categories=categories)) + + return label @classmethod - def _parse_meta_data( + def from_category(cls, category: str, *, categories: Sequence[str]) -> Label: + return cls(categories.index(category), categories=categories) + + def to_categories(self) -> Any: + if not self.categories: + raise RuntimeError() + + return apply_recursively(lambda idx: cast(Sequence[str], self.categories)[idx], self.tolist()) + + +class OneHotLabel(_Feature): + categories: Optional[Sequence[str]] + + def __new__( cls, - category: Optional[str] = DEFAULT, # type: ignore[assignment] - ) -> Dict[str, Tuple[Any, Any]]: - return dict(category=(category, None)) + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + like: Optional[Label] = None, + categories: Optional[Sequence[str]] = None, + ) -> OneHotLabel: + one_hot_label = super().__new__(cls, data, dtype=dtype, device=device) + + if categories is not None and len(categories) != one_hot_label.shape[-1]: + raise ValueError() + + one_hot_label._metadata.update(dict(categories=categories)) + + return one_hot_label diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py new file mode 100644 index 00000000000..dc41697ae9b --- /dev/null +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -0,0 +1,5 @@ +from ._feature import _Feature + + +class SegmentationMask(_Feature): + pass diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 56cca7b0402..c9988be1930 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,6 +1,5 @@ -from ._transform import Transform -from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip +from . import kernels # usort: skip +from . import functional # usort: skip +from .kernels import InterpolationMode # usort: skip -from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop -from ._misc import Identity, Normalize from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py deleted file mode 100644 index 86d7804dd17..00000000000 --- a/torchvision/prototype/transforms/_container.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Any, List - -import torch -from torch import nn -from torchvision.prototype.transforms import Transform - - -class ContainerTransform(nn.Module): - def supports(self, obj: Any) -> bool: - raise NotImplementedError() - - def forward(self, *inputs: Any) -> Any: - raise NotImplementedError() - - def _make_repr(self, lines: List[str]) -> str: - extra_repr = self.extra_repr() - if extra_repr: - lines = [self.extra_repr(), *lines] - head = f"{type(self).__name__}(" - tail = ")" - body = [f" {line.rstrip()}" for line in lines] - return "\n".join([head, *body, tail]) - - -class WrapperTransform(ContainerTransform): - def __init__(self, transform: Transform): - super().__init__() - self._transform = transform - - def supports(self, obj: Any) -> bool: - return self._transform.supports(obj) - - def __repr__(self) -> str: - return self._make_repr(repr(self._transform).splitlines()) - - -class MultiTransform(ContainerTransform): - def __init__(self, *transforms: Transform) -> None: - super().__init__() - self._transforms = transforms - - def supports(self, obj: Any) -> bool: - return all(transform.supports(obj) for transform in self._transforms) - - def __repr__(self) -> str: - lines = [] - for idx, transform in enumerate(self._transforms): - partial_lines = repr(transform).splitlines() - lines.append(f"({idx:d}): {partial_lines[0]}") - lines.extend(partial_lines[1:]) - return self._make_repr(lines) - - -class Compose(MultiTransform): - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - for transform in self._transforms: - sample = transform(sample) - return sample - - -class RandomApply(WrapperTransform): - def __init__(self, transform: Transform, *, p: float = 0.5) -> None: - super().__init__(transform) - self._p = p - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if float(torch.rand(())) < self._p: - return sample - - return self._transform(sample) - - def extra_repr(self) -> str: - return f"p={self._p}" - - -class RandomChoice(MultiTransform): - def forward(self, *inputs: Any) -> Any: - idx = int(torch.randint(len(self._transforms), size=())) - transform = self._transforms[idx] - return transform(*inputs) - - -class RandomOrder(MultiTransform): - def forward(self, *inputs: Any) -> Any: - for idx in torch.randperm(len(self._transforms)): - transform = self._transforms[idx] - inputs = transform(*inputs) - return inputs diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py deleted file mode 100644 index f34e5daa063..00000000000 --- a/torchvision/prototype/transforms/_geometry.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Dict, Tuple, Union - -import torch -from torch.nn.functional import interpolate -from torchvision.prototype.datasets.utils import SampleQuery -from torchvision.prototype.features import BoundingBox, Image, Label -from torchvision.prototype.transforms import Transform - - -class HorizontalFlip(Transform): - NO_OP_FEATURE_TYPES = {Label} - - @staticmethod - def image(input: Image) -> Image: - return Image(input.flip((-1,)), like=input) - - @staticmethod - def bounding_box(input: BoundingBox) -> BoundingBox: - x, y, w, h = input.convert("xywh").to_parts() - x = input.image_size[1] - (x + w) - return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh").convert(input.format) - - -class Resize(Transform): - NO_OP_FEATURE_TYPES = {Label} - - def __init__( - self, - size: Union[int, Tuple[int, int]], - *, - interpolation_mode: str = "nearest", - ) -> None: - super().__init__() - self.size = (size, size) if isinstance(size, int) else size - self.interpolation_mode = interpolation_mode - - def get_params(self, sample: Any) -> Dict[str, Any]: - return dict(size=self.size, interpolation_mode=self.interpolation_mode) - - @staticmethod - def image(input: Image, *, size: Tuple[int, int], interpolation_mode: str = "nearest") -> Image: - return Image(interpolate(input.unsqueeze(0), size, mode=interpolation_mode).squeeze(0), like=input) - - @staticmethod - def bounding_box(input: BoundingBox, *, size: Tuple[int, int], **_: Any) -> BoundingBox: - old_height, old_width = input.image_size - new_height, new_width = size - - height_scale = new_height / old_height - width_scale = new_width / old_width - - old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts() - - new_x1 = old_x1 * width_scale - new_y1 = old_y1 * height_scale - - new_x2 = old_x2 * width_scale - new_y2 = old_y2 * height_scale - - return BoundingBox.from_parts( - new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=size - ).convert(input.format) - - def extra_repr(self) -> str: - extra_repr = f"size={self.size}" - if self.interpolation_mode != "bilinear": - extra_repr += f", interpolation_mode={self.interpolation_mode}" - return extra_repr - - -class RandomResize(Transform, wraps=Resize): - def __init__(self, min_size: Union[int, Tuple[int, int]], max_size: Union[int, Tuple[int, int]]) -> None: - super().__init__() - self.min_size = (min_size, min_size) if isinstance(min_size, int) else min_size - self.max_size = (max_size, max_size) if isinstance(max_size, int) else max_size - - def get_params(self, sample: Any) -> Dict[str, Any]: - min_height, min_width = self.min_size - max_height, max_width = self.max_size - height = int(torch.randint(min_height, max_height + 1, size=())) - width = int(torch.randint(min_width, max_width + 1, size=())) - return dict(size=(height, width)) - - def extra_repr(self) -> str: - return f"min_size={self.min_size}, max_size={self.max_size}" - - -class Crop(Transform): - NO_OP_FEATURE_TYPES = {BoundingBox, Label} - - def __init__(self, crop_box: BoundingBox) -> None: - super().__init__() - self.crop_box = crop_box.convert("xyxy") - - def get_params(self, sample: Any) -> Dict[str, Any]: - return dict(crop_box=self.crop_box) - - @staticmethod - def image(input: Image, *, crop_box: BoundingBox) -> Image: - # FIXME: pad input in case it is smaller than crop_box - x1, y1, x2, y2 = crop_box.convert("xyxy").to_parts() - return Image(input[..., y1 : y2 + 1, x1 : x2 + 1], like=input) # type: ignore[misc] - - -class CenterCrop(Transform, wraps=Crop): - def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: - super().__init__() - self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size - - def get_params(self, sample: Any) -> Dict[str, Any]: - image_size = SampleQuery(sample).image_size() - image_height, image_width = image_size - cx = image_width // 2 - cy = image_height // 2 - h, w = self.crop_size - crop_box = BoundingBox.from_parts(cx, cy, w, h, image_size=image_size, format="cxcywh") - return dict(crop_box=crop_box) - - def extra_repr(self) -> str: - return f"crop_size={self.crop_size}" - - -class RandomCrop(Transform, wraps=Crop): - def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: - super().__init__() - self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size - - def get_params(self, sample: Any) -> Dict[str, Any]: - image_size = SampleQuery(sample).image_size() - image_height, image_width = image_size - crop_height, crop_width = self.crop_size - x = torch.randint(0, image_width - crop_width + 1, size=()) if crop_width < image_width else 0 - y = torch.randint(0, image_height - crop_height + 1, size=()) if crop_height < image_height else 0 - crop_box = BoundingBox.from_parts(x, y, crop_width, crop_height, image_size=image_size, format="xywh") - return dict(crop_box=crop_box) - - def extra_repr(self) -> str: - return f"crop_size={self.crop_size}" diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py deleted file mode 100644 index 47062aeaf03..00000000000 --- a/torchvision/prototype/transforms/_misc.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Any, Dict, Sequence - -import torch -from torchvision.prototype.features import Image, BoundingBox, Label -from torchvision.prototype.transforms import Transform - - -class Identity(Transform): - """Identity transform that supports all built-in :class:`~torchvision.prototype.features.Feature`'s.""" - - def __init__(self): - super().__init__() - for feature_type in self._BUILTIN_FEATURE_TYPES: - self.register_feature_transform(feature_type, lambda input, **params: input) - - -class Normalize(Transform): - NO_OP_FEATURE_TYPES = {BoundingBox, Label} - - def __init__(self, mean: Sequence[float], std: Sequence[float]): - super().__init__() - self.mean = mean - self.std = std - - def get_params(self, sample: Any) -> Dict[str, Any]: - return dict(mean=self.mean, std=self.std) - - @staticmethod - def _channel_stats_to_tensor(stats: Sequence[float], *, like: torch.Tensor) -> torch.Tensor: - return torch.as_tensor(stats, device=like.device, dtype=like.dtype).view(-1, 1, 1) - - @staticmethod - def image(input: Image, *, mean: Sequence[float], std: Sequence[float]) -> Image: - mean_t = Normalize._channel_stats_to_tensor(mean, like=input) - std_t = Normalize._channel_stats_to_tensor(std, like=input) - return Image((input - mean_t) / std_t, like=input) - - def extra_repr(self) -> str: - return f"mean={tuple(self.mean)}, std={tuple(self.std)}" diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py deleted file mode 100644 index 8062ff0fad0..00000000000 --- a/torchvision/prototype/transforms/_transform.py +++ /dev/null @@ -1,406 +0,0 @@ -import collections.abc -import inspect -import re -from typing import Any, Callable, Dict, Optional, Type, Union, cast, Set, Collection - -import torch -from torch import nn -from torchvision.prototype import features -from torchvision.prototype.utils._internal import add_suggestion - - -class Transform(nn.Module): - """Base class for transforms. - - A transform operates on a full sample at once, which might be a nested container of elements to transform. The - non-container elements of the sample will be dispatched to feature transforms based on their type in case it is - supported by the transform. Each transform needs to define at least one feature transform, which is canonical done - as static method: - - .. code-block:: - - class ImageIdentity(Transform): - @staticmethod - def image(input): - return input - - To achieve correct results for a complete sample, each transform should implement feature transforms for every - :class:`Feature` it can handle: - - .. code-block:: - - class Identity(Transform): - @staticmethod - def image(input): - return input - - @staticmethod - def bounding_box(input): - return input - - ... - - If the name of a static method in camel-case matches the name of a :class:`Feature`, the feature transform is - auto-registered. Supported pairs are: - - +----------------+----------------+ - | method name | `Feature` | - +================+================+ - | `image` | `Image` | - +----------------+----------------+ - | `bounding_box` | `BoundingBox` | - +----------------+----------------+ - | `label` | `Label` | - +----------------+----------------+ - - If you don't want to stick to this scheme, you can disable the auto-registration and perform it manually: - - .. code-block:: - - def my_image_transform(input): - ... - - class MyTransform(Transform, auto_register=False): - def __init__(self): - super().__init__() - self.register_feature_transform(Image, my_image_transform) - self.register_feature_transform(BoundingBox, self.my_bounding_box_transform) - - @staticmethod - def my_bounding_box_transform(input): - ... - - In any case, the registration will assert that the feature transform can be invoked with - ``feature_transform(input, **params)``. - - .. warning:: - - Feature transforms are **registered on the class and not on the instance**. This means you cannot have two - instances of the same :class:`Transform` with different feature transforms. - - If the feature transforms needs additional parameters, you need to - overwrite the :meth:`~Transform.get_params` method. It needs to return the parameter dictionary that will be - unpacked and its contents passed to each feature transform: - - .. code-block:: - - class Rotate(Transform): - def __init__(self, degrees): - super().__init__() - self.degrees = degrees - - def get_params(self, sample): - return dict(degrees=self.degrees) - - def image(input, *, degrees): - ... - - The :meth:`~Transform.get_params` method will be invoked once per sample. Thus, in case of randomly sampled - parameters they will be the same for all features of the whole sample. - - .. code-block:: - - class RandomRotate(Transform) - def __init__(self, range): - super().__init__() - self._dist = torch.distributions.Uniform(range) - - def get_params(self, sample): - return dict(degrees=self._dist.sample().item()) - - @staticmethod - def image(input, *, degrees): - ... - - In case the sampling depends on one or more features at runtime, the complete ``sample`` gets passed to the - :meth:`Transform.get_params` method. Derivative transforms that only changes the parameter sampling, but the - feature transformations are identical, can simply wrap the transform they dispatch to: - - .. code-block:: - - class RandomRotate(Transform, wraps=Rotate): - def get_params(self, sample): - return dict(degrees=float(torch.rand(())) * 30.0) - - To transform a sample, you simply call an instance of the transform with it: - - .. code-block:: - - transform = MyTransform() - sample = dict(input=Image(torch.tensor(...)), target=BoundingBox(torch.tensor(...)), ...) - transformed_sample = transform(sample) - - .. note:: - - To use a :class:`Transform` with a dataset, simply use it as map: - - .. code-block:: - - torchvision.datasets.load(...).map(MyTransform()) - """ - - _BUILTIN_FEATURE_TYPES = ( - features.BoundingBox, - features.Image, - features.Label, - ) - _FEATURE_NAME_MAP = { - "_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", feature_type.__name__)]): feature_type - for feature_type in _BUILTIN_FEATURE_TYPES - } - _feature_transforms: Dict[Type[features.Feature], Callable] - - NO_OP_FEATURE_TYPES: Collection[Type[features.Feature]] = () - - def __init_subclass__( - cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False - ): - cls._feature_transforms = {} if wraps is None else wraps._feature_transforms.copy() - if wraps: - cls.NO_OP_FEATURE_TYPES = wraps.NO_OP_FEATURE_TYPES - if auto_register: - cls._auto_register(verbose=verbose) - - @staticmethod - def _has_allowed_signature(feature_transform: Callable) -> bool: - """Checks if ``feature_transform`` can be invoked with ``feature_transform(input, **params)``""" - - parameters = tuple(inspect.signature(feature_transform).parameters.values()) - if not parameters: - return False - elif len(parameters) == 1: - return parameters[0].kind != inspect.Parameter.KEYWORD_ONLY - else: - return parameters[1].kind != inspect.Parameter.POSITIONAL_ONLY - - @classmethod - def register_feature_transform(cls, feature_type: Type[features.Feature], transform: Callable) -> None: - """Registers a transform for given feature on the class. - - If a transform object is called or :meth:`Transform.apply` is invoked, inputs are dispatched to the registered - transforms based on their type. - - Args: - feature_type: Feature type the transformation is registered for. - transform: Feature transformation. - - Raises: - TypeError: If ``transform`` cannot be invoked with ``transform(input, **params)``. - """ - if not cls._has_allowed_signature(transform): - raise TypeError("Feature transform cannot be invoked with transform(input, **params)") - cls._feature_transforms[feature_type] = transform - - @classmethod - def _auto_register(cls, *, verbose: bool = False) -> None: - """Auto-registers methods on the class as feature transforms if they meet the following criteria: - - 1. They are static. - 2. They can be invoked with `cls.feature_transform(input, **params)`. - 3. They are public. - 4. Their name in camel case matches the name of a builtin feature, e.g. 'bounding_box' and 'BoundingBox'. - - The name from 4. determines for which feature the method is registered. - - .. note:: - - The ``auto_register`` and ``verbose`` flags need to be passed as keyword arguments to the class: - - .. code-block:: - - class MyTransform(Transform, auto_register=True, verbose=True): - ... - - Args: - verbose: If ``True``, prints to STDOUT which methods were registered or why a method was not registered - """ - for name, value in inspect.getmembers(cls): - # check if attribute is a static method and was defined in the subclass - # TODO: this needs to be revisited to allow subclassing of custom transforms - if not (name in cls.__dict__ and inspect.isfunction(value)): - continue - - not_registered_prefix = f"{cls.__name__}.{name}() was not registered as feature transform, because" - - if not cls._has_allowed_signature(value): - if verbose: - print(f"{not_registered_prefix} it cannot be invoked with {name}(input, **params).") - continue - - if name.startswith("_"): - if verbose: - print(f"{not_registered_prefix} it is private.") - continue - - try: - feature_type = cls._FEATURE_NAME_MAP[name] - except KeyError: - if verbose: - print( - add_suggestion( - f"{not_registered_prefix} its name doesn't match any known feature type.", - word=name, - possibilities=cls._FEATURE_NAME_MAP.keys(), - close_match_hint=lambda close_match: ( - f"Did you mean to name it '{close_match}' " - f"to be registered for type '{cls._FEATURE_NAME_MAP[close_match]}'?" - ), - ) - ) - continue - - cls.register_feature_transform(feature_type, value) - if verbose: - print( - f"{cls.__name__}.{name}() was registered as feature transform for type '{feature_type.__name__}'." - ) - - @classmethod - def from_callable( - cls, - feature_transform: Union[Callable, Dict[Type[features.Feature], Callable]], - *, - name: str = "FromCallable", - get_params: Optional[Union[Dict[str, Any], Callable[[Any], Dict[str, Any]]]] = None, - ) -> "Transform": - """Creates a new transform from a callable. - - Args: - feature_transform: Feature transform that will be registered to handle :class:`Image`'s. Can be passed as - dictionary in which case each key-value-pair is needs to consists of a ``Feature`` type and the - corresponding transform. - name: Name of the transform. - get_params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. - Can be passed as callable in which case it will be called with the transform instance (``self``) and - the input of the transform. - - Raises: - TypeError: If ``feature_transform`` cannot be invoked with ``feature_transform(input, **params)``. - """ - if get_params is None: - get_params = dict() - attributes = dict( - get_params=get_params if callable(get_params) else lambda self, sample: get_params, # type: ignore[misc] - ) - transform_cls = cast(Type[Transform], type(name, (cls,), attributes)) - - if callable(feature_transform): - feature_transform = {features.Image: feature_transform} - for feature_type, transform in feature_transform.items(): - transform_cls.register_feature_transform(feature_type, transform) - - return transform_cls() - - @classmethod - def supported_feature_types(cls) -> Set[Type[features.Feature]]: - return set(cls._feature_transforms.keys()) - - @classmethod - def supports(cls, obj: Any) -> bool: - """Checks if object or type is supported. - - Args: - obj: Object or type. - """ - # TODO: should this handle containers? - feature_type = obj if isinstance(obj, type) else type(obj) - return feature_type is torch.Tensor or feature_type in cls.supported_feature_types() - - @classmethod - def transform(cls, input: Union[torch.Tensor, features.Feature], **params: Any) -> torch.Tensor: - """Applies the registered feature transform to the input based on its type. - - This can be uses as feature type generic functional interface: - - .. code-block:: - - transform = Rotate.transform - transformed_image = transform(Image(torch.tensor(...)), degrees=30.0) - transformed_bbox = transform(BoundingBox(torch.tensor(...)), degrees=-10.0) - - Args: - input: ``input`` in ``feature_transform(input, **params)`` - **params: Parameter dictionary ``params`` in ``feature_transform(input, **params)``. - - Returns: - Transformed input. - """ - feature_type = type(input) - if not cls.supports(feature_type): - raise TypeError(f"{cls.__name__}() is not able to handle inputs of type {feature_type}.") - - if feature_type is torch.Tensor: - # To keep BC, we treat all regular torch.Tensor's as images - feature_type = features.Image - input = feature_type(input) - feature_type = cast(Type[features.Feature], feature_type) - - feature_transform = cls._feature_transforms[feature_type] - output = feature_transform(input, **params) - - if type(output) is torch.Tensor: - output = feature_type(output, like=input) - return output - - def _transform_recursively(self, sample: Any, *, params: Dict[str, Any]) -> Any: - """Recurses through a sample and invokes :meth:`Transform.transform` on non-container elements. - - If an element is not supported by the transform, it is returned untransformed. - - Args: - sample: Sample. - params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. - """ - # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: - # "a" == "a"[0][0]... - if isinstance(sample, collections.abc.Sequence) and not isinstance(sample, str): - return [self._transform_recursively(item, params=params) for item in sample] - elif isinstance(sample, collections.abc.Mapping): - return {name: self._transform_recursively(item, params=params) for name, item in sample.items()} - else: - feature_type = type(sample) - if not self.supports(feature_type): - if ( - not issubclass(feature_type, features.Feature) - # issubclass is not a strict check, but also allows the type checked against. Thus, we need to - # check it separately - or feature_type is features.Feature - or feature_type in self.NO_OP_FEATURE_TYPES - ): - return sample - - raise TypeError( - f"{type(self).__name__}() is not able to handle inputs of type {feature_type}. " - f"If you want it to be a no-op, add the feature type to {type(self).__name__}.NO_OP_FEATURE_TYPES." - ) - - return self.transform(cast(Union[torch.Tensor, features.Feature], sample), **params) - - def get_params(self, sample: Any) -> Dict[str, Any]: - """Returns the parameter dictionary used to transform the current sample. - - .. note:: - - Since ``sample`` might be a nested container, it is recommended to use the - :class:`torchvision.datasets.utils.Query` class if you need to extract information from it. - - Args: - sample: Current sample. - - Returns: - Parameter dictionary ``params`` in ``feature_transform(input, **params)``. - """ - return dict() - - def forward( - self, - *inputs: Any, - params: Optional[Dict[str, Any]] = None, - ) -> Any: - if not self._feature_transforms: - raise RuntimeError(f"{type(self).__name__}() has no registered feature transform.") - - sample = inputs if len(inputs) > 1 else inputs[0] - if params is None: - params = self.get_params(sample) - return self._transform_recursively(sample, params=params) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py new file mode 100644 index 00000000000..9f05f16df2d --- /dev/null +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -0,0 +1,14 @@ +from ._augment import erase, mixup, cutmix +from ._color import ( + adjust_brightness, + adjust_contrast, + adjust_saturation, + adjust_sharpness, + posterize, + solarize, + autocontrast, + equalize, + invert, +) +from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate +from ._misc import normalize diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py new file mode 100644 index 00000000000..2eafe0d3c1f --- /dev/null +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -0,0 +1,57 @@ +from typing import TypeVar, Any + +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F + +from ._utils import dispatch + +T = TypeVar("T", bound=features._Feature) + + +@dispatch( + { + torch.Tensor: _F.erase, + features.Image: K.erase_image, + } +) +def erase(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + features.Image: K.mixup_image, + features.OneHotLabel: K.mixup_one_hot_label, + } +) +def mixup(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + features.Image: K.cutmix_image, + features.OneHotLabel: K.cutmix_one_hot_label, + } +) +def cutmix(input: T, *args: Any, **kwargs: Any) -> T: + """Perform the CutMix operation as introduced in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. + + Dispatch to the corresponding kernels happens according to this table: + + .. table:: + :widths: 30 70 + + ==================================================== ================================================================ + :class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image` + :class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label` + ==================================================== ================================================================ + + Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. + """ + ... diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py new file mode 100644 index 00000000000..23e128b7856 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_color.py @@ -0,0 +1,119 @@ +from typing import TypeVar, Any + +import PIL.Image +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F + +from ._utils import dispatch + +T = TypeVar("T", bound=features._Feature) + + +@dispatch( + { + torch.Tensor: _F.adjust_brightness, + PIL.Image.Image: _F.adjust_brightness, + features.Image: K.adjust_brightness_image, + } +) +def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.adjust_saturation, + PIL.Image.Image: _F.adjust_saturation, + features.Image: K.adjust_saturation_image, + } +) +def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.adjust_contrast, + PIL.Image.Image: _F.adjust_contrast, + features.Image: K.adjust_contrast_image, + } +) +def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.adjust_sharpness, + PIL.Image.Image: _F.adjust_sharpness, + features.Image: K.adjust_sharpness_image, + } +) +def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.posterize, + PIL.Image.Image: _F.posterize, + features.Image: K.posterize_image, + } +) +def posterize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.solarize, + PIL.Image.Image: _F.solarize, + features.Image: K.solarize_image, + } +) +def solarize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.autocontrast, + PIL.Image.Image: _F.autocontrast, + features.Image: K.autocontrast_image, + } +) +def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.equalize, + PIL.Image.Image: _F.equalize, + features.Image: K.equalize_image, + } +) +def equalize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.invert, + PIL.Image.Image: _F.invert, + features.Image: K.invert_image, + } +) +def invert(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py new file mode 100644 index 00000000000..147baa3a066 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -0,0 +1,95 @@ +from typing import TypeVar, Any, cast + +import PIL.Image +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F + +from ._utils import dispatch + +T = TypeVar("T", bound=features._Feature) + + +@dispatch( + { + torch.Tensor: _F.hflip, + PIL.Image.Image: _F.hflip, + features.Image: K.horizontal_flip_image, + features.BoundingBox: None, + }, +) +def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + if isinstance(input, features.BoundingBox): + output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return cast(T, features.BoundingBox.new_like(input, output)) + + raise RuntimeError + + +@dispatch( + { + torch.Tensor: _F.resize, + PIL.Image.Image: _F.resize, + features.Image: K.resize_image, + features.SegmentationMask: K.resize_segmentation_mask, + features.BoundingBox: None, + } +) +def resize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + if isinstance(input, features.BoundingBox): + size = kwargs.pop("size") + output = K.resize_bounding_box(input, size=size, image_size=input.image_size) + return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) + + raise RuntimeError + + +@dispatch( + { + torch.Tensor: _F.center_crop, + PIL.Image.Image: _F.center_crop, + features.Image: K.center_crop_image, + } +) +def center_crop(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.resized_crop, + PIL.Image.Image: _F.resized_crop, + features.Image: K.resized_crop_image, + } +) +def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.affine, + PIL.Image.Image: _F.affine, + features.Image: K.affine_image, + } +) +def affine(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.rotate, + PIL.Image.Image: _F.rotate, + features.Image: K.rotate_image, + } +) +def rotate(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py new file mode 100644 index 00000000000..7cf0765105a --- /dev/null +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -0,0 +1,21 @@ +from typing import TypeVar, Any + +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F + +from ._utils import dispatch + +T = TypeVar("T", bound=features._Feature) + + +@dispatch( + { + torch.Tensor: _F.normalize, + features.Image: K.normalize_image, + } +) +def normalize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py new file mode 100644 index 00000000000..591f9a83101 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -0,0 +1,89 @@ +import functools +import inspect +from typing import Any, Optional, Callable, TypeVar, Dict + +import torch +import torch.overrides +from torchvision.prototype import features + +F = TypeVar("F", bound=features._Feature) + + +def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: + """Decorates a function to automatically dispatch to registered kernels based on the call arguments. + + The dispatch function should have this signature + + .. code:: python + + @dispatch( + ... + ) + def dispatch_fn(input, *args, **kwargs): + ... + + where ``input`` is used to determine which kernel to dispatch to. + + Args: + kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for + exact type matches first and if none is found falls back to checking for subclasses. If a value is + ``None``, the decorated function is called. + + Raises: + TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``. + TypeError: If the decorated function is called with an input that cannot be dispatched. + """ + + def check_kernel(kernel: Any) -> bool: + if kernel is None: + return True + + if not callable(kernel): + return False + + params = list(inspect.signature(kernel).parameters.values()) + if not params: + return False + + return params[0].kind != inspect.Parameter.KEYWORD_ONLY + + for feature_type, kernel in kernels.items(): + if not check_kernel(kernel): + raise TypeError( + f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)." + ) + + def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: + @functools.wraps(dispatch_fn) + def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: + feature_type = type(input) + try: + kernel = kernels[feature_type] + except KeyError: + try: + feature_type, kernel = next( + (feature_type, kernel) + for feature_type, kernel in kernels.items() + if isinstance(input, feature_type) + ) + except StopIteration: + raise TypeError(f"No support for {type(input).__name__}") from None + + if kernel is None: + output = dispatch_fn(input, *args, **kwargs) + if output is None: + raise RuntimeError( + f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} " + f"although it was configured to do so." + ) + else: + output = kernel(input, *args, **kwargs) + + if issubclass(feature_type, features._Feature) and type(output) is torch.Tensor: + output = feature_type.new_like(input, output) + + return output + + return inner_wrapper + + return outer_wrapper diff --git a/torchvision/prototype/transforms/kernels/__init__.py b/torchvision/prototype/transforms/kernels/__init__.py new file mode 100644 index 00000000000..6f74f6af0e9 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/__init__.py @@ -0,0 +1,34 @@ +from torchvision.transforms import InterpolationMode # usort: skip +from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip + +from ._augment import ( + erase_image, + mixup_image, + mixup_one_hot_label, + cutmix_image, + cutmix_one_hot_label, +) +from ._color import ( + adjust_brightness_image, + adjust_contrast_image, + adjust_saturation_image, + adjust_sharpness_image, + posterize_image, + solarize_image, + autocontrast_image, + equalize_image, + invert_image, +) +from ._geometry import ( + horizontal_flip_bounding_box, + horizontal_flip_image, + resize_bounding_box, + resize_image, + resize_segmentation_mask, + center_crop_image, + resized_crop_image, + affine_image, + rotate_image, +) +from ._misc import normalize_image +from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/kernels/_augment.py new file mode 100644 index 00000000000..526ed85ffd8 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_augment.py @@ -0,0 +1,45 @@ +from typing import Tuple + +import torch +from torchvision.transforms import functional as _F + + +erase_image = _F.erase + + +def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: + input = input.clone() + return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) + + +def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + return _mixup(image_batch, -4, lam) + + +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam) + + +def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + x1, y1, x2, y2 = box + image_rolled = image_batch.roll(1, -4) + + image_batch = image_batch.clone() + image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image_batch + + +def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam_adjusted) diff --git a/torchvision/prototype/transforms/kernels/_color.py b/torchvision/prototype/transforms/kernels/_color.py new file mode 100644 index 00000000000..0d828e6d169 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_color.py @@ -0,0 +1,12 @@ +from torchvision.transforms import functional as _F + + +adjust_brightness_image = _F.adjust_brightness +adjust_saturation_image = _F.adjust_saturation +adjust_contrast_image = _F.adjust_contrast +adjust_sharpness_image = _F.adjust_sharpness +posterize_image = _F.posterize +solarize_image = _F.solarize +autocontrast_image = _F.autocontrast +equalize_image = _F.equalize +invert_image = _F.invert diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py new file mode 100644 index 00000000000..fb25f0fdf47 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -0,0 +1,70 @@ +from typing import Tuple, List, Optional, TypeVar + +import torch +from torchvision.prototype import features +from torchvision.transforms import functional as _F, InterpolationMode + +from ._meta_conversion import convert_bounding_box_format + + +T = TypeVar("T", bound=features._Feature) + + +horizontal_flip_image = _F.hflip + + +def horizontal_flip_bounding_box( + bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format + ).view(shape) + + +def resize_image( + image: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + new_height, new_width = size + num_channels, old_height, old_width = image.shape[-3:] + batch_shape = image.shape[:-3] + return _F.resize( + image.reshape((-1, num_channels, old_height, old_width)), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ).reshape(batch_shape + (num_channels, new_height, new_width)) + + +def resize_segmentation_mask( + segmentation_mask: torch.Tensor, + size: List[int], + max_size: Optional[int] = None, +) -> torch.Tensor: + return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + + +# TODO: handle max_size +def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: + old_height, old_width = image_size + new_height, new_width = size + ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) + return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) + + +center_crop_image = _F.center_crop +resized_crop_image = _F.resized_crop +affine_image = _F.affine +rotate_image = _F.rotate diff --git a/torchvision/prototype/transforms/kernels/_meta_conversion.py b/torchvision/prototype/transforms/kernels/_meta_conversion.py new file mode 100644 index 00000000000..4acaf9fe9e4 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_meta_conversion.py @@ -0,0 +1,69 @@ +import torch +from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale + + +def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: + xyxy = xywh.clone() + xyxy[..., 2:] += xyxy[..., :2] + return xyxy + + +def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + xywh = xyxy.clone() + xywh[..., 2:] -= xywh[..., :2] + return xywh + + +def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: + cx, cy, w, h = torch.unbind(cxcywh, dim=-1) + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + return torch.stack((x1, y1, x2, y2), dim=-1) + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: + x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1) + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return torch.stack((cx, cy, w, h), dim=-1) + + +def convert_bounding_box_format( + bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat +) -> torch.Tensor: + if new_format == old_format: + return bounding_box.clone() + + if old_format == BoundingBoxFormat.XYWH: + bounding_box = _xywh_to_xyxy(bounding_box) + elif old_format == BoundingBoxFormat.CXCYWH: + bounding_box = _cxcywh_to_xyxy(bounding_box) + + if new_format == BoundingBoxFormat.XYWH: + bounding_box = _xyxy_to_xywh(bounding_box) + elif new_format == BoundingBoxFormat.CXCYWH: + bounding_box = _xyxy_to_cxcywh(bounding_box) + + return bounding_box + + +def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: + return grayscale.expand(3, 1, 1) + + +def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: + if new_color_space == old_color_space: + return image.clone() + + if old_color_space == ColorSpace.GRAYSCALE: + image = _grayscale_to_rgb(image) + + if new_color_space == ColorSpace.GRAYSCALE: + image = _rgb_to_grayscale(image) + + return image diff --git a/torchvision/prototype/transforms/kernels/_misc.py b/torchvision/prototype/transforms/kernels/_misc.py new file mode 100644 index 00000000000..de148ab194a --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_misc.py @@ -0,0 +1,4 @@ +from torchvision.transforms import functional as _F + + +normalize_image = _F.normalize diff --git a/torchvision/prototype/transforms/kernels/_type_conversion.py b/torchvision/prototype/transforms/kernels/_type_conversion.py new file mode 100644 index 00000000000..09cb61b8a21 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_type_conversion.py @@ -0,0 +1,25 @@ +import unittest.mock +from typing import Dict, Any, Tuple, cast + +import numpy as np +import PIL.Image +import torch +from torch.nn.functional import one_hot +from torchvision.io.video import read_video +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer + + +def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor: + image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True)) + if image.ndim == 2: + image = image.unsqueeze(2) + return image.permute(2, 0, 1) + + +def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True): + return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] + + +def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: + return cast(torch.Tensor, one_hot(label, num_classes=num_categories)) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 9468dcf08a9..fe75c19eb75 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -3,11 +3,33 @@ import enum import functools import inspect +import io +import mmap import os import os.path +import platform import textwrap import warnings -from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast +from typing import ( + Any, + BinaryIO, + Callable, + cast, + Collection, + Iterable, + Iterator, + Mapping, + NoReturn, + Sequence, + Tuple, + TypeVar, + Union, + List, + Dict, +) + +import numpy as np +import torch __all__ = [ "StrEnum", @@ -17,10 +39,15 @@ "make_repr", "FrozenBunch", "kwonly_to_pos_or_kw", + "fromfile", + "ReadOnlyTensorBuffer", + "apply_recursively", ] class StrEnumMeta(enum.EnumMeta): + auto = enum.auto + def __getitem__(self, item): return super().__getitem__(item.upper() if isinstance(item, str) else item) @@ -186,3 +213,114 @@ def wrapper(*args: Any, **kwargs: Any) -> D: return fn(*args, **kwargs) return wrapper + + +def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: + # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable + return bytearray(file.read(-1 if count == -1 else count * item_size)) + + +def fromfile( + file: BinaryIO, + *, + dtype: torch.dtype, + byte_order: str, + count: int = -1, +) -> torch.Tensor: + """Construct a tensor from a binary file. + .. note:: + This function is similar to :func:`numpy.fromfile` with two notable differences: + 1. This function only accepts an open binary file, but not a path to it. + 2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that + concept. + .. note:: + If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as + long as the file is still open, inplace operations on the returned tensor will reflect back to the file. + Args: + file (IO): Open binary file. + dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor. + byte_order (str): Byte order of the data. Can be "little" or "big" endian. + count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file. + """ + byte_order = "<" if byte_order == "little" else ">" + char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u") + item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 + np_dtype = byte_order + char + str(item_size) + + buffer: Union[memoryview, bytearray] + if platform.system() != "Windows": + # PyTorch does not support tensors with underlying read-only memory. In case + # - the file has a .fileno(), + # - the file was opened for updating, i.e. 'r+b' or 'w+b', + # - the file is seekable + # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it + # to a mutable location afterwards. + try: + buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] + # Reading from the memoryview does not advance the file cursor, so we have to do it manually. + file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) + except (AttributeError, PermissionError, io.UnsupportedOperation): + buffer = _read_mutable_buffer_fallback(file, count, item_size) + else: + # On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state + # so no data can be read afterwards. Thus, we simply ignore the possible speed-up. + buffer = _read_mutable_buffer_fallback(file, count, item_size) + + # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we + # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the + # successive .astype() call. + return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False)) + + +class ReadOnlyTensorBuffer: + def __init__(self, tensor: torch.Tensor) -> None: + self._memory = memoryview(tensor.numpy()) + self._cursor: int = 0 + + def tell(self) -> int: + return self._cursor + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_SET: + self._cursor = offset + elif whence == io.SEEK_CUR: + self._cursor += offset + pass + elif whence == io.SEEK_END: + self._cursor = len(self._memory) + offset + else: + raise ValueError( + f"'whence' should be ``{io.SEEK_SET}``, ``{io.SEEK_CUR}``, or ``{io.SEEK_END}``, " + f"but got {repr(whence)} instead" + ) + return self.tell() + + def read(self, size: int = -1) -> bytes: + cursor = self.tell() + offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) + return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() + + +def apply_recursively(fn: Callable, obj: Any) -> Any: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + sequence: List[Any] = [] + for item in obj: + result = apply_recursively(fn, item) + if isinstance(result, collections.abc.Sequence) and hasattr(result, "__inline__"): + sequence.extend(result) + else: + sequence.append(result) + return sequence + elif isinstance(obj, collections.abc.Mapping): + mapping: Dict[Any, Any] = {} + for name, item in obj.items(): + result = apply_recursively(fn, item) + if isinstance(result, collections.abc.Mapping) and hasattr(result, "__inline__"): + mapping.update(result) + else: + mapping[name] = result + return mapping + else: + return fn(obj)