From a6a73e2745411d5b2062d8cfa66d50dee6a2478d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 4 Nov 2021 16:44:44 +0100 Subject: [PATCH 1/7] add initial chunk of prototype transforms --- .../prototype/datasets/_builtin/imagenet.py | 2 +- .../prototype/datasets/utils/__init__.py | 1 + .../prototype/datasets/utils/_dataset.py | 3 +- .../prototype/datasets/utils/_internal.py | 85 ---- .../prototype/datasets/utils/_query.py | 41 ++ torchvision/prototype/features/_image.py | 9 + torchvision/prototype/models/alexnet.py | 2 +- torchvision/prototype/models/densenet.py | 2 +- .../prototype/models/detection/faster_rcnn.py | 3 +- torchvision/prototype/models/efficientnet.py | 2 +- torchvision/prototype/models/googlenet.py | 2 +- torchvision/prototype/models/inception.py | 2 +- torchvision/prototype/models/mnasnet.py | 2 +- torchvision/prototype/models/mobilenetv2.py | 2 +- torchvision/prototype/models/mobilenetv3.py | 2 +- .../models/quantization/googlenet.py | 2 +- .../models/quantization/inception.py | 2 +- .../prototype/models/quantization/resnet.py | 2 +- .../models/quantization/shufflenetv2.py | 2 +- torchvision/prototype/models/regnet.py | 2 +- torchvision/prototype/models/resnet.py | 2 +- .../models/segmentation/deeplabv3.py | 3 +- .../prototype/models/segmentation/fcn.py | 3 +- .../prototype/models/segmentation/lraspp.py | 3 +- torchvision/prototype/models/shufflenetv2.py | 2 +- torchvision/prototype/models/squeezenet.py | 2 +- torchvision/prototype/models/vgg.py | 2 +- torchvision/prototype/models/video/resnet.py | 2 +- torchvision/prototype/transforms/__init__.py | 7 +- .../prototype/transforms/_container.py | 92 ++++ torchvision/prototype/transforms/_geometry.py | 130 ++++++ torchvision/prototype/transforms/_misc.py | 48 +++ .../transforms/{presets.py => _presets.py} | 0 .../prototype/transforms/_transform.py | 403 ++++++++++++++++++ torchvision/prototype/utils/_internal.py | 90 +++- 35 files changed, 847 insertions(+), 112 deletions(-) create mode 100644 torchvision/prototype/datasets/utils/_query.py create mode 100644 torchvision/prototype/transforms/_container.py create mode 100644 torchvision/prototype/transforms/_geometry.py create mode 100644 torchvision/prototype/transforms/_misc.py rename torchvision/prototype/transforms/{presets.py => _presets.py} (100%) create mode 100644 torchvision/prototype/transforms/_transform.py diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 986b01cd8f5..b8ca0089e65 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -20,8 +20,8 @@ Enumerator, getitem, read_mat, - FrozenMapping, ) +from torchvision.prototype.utils._internal import FrozenMapping class ImageNet(Dataset): diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 018553e0908..93c56ce173b 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,3 +1,4 @@ from . import _internal from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset +from ._query import SampleQuery from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 23474686a3b..e42a5d2e48e 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -9,10 +9,11 @@ import torch from torch.utils.data import IterDataPipe +from torchvision.prototype.utils._internal import FrozenBunch, make_repr from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str from .._home import use_sharded_dataset -from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe +from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._resource import OnlineResource diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 2c48c4414e3..f931aa58f31 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,4 +1,3 @@ -import csv import enum import gzip import io @@ -7,7 +6,6 @@ import os.path import pathlib import pickle -import textwrap from typing import ( Sequence, Callable, @@ -18,10 +16,7 @@ Iterator, Dict, Optional, - NoReturn, IO, - Iterable, - Mapping, Sized, ) from typing import cast @@ -38,10 +33,6 @@ __all__ = [ "INFINITE_BUFFER_SIZE", "BUILTIN_DIR", - "make_repr", - "FrozenMapping", - "FrozenBunch", - "create_categories_file", "read_mat", "image_buffer_from_array", "SequenceIterator", @@ -62,82 +53,6 @@ BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" -def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: - def to_str(sep: str) -> str: - return sep.join([f"{key}={value}" for key, value in items]) - - prefix = f"{name}(" - postfix = ")" - body = to_str(", ") - - line_length = int(os.environ.get("COLUMNS", 80)) - body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length - multiline_body = len(str(body).splitlines()) > 1 - if not (body_too_long or multiline_body): - return prefix + body + postfix - - body = textwrap.indent(to_str(",\n"), " " * 2) - return f"{prefix}\n{body}\n{postfix}" - - -class FrozenMapping(Mapping[K, D]): - def __init__(self, *args: Any, **kwargs: Any) -> None: - data = dict(*args, **kwargs) - self.__dict__["__data__"] = data - self.__dict__["__final_hash__"] = hash(tuple(data.items())) - - def __getitem__(self, item: K) -> D: - return cast(Mapping[K, D], self.__dict__["__data__"])[item] - - def __iter__(self) -> Iterator[K]: - return iter(self.__dict__["__data__"].keys()) - - def __len__(self) -> int: - return len(self.__dict__["__data__"]) - - def __setitem__(self, key: K, value: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __delitem__(self, key: K) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __hash__(self) -> int: - return cast(int, self.__dict__["__final_hash__"]) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, FrozenMapping): - return NotImplemented - - return hash(self) == hash(other) - - def __repr__(self) -> str: - return repr(self.__dict__["__data__"]) - - -class FrozenBunch(FrozenMapping): - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError as error: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error - - def __setattr__(self, key: Any, value: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __delattr__(self, item: Any) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __repr__(self) -> str: - return make_repr(type(self).__name__, self.items()) - - -def create_categories_file( - root: Union[str, pathlib.Path], name: str, categories: Sequence[Union[str, Sequence[str]]], **fmtparams: Any -) -> None: - with open(pathlib.Path(root) / f"{name}.categories", "w", newline="") as file: - csv.writer(file, **fmtparams).writerows(categories) - - def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: try: import scipy.io as sio diff --git a/torchvision/prototype/datasets/utils/_query.py b/torchvision/prototype/datasets/utils/_query.py new file mode 100644 index 00000000000..0adec6e83a6 --- /dev/null +++ b/torchvision/prototype/datasets/utils/_query.py @@ -0,0 +1,41 @@ +import collections.abc +from typing import Any, Callable, Iterator, Optional, Tuple, TypeVar, cast + +from torchvision.prototype.features import BoundingBox, Image + +T = TypeVar("T") + + +class SampleQuery: + def __init__(self, sample: Any) -> None: + self.sample = sample + + @staticmethod + def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]: + if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)): + for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample: + yield from SampleQuery._query_recursively(item, fn) + else: + result = fn(sample) + if result is not None: + yield result + + def query(self, fn: Callable[[Any], Optional[T]]) -> T: + results = set(self._query_recursively(self.sample, fn)) + if not results: + raise RuntimeError("Query turned up empty.") + elif len(results) > 1: + raise RuntimeError(f"Found more than one result: {results}") + + return results.pop() + + def image_size(self) -> Tuple[int, int]: + def fn(sample: Any) -> Optional[Tuple[int, int]]: + if isinstance(sample, Image): + return cast(Tuple[int, int], sample.shape[-2:]) + elif isinstance(sample, BoundingBox): + return sample.image_size + else: + return None + + return self.query(fn) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index a8eab249997..a39c654f0b7 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -18,6 +18,15 @@ class Image(Feature): color_spaces = ColorSpace color_space: ColorSpace + @classmethod + def _to_tensor(cls, data, *, dtype, device): + tensor = torch.as_tensor(data, dtype=dtype, device=device) + if tensor.ndim == 2: + tensor = tensor.unsqueeze(0) + elif tensor.ndim != 3: + raise ValueError() + return tensor + @classmethod def _parse_meta_data( cls, diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index 5236eb87de6..0f61c8f31fa 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.alexnet import AlexNet -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index adb1140063d..db0c742e48d 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -4,10 +4,10 @@ from typing import Any, Optional, Tuple import torch.nn as nn +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.densenet import DenseNet -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 4f8ec08edc3..e04932a1b6e 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,6 +1,8 @@ import warnings from typing import Any, Optional, Union +from torchvision.prototype.transforms import CocoEval + from ....models.detection.faster_rcnn import ( _mobilenet_extractor, _resnet_fpn_extractor, @@ -10,7 +12,6 @@ misc_nn_ops, overwrite_eps, ) -from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 9e550c966ed..da63e5a9d45 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -3,10 +3,10 @@ from typing import Any, Optional from torch import nn +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.efficientnet import EfficientNet, MBConvConfig -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 0fb45e103e3..5691fb81b03 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index daac42a3d3f..88649e61293 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index b55aa1c46f7..7301817ebc8 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.mnasnet import MNASNet -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 40721792b3c..990ede1eafc 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv2 import MobileNetV2 -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index 8e1fad903a0..5b7f42517d1 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional, List +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 2cf6527cccf..84f1b91d1b3 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -2,6 +2,7 @@ from functools import partial from typing import Any, Optional, Union +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.googlenet import ( @@ -9,7 +10,6 @@ _replace_relu, quantize_model, ) -from ...transforms.presets import ImageNetEval from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES from ..googlenet import GoogLeNetWeights diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index a783f33d177..cc864677b78 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -2,6 +2,7 @@ from functools import partial from typing import Any, Optional, Union +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.inception import ( @@ -9,7 +10,6 @@ _replace_relu, quantize_model, ) -from ...transforms.presets import ImageNetEval from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES from ..inception import InceptionV3Weights diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 361bf3dc385..c678619353a 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -2,6 +2,7 @@ from functools import partial from typing import Any, List, Optional, Type, Union +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.resnet import ( @@ -11,7 +12,6 @@ _replace_relu, quantize_model, ) -from ...transforms.presets import ImageNetEval from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 2bc0383cff1..06c1f29b631 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -2,6 +2,7 @@ from functools import partial from typing import Any, List, Optional, Union +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.shufflenetv2 import ( @@ -9,7 +10,6 @@ _replace_relu, quantize_model, ) -from ...transforms.presets import ImageNetEval from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index c972f6f27e8..60da724f9bd 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -3,10 +3,10 @@ from typing import Any, Optional from torch import nn +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.regnet import RegNet, BlockParams -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 764ea49641e..d5c774623f0 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, List, Optional, Type, Union +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.resnet import BasicBlock, Bottleneck, ResNet -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 016e6cc507d..6423773a7d5 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -2,8 +2,9 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import VocEval + from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from ...transforms.presets import VocEval from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 0c671053bf5..892b0cdb50f 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -2,8 +2,9 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import VocEval + from ....models.segmentation.fcn import FCN, _fcn_resnet -from ...transforms.presets import VocEval from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 2a696c24bd4..f5ee7b17f0d 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -2,8 +2,9 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import VocEval + from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from ...transforms.presets import VocEval from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 22eb0843f3f..d866a66bc50 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.shufflenetv2 import ShuffleNetV2 -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index d354eb70721..56aabc3714b 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.squeezenet import SqueezeNet -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 54afd3b10fc..b2c0a4aabc3 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -2,10 +2,10 @@ from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ...models.vgg import VGG, make_layers, cfgs -from ..transforms.presets import ImageNetEval from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 280420be4fe..f88f5a2aa31 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -3,6 +3,7 @@ from typing import Any, Callable, List, Optional, Sequence, Type, Union from torch import nn +from torchvision.prototype.transforms import Kinect400Eval from torchvision.transforms.functional import InterpolationMode from ....models.video.resnet import ( @@ -15,7 +16,6 @@ R2Plus1dStem, VideoResNet, ) -from ...transforms.presets import Kinect400Eval from .._api import Weights, WeightEntry from .._meta import _KINETICS400_CATEGORIES diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 061b9381e43..f6a608b6572 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1 +1,6 @@ -from .presets import * +from ._transform import Transform +from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip + +from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop +from ._misc import Identity, Lambda, Normalize +from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py new file mode 100644 index 00000000000..53a430c6b92 --- /dev/null +++ b/torchvision/prototype/transforms/_container.py @@ -0,0 +1,92 @@ +from typing import Any, List + +import torch +from torch import nn +from torchvision.prototype.transforms import Transform + + +class ContainerTransform(nn.Module): + def supports(self, obj: Any) -> bool: + raise NotImplementedError() + + def forward(self, *inputs: Any, strict: bool = False) -> Any: + raise NotImplementedError() + + def _make_repr(self, lines: List[str]) -> str: + extra_repr = self.extra_repr() + if extra_repr: + lines = [self.extra_repr(), *lines] + head = f"{type(self).__name__}(" + tail = ")" + body = [f" {line.rstrip()}" for line in lines] + return "\n".join([head, *body, tail]) + + +class WrapperTransform(ContainerTransform): + def __init__(self, transform: Transform): + super().__init__() + self._transform = transform + + def supports(self, obj: Any) -> bool: + return self._transform.supports(obj) + + def __repr__(self) -> str: + return self._make_repr(repr(self._transform).splitlines()) + + +class MultiTransform(ContainerTransform): + def __init__(self, *transforms: Transform) -> None: + super().__init__() + self._transforms = transforms + + def supports(self, obj: Any, *, strict: bool = False) -> bool: + aggregator = all if strict else any + return aggregator(transform.supports(obj) for transform in self._transforms) + + def __repr__(self) -> str: + lines = [] + for idx, transform in enumerate(self._transforms): + partial_lines = repr(transform).splitlines() + lines.append(f"({idx:d}): {partial_lines[0]}") + lines.extend(partial_lines[1:]) + return self._make_repr(lines) + + +class Compose(MultiTransform): + def forward(self, *inputs: Any, strict: bool = False) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + for transform in self._transforms: + sample = transform(sample, strict=strict) + return sample + + +class RandomApply(WrapperTransform): + def __init__(self, transform: Transform, *, p: float = 0.5) -> None: + super().__init__(transform) + self._p = p + + def forward(self, *inputs: Any, strict: bool = False) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if float(torch.rand(())) < self._p: + # TODO: Should we check here is sample is supported if strict=True? + return sample + + return self._transform(sample, strict=strict) + + def extra_repr(self) -> str: + return f"p={self._p}" + + +class RandomChoice(MultiTransform): + def forward(self, *inputs: Any, strict: bool = False) -> Any: + idx = int(torch.randint(len(self._transforms), size=())) + transform = self._transforms[idx] + return transform(*inputs, strict=strict) + + +class RandomOrder(MultiTransform): + def forward(self, *inputs: Any, strict: bool = False) -> Any: + for idx in torch.randperm(len(self._transforms)): + transform = self._transforms[idx] + inputs = transform(*inputs, strict=strict) + return inputs diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py new file mode 100644 index 00000000000..c2d2dffefcd --- /dev/null +++ b/torchvision/prototype/transforms/_geometry.py @@ -0,0 +1,130 @@ +from typing import Any, Dict, Tuple, Union + +import torch +from torch.nn.functional import interpolate +from torchvision.prototype.datasets.utils import SampleQuery +from torchvision.prototype.features import BoundingBox, Image +from torchvision.prototype.transforms import Transform + + +class HorizontalFlip(Transform): + @staticmethod + def image(input: Image) -> Image: + return Image(input.flip((-1,)), like=input) + + @staticmethod + def bounding_box(input: BoundingBox) -> BoundingBox: + x, y, w, h = input.convert("xywh").to_parts() + x = input.image_size[1] - (x + w) + return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh") + + +class Resize(Transform): + def __init__( + self, + size: Union[int, Tuple[int, int]], + *, + interpolation_mode: str = "nearest", + ) -> None: + super().__init__() + self.size = (size, size) if isinstance(size, int) else size + self.interpolation_mode = interpolation_mode + + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict(size=self.size, interpolation_mode=self.interpolation_mode) + + @staticmethod + def image(input: Image, *, size: Tuple[int, int], interpolation_mode: str = "nearest") -> Image: + return Image(interpolate(input.unsqueeze(0), size, mode=interpolation_mode).squeeze(0), like=input) + + @staticmethod + def bounding_box(input: BoundingBox, *, size: Tuple[int, int], **_: Any) -> BoundingBox: + old_height, old_width = input.image_size + new_height, new_width = size + + height_scale = new_height / old_height + width_scale = new_width / old_width + + old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts() + + new_x1 = old_x1 * width_scale + new_y1 = old_y1 * height_scale + + new_x2 = old_x2 * width_scale + new_y2 = old_y2 * height_scale + + return BoundingBox.from_parts(new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=size) + + def extra_repr(self) -> str: + extra_repr = f"size={self.size}" + if self.interpolation_mode != "bilinear": + extra_repr += f", interpolation_mode={self.interpolation_mode}" + return extra_repr + + +class RandomResize(Transform, wraps=Resize): + def __init__(self, min_size: Union[int, Tuple[int, int]], max_size: Union[int, Tuple[int, int]]) -> None: + super().__init__() + self.min_size = (min_size, min_size) if isinstance(min_size, int) else min_size + self.max_size = (max_size, max_size) if isinstance(max_size, int) else max_size + + def get_params(self, sample: Any) -> Dict[str, Any]: + min_height, min_width = self.min_size + max_height, max_width = self.max_size + height = int(torch.randint(min_height, max_height + 1, size=())) + width = int(torch.randint(min_width, max_width + 1, size=())) + return dict(size=(height, width)) + + def extra_repr(self) -> str: + return f"min_size={self.min_size}, max_size={self.max_size}" + + +class Crop(Transform): + def __init__(self, crop_box: BoundingBox) -> None: + super().__init__() + self.crop_box = crop_box.convert("xyxy") + + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict(crop_box=self.crop_box) + + @staticmethod + def image(input: Image, *, crop_box: BoundingBox) -> Image: + # FIXME: pad input in case it is smaller than crop_box + x1, y1, x2, y2 = crop_box.convert("xyxy").to_parts() + return Image(input[..., y1 : y2 + 1, x1 : x2 + 1], like=input) # type: ignore[misc] + + +class CenterCrop(Transform, wraps=Crop): + def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: + super().__init__() + self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size + + def get_params(self, sample: Any) -> Dict[str, Any]: + image_size = SampleQuery(sample).image_size() + image_height, image_width = image_size + cx = image_width // 2 + cy = image_height // 2 + h, w = self.crop_size + crop_box = BoundingBox.from_parts(cx, cy, w, h, image_size=image_size, format="cxcywh") + return dict(crop_box=crop_box.convert("xyxy")) + + def extra_repr(self) -> str: + return f"crop_size={self.crop_size}" + + +class RandomCrop(Transform, wraps=Crop): + def __init__(self, crop_size: Union[int, Tuple[int, int]]) -> None: + super().__init__() + self.crop_size = (crop_size, crop_size) if isinstance(crop_size, int) else crop_size + + def get_params(self, sample: Any) -> Dict[str, Any]: + image_size = SampleQuery(sample).image_size() + image_height, image_width = image_size + crop_height, crop_width = self.crop_size + x = torch.randint(0, image_width - crop_width + 1, size=()) if crop_width < image_width else 0 + y = torch.randint(0, image_height - crop_height + 1, size=()) if crop_height < image_height else 0 + crop_box = BoundingBox.from_parts(x, y, crop_width, crop_height, image_size=image_size, format="xywh") + return dict(crop_box=crop_box.convert("xyxy")) + + def extra_repr(self) -> str: + return f"crop_size={self.crop_size}" diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py new file mode 100644 index 00000000000..88fe0d68b22 --- /dev/null +++ b/torchvision/prototype/transforms/_misc.py @@ -0,0 +1,48 @@ +import warnings +from typing import Any, Dict, Sequence +from typing import Callable + +import torch +from torchvision.prototype.features import Image +from torchvision.prototype.transforms import Transform + + +class Identity(Transform): + """Identity transform that supports all built-in :class:`~torchvision.prototype.features.Feature`'s.""" + + def __init__(self): + super().__init__() + for feature_type in self._BUILTIN_FEATURE_TYPES: + self.register_feature_transform(feature_type, lambda input, **params: input) + + +class Lambda(Transform): + def __new__(cls, lambd: Callable) -> Transform: # type: ignore[misc] + warnings.warn("transforms.Lambda(...) is deprecated. Use transforms.Transform.from_callable(...) instead.") + # We need to generate a new class everytime a Lambda transform is created, since the feature transforms are + # registered on the class rather than on the instance. If we didn't, registering a feature transform will + # overwrite it on **all** Lambda transform instances. + return Transform.from_callable(lambd, name="Lambda") + + +class Normalize(Transform): + def __init__(self, mean: Sequence[float], std: Sequence[float]): + super().__init__() + self.mean = mean + self.std = std + + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict(mean=self.mean, std=self.std) + + @staticmethod + def _channel_stats_to_tensor(stats: Sequence[float], *, like: torch.Tensor) -> torch.Tensor: + return torch.as_tensor(stats, device=like.device, dtype=like.dtype).view(-1, 1, 1) + + @staticmethod + def image(input: Image, *, mean: Sequence[float], std: Sequence[float]) -> Image: + mean_t = Normalize._channel_stats_to_tensor(mean, like=input) + std_t = Normalize._channel_stats_to_tensor(std, like=input) + return Image((input - mean_t) / std_t, like=input) + + def extra_repr(self) -> str: + return f"mean={tuple(self.mean)}, std={tuple(self.std)}" diff --git a/torchvision/prototype/transforms/presets.py b/torchvision/prototype/transforms/_presets.py similarity index 100% rename from torchvision/prototype/transforms/presets.py rename to torchvision/prototype/transforms/_presets.py diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py new file mode 100644 index 00000000000..141379231aa --- /dev/null +++ b/torchvision/prototype/transforms/_transform.py @@ -0,0 +1,403 @@ +import collections.abc +import inspect +import re +from typing import Any, Callable, Dict, Optional, Type, Union, cast, Set + +import torch +from torch import nn +from torchvision.prototype import features +from torchvision.prototype.utils._internal import add_suggestion + + +class Transform(nn.Module): + """Base class for transforms. + + A transform operates on a full sample at once, which might be a nested container of elements to transform. The + non-container elements of the sample will be dispatched to feature transforms based on their type in case it is + supported by the transform. Each transform needs to define at least one feature transform, which is canonical done + as static method: + + .. code-block:: + + class ImageIdentity(Transform): + @staticmethod + def image(input): + return input + + To achieve correct results for a complete sample, each transform should implement feature transforms for every + :class:`Feature` it can handle: + + .. code-block:: + + class Identity(Transform): + @staticmethod + def image(input): + return input + + @staticmethod + def bounding_box(input): + return input + + ... + + If the name of a static method in camel-case matches the name of a :class:`Feature`, the feature transform is + auto-registered. Supported pairs are: + + +----------------+----------------+ + | method name | `Feature` | + +================+================+ + | `image` | `Image` | + +----------------+----------------+ + | `bounding_box` | `BoundingBox` | + +----------------+----------------+ + | `label` | `Label` | + +----------------+----------------+ + + If you don't want to stick to this scheme, you can disable the auto-registration and perform it manually: + + .. code-block:: + + def my_image_transform(input): + ... + + class MyTransform(Transform, auto_register=False): + def __init__(self): + super().__init__() + self.register_feature_transform(Image, my_image_transform) + self.register_feature_transform(BoundingBox, self.my_bounding_box_transform) + + @staticmethod + def my_bounding_box_transform(input): + ... + + In any case, the registration will assert that the feature transform can be invoked with + ``feature_transform(input, **params)``. + + .. warning:: + + Feature transforms are **registered on the class and not on the instance**. This means you cannot have two + instances of the same :class:`Transform` with different feature transforms. + + If the feature transforms needs additional parameters, you need to + overwrite the :meth:`~Transform.get_params` method. It needs to return the parameter dictionary that will be + unpacked and its contents passed to each feature transform: + + .. code-block:: + + class Rotate(Transform): + def __init__(self, degrees): + super().__init__() + self.degrees = degrees + + def get_params(self, sample): + return dict(degrees=self.degrees) + + def image(input, *, degrees): + ... + + The :meth:`~Transform.get_params` method will be invoked once per sample. Thus, in case of randomly sampled + parameters they will be the same for all features of the whole sample. + + .. code-block:: + + class RandomRotate(Transform) + def __init__(self, range): + super().__init__() + self._dist = torch.distributions.Uniform(range) + + def get_params(self, sample): + return dict(degrees=self._dist.sample().item()) + + @staticmethod + def image(input, *, degrees): + ... + + In case the sampling depends on one or more features at runtime, the complete ``sample`` gets passed to the + :meth:`Transform.get_params` method. Derivative transforms that only changes the parameter sampling, but the + feature transformations are identical, can simply wrap the transform they dispatch to: + + .. code-block:: + + class RandomRotate(Transform, wraps=Rotate): + def get_params(self, sample): + return dict(degrees=float(torch.rand(())) * 30.0) + + To transform a sample, you simply call an instance of the transform with it: + + .. code-block:: + + transform = MyTransform() + sample = dict(input=Image(torch.tensor(...)), target=BoundingBox(torch.tensor(...)), ...) + transformed_sample = transform(sample) + + By default elements in the ``sample`` that are not supported by the transform are returned without modification. + You can set the ``strict=True`` flag to force a transformation of every element or bail out in case one is not + supported. + + .. note:: + + To use a :class:`Transform` with a dataset, simply use it as map: + + .. code-block:: + + torchvision.datasets.load(...).map(MyTransform()) + """ + + _BUILTIN_FEATURE_TYPES = ( + features.BoundingBox, + features.Image, + features.Label, + ) + _FEATURE_NAME_MAP = { + "_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", feature_type.__name__)]): feature_type + for feature_type in _BUILTIN_FEATURE_TYPES + } + _feature_transforms: Dict[Type[features.Feature], Callable] + + def __init_subclass__( + cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False + ): + cls._feature_transforms = {} if wraps is None else wraps._feature_transforms.copy() + if auto_register: + cls._auto_register(verbose=verbose) + + @staticmethod + def _has_allowed_signature(feature_transform: Callable) -> bool: + """Checks if ``feature_transform`` can be invoked with ``feature_transform(input, **params)``""" + + parameters = tuple(inspect.signature(feature_transform).parameters.values()) + if not parameters: + return False + elif len(parameters) == 1: + return parameters[0].kind != inspect.Parameter.KEYWORD_ONLY + else: + return parameters[1].kind != inspect.Parameter.POSITIONAL_ONLY + + @classmethod + def register_feature_transform(cls, feature_type: Type[features.Feature], transform: Callable) -> None: + """Registers a transform for given feature on the class. + + If a transform object is called or :meth:`Transform.apply` is invoked, inputs are dispatched to the registered + transforms based on their type. + + Args: + feature_type: Feature type the transformation is registered for. + transform: Feature transformation. + + Raises: + TypeError: If ``transform`` cannot be invoked with ``transform(input, **params)``. + """ + if not cls._has_allowed_signature(transform): + raise TypeError("Feature transform cannot be invoked with transform(input, **params)") + cls._feature_transforms[feature_type] = transform + + @classmethod + def _auto_register(cls, *, verbose: bool = False) -> None: + """Auto-registers methods on the class as feature transforms if they meet the following criteria: + + 1. They are static. + 2. They can be invoked with `cls.feature_transform(input, **params)`. + 3. They are public. + 4. Their name in camel case matches the name of a builtin feature, e.g. 'bounding_box' and 'BoundingBox'. + + The name from 4. determines for which feature the method is registered. + + .. note:: + + The ``auto_register`` and ``verbose`` flags need to be passed as keyword arguments to the class: + + .. code-block:: + + class MyTransform(Transform, auto_register=True, verbose=True): + ... + + Args: + verbose: If ``True``, prints to STDOUT which methods were registered or why a method was not registered + """ + for name, value in inspect.getmembers(cls): + # check if attribute is a static method and was defined in the subclass + # TODO: this needs to be revisited to allow subclassing of custom transforms + if not (name in cls.__dict__ and inspect.isfunction(value)): + continue + + not_registered_prefix = f"{cls.__name__}.{name}() was not registered as feature transform, because" + + if not cls._has_allowed_signature(value): + if verbose: + print(f"{not_registered_prefix} it cannot be invoked with {name}(input, **params).") + continue + + if name.startswith("_"): + if verbose: + print(f"{not_registered_prefix} it is private.") + continue + + try: + feature_type = cls._FEATURE_NAME_MAP[name] + except KeyError: + if verbose: + print( + add_suggestion( + f"{not_registered_prefix} its name doesn't match any known feature type.", + word=name, + possibilities=cls._FEATURE_NAME_MAP.keys(), + close_match_hint=lambda close_match: ( + f"Did you mean to name it '{close_match}' " + f"to be registered for type '{cls._FEATURE_NAME_MAP[close_match]}'?" + ), + ) + ) + continue + + cls.register_feature_transform(feature_type, value) + if verbose: + print( + f"{cls.__name__}.{name}() was registered as feature transform for type '{feature_type.__name__}'." + ) + + @classmethod + def from_callable( + cls, + feature_transform: Union[Callable, Dict[Type[features.Feature], Callable]], + *, + name: str = "FromCallable", + get_params: Optional[Union[Dict[str, Any], Callable[[Any], Dict[str, Any]]]] = None, + ) -> "Transform": + """Creates a new transform from a callable. + + Args: + feature_transform: Feature transform that will be registered to handle :class:`Image`'s. Can be passed as + dictionary in which case each key-value-pair is needs to consists of a ``Feature`` type and the + corresponding transform. + name: Name of the transform. + get_params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. + Can be passed as callable in which case it will be called with the transform instance (``self``) and + the input of the transform. + + Raises: + TypeError: If ``feature_transform`` cannot be invoked with ``feature_transform(input, **params)``. + """ + if get_params is None: + get_params = dict() + attributes = dict( + get_params=get_params if callable(get_params) else lambda self, sample: get_params, # type: ignore[misc] + ) + transform_cls = cast(Type[Transform], type(name, (cls,), attributes)) + + if callable(feature_transform): + feature_transform = {features.Image: feature_transform} + for feature_type, transform in feature_transform.items(): + transform_cls.register_feature_transform(feature_type, transform) + + return transform_cls() + + @classmethod + def supported_feature_types(cls) -> Set[Type[features.Feature]]: + return set(cls._feature_transforms.keys()) + + @classmethod + def supports(cls, obj: Any) -> bool: + """Checks if object or type is supported. + + Args: + obj: Object or type. + """ + # TODO: should this handle containers? + feature_type = obj if isinstance(obj, type) else type(obj) + return feature_type is torch.Tensor or feature_type in cls.supported_feature_types() + + @classmethod + def transform(cls, input: Union[torch.Tensor, features.Feature], **params: Any) -> torch.Tensor: + """Applies the registered feature transform to the input based on its type. + + This can be uses as feature type generic functional interface: + + .. code-block:: + + transform = Rotate.transform + transformed_image = transform(Image(torch.tensor(...)), degrees=30.0) + transformed_bbox = transform(BoundingBox(torch.tensor(...)), degrees=-10.0) + + Args: + input: ``input`` in ``feature_transform(input, **params)`` + **params: Parameter dictionary ``params`` in ``feature_transform(input, **params)``. + + Returns: + Transformed input. + """ + feature_type = type(input) + if not cls.supports(feature_type): + raise TypeError(f"{cls.__name__}() is not able to handle inputs of type {feature_type}.") + + if feature_type is torch.Tensor: + # To keep BC, we treat all regular torch.Tensor's as images + feature_type = features.Image + input = feature_type(input) + feature_type = cast(Type[features.Feature], feature_type) + + feature_transform = cls._feature_transforms[feature_type] + output = feature_transform(input, **params) + + if type(output) is torch.Tensor: + output = feature_type(output, like=input) + return output + + def _transform_recursively(self, sample: Any, *, params: Dict[str, Any], strict: bool) -> Any: + """Recurses through a sample and invokes :meth:`Transform.transform` on non-container elements. + + If an element is not supported by the transform, it is returned untransformed. + + Args: + sample: Sample. + params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. + strict: If ``True``, raises an error in case a non-container element of the ``sample`` is not supported by + the transform. + + Raises: + TypeError: If ``strict=True`` and a non-container element of the ``sample`` is not supported. + """ + if isinstance(sample, collections.abc.Sequence): + return [self._transform_recursively(item, params=params, strict=strict) for item in sample] + elif isinstance(sample, collections.abc.Mapping): + return { + name: self._transform_recursively(item, params=params, strict=strict) for name, item in sample.items() + } + else: + feature_type = type(sample) + if not self.supports(feature_type): + if not strict: + return sample + + raise TypeError(f"{type(self).__name__}() is not able to handle inputs of type {feature_type}.") + + return self.transform(sample, **params) + + def get_params(self, sample: Any) -> Dict[str, Any]: + """Returns the parameter dictionary used to transform the current sample. + + .. note:: + + Since ``sample`` might be a nested container, it is recommended to use the + :class:`torchvision.datasets.utils.Query` class if you need to extract information from it. + + Args: + sample: Current sample. + + Returns: + Parameter dictionary ``params`` in ``feature_transform(input, **params)``. + """ + return dict() + + def forward( + self, + *inputs: Any, + params: Optional[Dict[str, Any]] = None, + strict: bool = True, + ) -> Any: + if not self._feature_transforms: + raise RuntimeError(f"{type(self).__name__}() has no registered feature transform.") + + sample = inputs if len(inputs) > 1 else inputs[0] + if params is None: + params = self.get_params(sample) + return self._transform_recursively(sample, params=params, strict=strict) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 3e411ef1faa..2c0be022a6d 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -1,9 +1,19 @@ import collections.abc import difflib import enum -from typing import Sequence, Collection, Callable +import os +import os.path +import textwrap +from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast -__all__ = ["StrEnum", "sequence_to_str", "add_suggestion"] +__all__ = [ + "StrEnum", + "sequence_to_str", + "add_suggestion", + "FrozenMapping", + "make_repr", + "FrozenBunch", +] class StrEnumMeta(enum.EnumMeta): @@ -40,3 +50,79 @@ def add_suggestion( return msg return f"{msg.strip()} {hint}" + + +K = TypeVar("K") +D = TypeVar("D") + + +class FrozenMapping(Mapping[K, D]): + def __init__(self, *args: Any, **kwargs: Any) -> None: + data = dict(*args, **kwargs) + self.__dict__["__data__"] = data + self.__dict__["__final_hash__"] = hash(tuple(data.items())) + + def __getitem__(self, item: K) -> D: + return cast(Mapping[K, D], self.__dict__["__data__"])[item] + + def __iter__(self) -> Iterator[K]: + return iter(self.__dict__["__data__"].keys()) + + def __len__(self) -> int: + return len(self.__dict__["__data__"]) + + def __immutable__(self) -> NoReturn: + raise RuntimeError(f"'{type(self).__name__}' object is immutable") + + def __setitem__(self, key: K, value: Any) -> NoReturn: + self.__immutable__() + + def __delitem__(self, key: K) -> NoReturn: + self.__immutable__() + + def __hash__(self) -> int: + return cast(int, self.__dict__["__final_hash__"]) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FrozenMapping): + return NotImplemented + + return hash(self) == hash(other) + + def __repr__(self) -> str: + return repr(self.__dict__["__data__"]) + + +def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: + def to_str(sep: str) -> str: + return sep.join([f"{key}={value}" for key, value in items]) + + prefix = f"{name}(" + postfix = ")" + body = to_str(", ") + + line_length = int(os.environ.get("COLUMNS", 80)) + body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length + multiline_body = len(str(body).splitlines()) > 1 + if not (body_too_long or multiline_body): + return prefix + body + postfix + + body = textwrap.indent(to_str(",\n"), " " * 2) + return f"{prefix}\n{body}\n{postfix}" + + +class FrozenBunch(FrozenMapping): + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError as error: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error + + def __setattr__(self, key: Any, value: Any) -> NoReturn: + self.__immutable__() + + def __delattr__(self, item: Any) -> NoReturn: + self.__immutable__() + + def __repr__(self) -> str: + return make_repr(type(self).__name__, self.items()) From 64bdeb578f6ba21f41f22a7ad3aecd457591fa7c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 5 Nov 2021 00:04:25 +0100 Subject: [PATCH 2/7] fix tests --- test/test_prototype_datasets_api.py | 2 +- test/test_prototype_features.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py index 0cb6365278f..6b539400770 100644 --- a/test/test_prototype_datasets_api.py +++ b/test/test_prototype_datasets_api.py @@ -2,7 +2,7 @@ import pytest from torchvision.prototype import datasets -from torchvision.prototype.datasets.utils._internal import FrozenMapping, FrozenBunch +from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs): diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index e4a178e3594..147243286d4 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -11,6 +11,11 @@ make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) +def make_image(**kwargs): + data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist())) + return features.Image(data, **kwargs) + + def make_bounding_box(*, format="xyxy", image_size=(10, 10)): if isinstance(format, str): format = features.BoundingBoxFormat[format] @@ -42,6 +47,7 @@ def make_bounding_box(*, format="xyxy", image_size=(10, 10)): MAKE_DATA_MAP = { + features.Image: make_image, features.BoundingBox: make_bounding_box, } From 35c8dfdba8a120a241aa51f1c7a69d9bb21933e6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Nov 2021 08:35:57 +0100 Subject: [PATCH 3/7] add error message --- torchvision/prototype/features/_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index a39c654f0b7..3d0b3d0c0af 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -24,7 +24,7 @@ def _to_tensor(cls, data, *, dtype, device): if tensor.ndim == 2: tensor = tensor.unsqueeze(0) elif tensor.ndim != 3: - raise ValueError() + raise ValueError("Only single images with 2 or 3 dimensions are allowed.") return tensor @classmethod From a3b047cccbb5cff6e85fd010f460720c7ecdd99d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 10 Nov 2021 11:09:27 +0100 Subject: [PATCH 4/7] fix more imports --- torchvision/prototype/models/detection/keypoint_rcnn.py | 3 ++- torchvision/prototype/models/detection/mask_rcnn.py | 2 +- torchvision/prototype/models/detection/retinanet.py | 2 +- torchvision/prototype/models/detection/ssd.py | 2 +- torchvision/prototype/models/detection/ssdlite.py | 2 +- torchvision/prototype/models/quantization/mobilenetv2.py | 2 +- torchvision/prototype/models/quantization/mobilenetv3.py | 2 +- 7 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 392446fee8b..48b6640e044 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -1,6 +1,8 @@ import warnings from typing import Any, Optional +from torchvision.prototype.transforms import CocoEval + from ....models.detection.keypoint_rcnn import ( _resnet_fpn_extractor, _validate_trainable_layers, @@ -8,7 +10,6 @@ misc_nn_ops, overwrite_eps, ) -from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from ..resnet import ResNet50Weights, resnet50 diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index efce203f1bb..80603c7781d 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -1,6 +1,7 @@ import warnings from typing import Any, Optional +from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.mask_rcnn import ( @@ -10,7 +11,6 @@ misc_nn_ops, overwrite_eps, ) -from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES from ..resnet import ResNet50Weights, resnet50 diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index e44e4fe9285..93428dad662 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -1,6 +1,7 @@ import warnings from typing import Any, Optional +from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.retinanet import ( @@ -11,7 +12,6 @@ misc_nn_ops, overwrite_eps, ) -from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES from ..resnet import ResNet50Weights, resnet50 diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 5759b8cd40f..543e997aad2 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -1,6 +1,7 @@ import warnings from typing import Any, Optional +from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.ssd import ( @@ -9,7 +10,6 @@ DefaultBoxGenerator, SSD, ) -from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES from ..vgg import VGG16Weights, vgg16 diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index ae7092cb2a3..4cda67d573e 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional from torch import nn +from torchvision.prototype.transforms import CocoEval from torchvision.transforms.functional import InterpolationMode from ....models.detection.ssdlite import ( @@ -14,7 +15,6 @@ SSD, SSDLiteHead, ) -from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 578b3e5e37f..cf8d2a617f2 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -2,6 +2,7 @@ from functools import partial from typing import Any, Optional, Union +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.mobilenetv2 import ( @@ -10,7 +11,6 @@ _replace_relu, quantize_model, ) -from ...transforms.presets import ImageNetEval from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES from ..mobilenetv2 import MobileNetV2Weights diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 924ee91852e..459e99d2e0f 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional, Union import torch +from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode from ....models.quantization.mobilenetv3 import ( @@ -11,7 +12,6 @@ QuantizableMobileNetV3, _replace_relu, ) -from ...transforms.presets import ImageNetEval from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf From 660befb55b1a7ac00e97ee3387d5242be0f11379 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Nov 2021 10:55:02 +0100 Subject: [PATCH 5/7] add explicit no-ops --- .../prototype/transforms/_container.py | 24 +++++++------- .../prototype/transforms/_transform.py | 31 +++++++------------ 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index 53a430c6b92..86d7804dd17 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -9,7 +9,7 @@ class ContainerTransform(nn.Module): def supports(self, obj: Any) -> bool: raise NotImplementedError() - def forward(self, *inputs: Any, strict: bool = False) -> Any: + def forward(self, *inputs: Any) -> Any: raise NotImplementedError() def _make_repr(self, lines: List[str]) -> str: @@ -39,9 +39,8 @@ def __init__(self, *transforms: Transform) -> None: super().__init__() self._transforms = transforms - def supports(self, obj: Any, *, strict: bool = False) -> bool: - aggregator = all if strict else any - return aggregator(transform.supports(obj) for transform in self._transforms) + def supports(self, obj: Any) -> bool: + return all(transform.supports(obj) for transform in self._transforms) def __repr__(self) -> str: lines = [] @@ -53,10 +52,10 @@ def __repr__(self) -> str: class Compose(MultiTransform): - def forward(self, *inputs: Any, strict: bool = False) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] for transform in self._transforms: - sample = transform(sample, strict=strict) + sample = transform(sample) return sample @@ -65,28 +64,27 @@ def __init__(self, transform: Transform, *, p: float = 0.5) -> None: super().__init__(transform) self._p = p - def forward(self, *inputs: Any, strict: bool = False) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if float(torch.rand(())) < self._p: - # TODO: Should we check here is sample is supported if strict=True? return sample - return self._transform(sample, strict=strict) + return self._transform(sample) def extra_repr(self) -> str: return f"p={self._p}" class RandomChoice(MultiTransform): - def forward(self, *inputs: Any, strict: bool = False) -> Any: + def forward(self, *inputs: Any) -> Any: idx = int(torch.randint(len(self._transforms), size=())) transform = self._transforms[idx] - return transform(*inputs, strict=strict) + return transform(*inputs) class RandomOrder(MultiTransform): - def forward(self, *inputs: Any, strict: bool = False) -> Any: + def forward(self, *inputs: Any) -> Any: for idx in torch.randperm(len(self._transforms)): transform = self._transforms[idx] - inputs = transform(*inputs, strict=strict) + inputs = transform(*inputs) return inputs diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 141379231aa..a45292cbb66 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,7 +1,7 @@ import collections.abc import inspect import re -from typing import Any, Callable, Dict, Optional, Type, Union, cast, Set +from typing import Any, Callable, Dict, Optional, Type, Union, cast, Set, Collection import torch from torch import nn @@ -130,10 +130,6 @@ def get_params(self, sample): sample = dict(input=Image(torch.tensor(...)), target=BoundingBox(torch.tensor(...)), ...) transformed_sample = transform(sample) - By default elements in the ``sample`` that are not supported by the transform are returned without modification. - You can set the ``strict=True`` flag to force a transformation of every element or bail out in case one is not - supported. - .. note:: To use a :class:`Transform` with a dataset, simply use it as map: @@ -154,6 +150,8 @@ def get_params(self, sample): } _feature_transforms: Dict[Type[features.Feature], Callable] + NO_OP_FEATURE_TYPES: Collection[Type[features.Feature]] = () + def __init_subclass__( cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False ): @@ -342,7 +340,7 @@ def transform(cls, input: Union[torch.Tensor, features.Feature], **params: Any) output = feature_type(output, like=input) return output - def _transform_recursively(self, sample: Any, *, params: Dict[str, Any], strict: bool) -> Any: + def _transform_recursively(self, sample: Any, *, params: Dict[str, Any]) -> Any: """Recurses through a sample and invokes :meth:`Transform.transform` on non-container elements. If an element is not supported by the transform, it is returned untransformed. @@ -350,25 +348,21 @@ def _transform_recursively(self, sample: Any, *, params: Dict[str, Any], strict: Args: sample: Sample. params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``. - strict: If ``True``, raises an error in case a non-container element of the ``sample`` is not supported by - the transform. - - Raises: - TypeError: If ``strict=True`` and a non-container element of the ``sample`` is not supported. """ if isinstance(sample, collections.abc.Sequence): - return [self._transform_recursively(item, params=params, strict=strict) for item in sample] + return [self._transform_recursively(item, params=params) for item in sample] elif isinstance(sample, collections.abc.Mapping): - return { - name: self._transform_recursively(item, params=params, strict=strict) for name, item in sample.items() - } + return {name: self._transform_recursively(item, params=params) for name, item in sample.items()} else: feature_type = type(sample) if not self.supports(feature_type): - if not strict: + if feature_type in self.NO_OP_FEATURE_TYPES: return sample - raise TypeError(f"{type(self).__name__}() is not able to handle inputs of type {feature_type}.") + raise TypeError( + f"{type(self).__name__}() is not able to handle inputs of type {feature_type}. " + f"If you want it to be a no-op, add the feature type to {type(self).__name__}.NO_OP_FEATURE_TYPES." + ) return self.transform(sample, **params) @@ -392,7 +386,6 @@ def forward( self, *inputs: Any, params: Optional[Dict[str, Any]] = None, - strict: bool = True, ) -> Any: if not self._feature_transforms: raise RuntimeError(f"{type(self).__name__}() has no registered feature transform.") @@ -400,4 +393,4 @@ def forward( sample = inputs if len(inputs) > 1 else inputs[0] if params is None: params = self.get_params(sample) - return self._transform_recursively(sample, params=params, strict=strict) + return self._transform_recursively(sample, params=params) From 39ba6adbd7fd8a19f6bc3487c49e1e1810af58c0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Nov 2021 11:15:44 +0100 Subject: [PATCH 6/7] add test for no-ops --- test/test_prototype_transforms.py | 55 +++++++++++++++++++ torchvision/prototype/transforms/_geometry.py | 8 ++- torchvision/prototype/transforms/_misc.py | 4 +- .../prototype/transforms/_transform.py | 2 + 4 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 test/test_prototype_transforms.py diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py new file mode 100644 index 00000000000..678a7133ea6 --- /dev/null +++ b/test/test_prototype_transforms.py @@ -0,0 +1,55 @@ +import pytest +from torchvision.prototype import transforms, features +from torchvision.prototype.utils._internal import sequence_to_str + + +FEATURE_TYPES = { + feature_type + for name, feature_type in features.__dict__.items() + if not name.startswith("_") + and isinstance(feature_type, type) + and issubclass(feature_type, features.Feature) + and feature_type is not features.Feature +} + +TRANSFORM_TYPES = tuple( + transform_type + for name, transform_type in transforms.__dict__.items() + if not name.startswith("_") + and isinstance(transform_type, type) + and issubclass(transform_type, transforms.Transform) + and transform_type is not transforms.Transform +) + + +def test_feature_type_support(): + missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES) + if missing_feature_types: + names = sorted([feature_type.__name__ for feature_type in missing_feature_types]) + raise AssertionError( + f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at " + f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. " + f"Please add it/them to the collection." + ) + + +@pytest.mark.parametrize( + "transform_type", + [ + transform_type + for transform_type in TRANSFORM_TYPES + if transform_type not in (transforms.Lambda, transforms.Identity) + ], + ids=lambda transform_type: transform_type.__name__, +) +def test_no_op(transform_type): + unsupported_features = ( + FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES) + ) + if unsupported_features: + names = sorted([feature_type.__name__ for feature_type in unsupported_features]) + raise AssertionError( + f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as " + f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, " + f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection." + ) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c2d2dffefcd..62a698f2088 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -3,11 +3,13 @@ import torch from torch.nn.functional import interpolate from torchvision.prototype.datasets.utils import SampleQuery -from torchvision.prototype.features import BoundingBox, Image +from torchvision.prototype.features import BoundingBox, Image, Label from torchvision.prototype.transforms import Transform class HorizontalFlip(Transform): + NO_OP_FEATURE_TYPES = {Label} + @staticmethod def image(input: Image) -> Image: return Image(input.flip((-1,)), like=input) @@ -20,6 +22,8 @@ def bounding_box(input: BoundingBox) -> BoundingBox: class Resize(Transform): + NO_OP_FEATURE_TYPES = {Label} + def __init__( self, size: Union[int, Tuple[int, int]], @@ -80,6 +84,8 @@ def extra_repr(self) -> str: class Crop(Transform): + NO_OP_FEATURE_TYPES = {BoundingBox, Label} + def __init__(self, crop_box: BoundingBox) -> None: super().__init__() self.crop_box = crop_box.convert("xyxy") diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 88fe0d68b22..8dd92ef2e3e 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -3,7 +3,7 @@ from typing import Callable import torch -from torchvision.prototype.features import Image +from torchvision.prototype.features import Image, BoundingBox, Label from torchvision.prototype.transforms import Transform @@ -26,6 +26,8 @@ def __new__(cls, lambd: Callable) -> Transform: # type: ignore[misc] class Normalize(Transform): + NO_OP_FEATURE_TYPES = {BoundingBox, Label} + def __init__(self, mean: Sequence[float], std: Sequence[float]): super().__init__() self.mean = mean diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index a45292cbb66..0b3bf94bc8c 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -156,6 +156,8 @@ def __init_subclass__( cls, *, wraps: Optional[Type["Transform"]] = None, auto_register: bool = True, verbose: bool = False ): cls._feature_transforms = {} if wraps is None else wraps._feature_transforms.copy() + if wraps: + cls.NO_OP_FEATURE_TYPES = wraps.NO_OP_FEATURE_TYPES if auto_register: cls._auto_register(verbose=verbose) From 20bb26ceb3770ccf610cf16e54c219ff3e4c15ca Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 11 Nov 2021 11:24:17 +0100 Subject: [PATCH 7/7] cleanup --- test/test_prototype_transforms.py | 6 +----- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_geometry.py | 10 ++++++---- torchvision/prototype/transforms/_misc.py | 11 ----------- 4 files changed, 8 insertions(+), 21 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 678a7133ea6..80ede68bdc5 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -35,11 +35,7 @@ def test_feature_type_support(): @pytest.mark.parametrize( "transform_type", - [ - transform_type - for transform_type in TRANSFORM_TYPES - if transform_type not in (transforms.Lambda, transforms.Identity) - ], + [transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity], ids=lambda transform_type: transform_type.__name__, ) def test_no_op(transform_type): diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index f6a608b6572..c91542933b8 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -2,5 +2,5 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop -from ._misc import Identity, Lambda, Normalize +from ._misc import Identity, Normalize from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 62a698f2088..f34e5daa063 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -18,7 +18,7 @@ def image(input: Image) -> Image: def bounding_box(input: BoundingBox) -> BoundingBox: x, y, w, h = input.convert("xywh").to_parts() x = input.image_size[1] - (x + w) - return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh") + return BoundingBox.from_parts(x, y, w, h, like=input, format="xywh").convert(input.format) class Resize(Transform): @@ -57,7 +57,9 @@ def bounding_box(input: BoundingBox, *, size: Tuple[int, int], **_: Any) -> Boun new_x2 = old_x2 * width_scale new_y2 = old_y2 * height_scale - return BoundingBox.from_parts(new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=size) + return BoundingBox.from_parts( + new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=size + ).convert(input.format) def extra_repr(self) -> str: extra_repr = f"size={self.size}" @@ -112,7 +114,7 @@ def get_params(self, sample: Any) -> Dict[str, Any]: cy = image_height // 2 h, w = self.crop_size crop_box = BoundingBox.from_parts(cx, cy, w, h, image_size=image_size, format="cxcywh") - return dict(crop_box=crop_box.convert("xyxy")) + return dict(crop_box=crop_box) def extra_repr(self) -> str: return f"crop_size={self.crop_size}" @@ -130,7 +132,7 @@ def get_params(self, sample: Any) -> Dict[str, Any]: x = torch.randint(0, image_width - crop_width + 1, size=()) if crop_width < image_width else 0 y = torch.randint(0, image_height - crop_height + 1, size=()) if crop_height < image_height else 0 crop_box = BoundingBox.from_parts(x, y, crop_width, crop_height, image_size=image_size, format="xywh") - return dict(crop_box=crop_box.convert("xyxy")) + return dict(crop_box=crop_box) def extra_repr(self) -> str: return f"crop_size={self.crop_size}" diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 8dd92ef2e3e..47062aeaf03 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,6 +1,4 @@ -import warnings from typing import Any, Dict, Sequence -from typing import Callable import torch from torchvision.prototype.features import Image, BoundingBox, Label @@ -16,15 +14,6 @@ def __init__(self): self.register_feature_transform(feature_type, lambda input, **params: input) -class Lambda(Transform): - def __new__(cls, lambd: Callable) -> Transform: # type: ignore[misc] - warnings.warn("transforms.Lambda(...) is deprecated. Use transforms.Transform.from_callable(...) instead.") - # We need to generate a new class everytime a Lambda transform is created, since the feature transforms are - # registered on the class rather than on the instance. If we didn't, registering a feature transform will - # overwrite it on **all** Lambda transform instances. - return Transform.from_callable(lambd, name="Lambda") - - class Normalize(Transform): NO_OP_FEATURE_TYPES = {BoundingBox, Label}