From d8d765ac1b1ecbd7bc16f7b2870f52b943cf2ecc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 22 Sep 2022 11:31:15 +0100 Subject: [PATCH 01/11] Align and Clean up transform types --- torchvision/prototype/transforms/_augment.py | 26 +++++++++---------- .../prototype/transforms/_auto_augment.py | 22 ++++++++-------- torchvision/prototype/transforms/_color.py | 12 +++------ .../prototype/transforms/_deprecated.py | 6 ++--- torchvision/prototype/transforms/_geometry.py | 24 +++++++++-------- torchvision/prototype/transforms/_meta.py | 10 +++---- torchvision/prototype/transforms/_misc.py | 6 ++--- torchvision/prototype/transforms/_utils.py | 12 ++++----- 8 files changed, 57 insertions(+), 61 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 963cfac75b2..84895d3457e 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple import PIL.Image import torch @@ -11,7 +11,7 @@ from torchvision.prototype.transforms import functional as F, InterpolationMode from ._transform import _RandomApplyTransform -from ._utils import has_any, query_chw +from ._utils import has_any, ImageType, query_chw, TensorImageType class RandomErasing(_RandomApplyTransform): @@ -92,9 +92,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform( - self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any] - ) -> Union[torch.Tensor, features.Image, PIL.Image.Image]: + def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) @@ -110,8 +108,10 @@ def __init__(self, alpha: float, p: float = 0.5) -> None: def forward(self, *inputs: Any) -> Any: if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") - if has_any(inputs, features.BoundingBox, features.Mask, features.Label): - raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.") + if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): + raise TypeError( + f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." + ) return super().forward(*inputs) def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: @@ -203,15 +203,15 @@ def __init__( def _copy_paste( self, - image: Any, + image: TensorImageType, target: Dict[str, Any], - paste_image: Any, + paste_image: TensorImageType, paste_target: Dict[str, Any], random_selection: torch.Tensor, blending: bool, resize_interpolation: F.InterpolationMode, antialias: Optional[bool], - ) -> Tuple[Any, Dict[str, Any]]: + ) -> Tuple[TensorImageType, Dict[str, Any]]: paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) @@ -223,7 +223,7 @@ def _copy_paste( # This is something different to TF implementation we introduced here as # originally the algorithm works on equal-sized data # (for example, coming from LSJ data augmentations) - size1 = image.shape[-2:] + size1 = cast(List[int], image.shape[-2:]) size2 = paste_image.shape[-2:] if size1 != size2: paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) @@ -278,7 +278,7 @@ def _copy_paste( return image, out_target - def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: + def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[TensorImageType], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input # with List[image], List[BoundingBox], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] @@ -307,7 +307,7 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis return images, targets def _insert_outputs( - self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] + self, flat_sample: List[Any], output_images: List[TensorImageType], output_targets: List[Dict[str, Any]] ) -> None: c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 0d14a434c1e..adcd722f210 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_chw -from ._utils import _isinstance, _setup_fill_arg, FillType +from ._utils import _isinstance, _setup_fill_arg, FillType, ImageType K = TypeVar("K") V = TypeVar("V") @@ -20,7 +20,7 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[FillType, Dict[Type, FillType]] = None, ) -> None: super().__init__() self.interpolation = interpolation @@ -35,7 +35,7 @@ def _extract_image( self, sample: Any, unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), - ) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]: + ) -> Tuple[int, ImageType]: sample_flat, _ = tree_flatten(sample) images = [] for id, inpt in enumerate(sample_flat): @@ -59,13 +59,13 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: def _apply_image_transform( self, - image: Union[torch.Tensor, PIL.Image.Image, features.Image], + image: ImageType, transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Union[Dict[Type, FillType], Dict[Type, None]], - ) -> Any: - fill_ = fill[type(image)] + fill: Optional[Dict[Type, Optional[FillType]]], + ) -> ImageType: + fill_ = fill[type(image)] if fill is not None else None fill_ = F._geometry._convert_fill_arg(fill_) if transform_id == "Identity": @@ -177,7 +177,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[FillType, Dict[Type, FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -337,7 +337,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[FillType, Dict[Type, FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -393,7 +393,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[FillType, Dict[Type, FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -453,7 +453,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[FillType, Dict[Type, FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 19f90c7c984..a23b3a07400 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,5 +1,5 @@ import collections.abc -from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -7,9 +7,7 @@ from torchvision.prototype.transforms import functional as F, Transform from ._transform import _RandomApplyTransform -from ._utils import query_chw - -T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) +from ._utils import ImageType, query_chw class ColorJitter(Transform): @@ -112,7 +110,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, ) - def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any: + def _permute_channels(self, inpt: ImageType, permutation: torch.Tensor) -> ImageType: if isinstance(inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) @@ -125,9 +123,7 @@ def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any: return output - def _transform( - self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any] - ) -> Union[torch.Tensor, features.Image, PIL.Image.Image]: + def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: if params["brightness"]: inpt = F.adjust_brightness( inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 9b28551a048..2b4aa21a8b7 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -11,7 +11,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import DType, query_chw +from ._utils import ImageType, query_chw class ToTensor(Transform): @@ -52,7 +52,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: super().__init__() self.num_output_channels = num_output_channels - def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType: + def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, features.Image): output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) @@ -81,7 +81,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: num_input_channels, _, _ = query_chw(sample) return dict(num_input_channels=num_input_channels) - def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType: + def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, features.Image): output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index a8f0b09765b..5e1ca1e8972 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -20,10 +20,10 @@ _setup_angle, _setup_fill_arg, _setup_size, - DType, FillType, has_all, has_any, + ImageType, query_bounding_box, query_chw, ) @@ -179,7 +179,9 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform(self, inpt: DType, params: Dict[str, Any]) -> Tuple[DType, DType, DType, DType, DType]: + def _transform( + self, inpt: ImageType, params: Dict[str, Any] + ) -> Tuple[ImageType, ImageType, ImageType, ImageType, ImageType]: return F.five_crop(inpt, self.size) def forward(self, *inputs: Any) -> Any: @@ -200,7 +202,7 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _transform(self, inpt: DType, params: Dict[str, Any]) -> List[DType]: + def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> List[ImageType]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) def forward(self, *inputs: Any) -> Any: @@ -226,7 +228,7 @@ def __init__( self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None # This cast does Sequence[int] -> List[int] and is required to make mypy happy padding = self.padding @@ -271,7 +273,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) return F.pad(inpt, **params, fill=fill) @@ -302,7 +304,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) return F.rotate( inpt, @@ -384,7 +386,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) return F.affine( inpt, @@ -464,7 +466,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: # TODO: (PERF) check for speed optimization if we avoid repeated pad calls - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) if self.padding is not None: @@ -535,7 +537,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(perspective_coeffs=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) return F.perspective( inpt, @@ -603,7 +605,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) return F.elastic( inpt, @@ -862,7 +864,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["needs_pad"]: - fill = self.fill[type(inpt)] + fill = self.fill[type(inpt)] if self.fill is not None else None fill = F._geometry._convert_fill_arg(fill) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index d4dc87cf6f8..bd229640125 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -6,6 +6,8 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform +from ._utils import ImageType, TensorImageType + class ConvertBoundingBoxFormat(Transform): _transformed_types = (features.BoundingBox,) @@ -28,9 +30,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform( - self, inpt: Union[torch.Tensor, features.Image], params: Dict[str, Any] - ) -> Union[torch.Tensor, features.Image]: + def _transform(self, inpt: TensorImageType, params: Dict[str, Any]) -> TensorImageType: output = F.convert_image_dtype(inpt, dtype=self.dtype) return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type] @@ -56,9 +56,7 @@ def __init__( self.copy = copy - def _transform( - self, inpt: Union[torch.Tensor, PIL.Image.Image, features._Feature], params: Dict[str, Any] - ) -> Union[torch.Tensor, PIL.Image.Image, features._Feature]: + def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: return F.convert_color_space( inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 82e08259662..ab8188fb29c 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -8,7 +8,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ._utils import _setup_size, has_any, query_bounding_box +from ._utils import _setup_size, has_any, query_bounding_box, TensorImageType class Identity(Transform): @@ -68,7 +68,7 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) - def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: TensorImageType, params: Dict[str, Any]) -> torch.Tensor: # Image instance after linear transformation is not Image anymore due to unknown data range # Thus we will return Tensor for input Image @@ -101,7 +101,7 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = self.std = list(std) self.inplace = inplace - def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: TensorImageType, params: Dict[str, Any]) -> torch.Tensor: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) def forward(self, *inpts: Any) -> Any: diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 09829629e03..7bedbc89906 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -18,10 +18,12 @@ # Type shortcuts: DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] -FillType = Union[int, float, Sequence[int], Sequence[float]] +ImageType = Union[torch.Tensor, PIL.Image.Image, features.Image] +TensorImageType = Union[torch.Tensor, features.Image] +FillType = Union[int, float, Sequence[int], Sequence[float], None] -def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None: +def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: if isinstance(fill, dict): for key, value in fill.items(): # Check key for type @@ -31,12 +33,10 @@ def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> No raise TypeError("Got inappropriate fill arg") -def _setup_fill_arg( - fill: Optional[Union[FillType, Dict[Type, FillType]]] -) -> Union[Dict[Type, FillType], Dict[Type, None]]: +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Optional[Dict[Type, FillType]]: _check_fill_arg(fill) - if isinstance(fill, dict): + if fill is None or isinstance(fill, dict): return fill return defaultdict(lambda: fill) # type: ignore[return-value] From 06aad95bd4bac06dd7174015786012ee7f276466 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 22 Sep 2022 11:44:06 +0100 Subject: [PATCH 02/11] Move type definitions to `_utils.py` --- torchvision/prototype/features/__init__.py | 4 +- torchvision/prototype/features/_feature.py | 5 --- torchvision/prototype/features/_image.py | 5 --- .../transforms/functional/_augment.py | 6 ++- .../prototype/transforms/functional/_color.py | 24 +++++----- .../transforms/functional/_deprecated.py | 7 +-- .../transforms/functional/_geometry.py | 45 +++++++++---------- .../prototype/transforms/functional/_meta.py | 16 ++++--- .../prototype/transforms/functional/_misc.py | 7 +-- .../prototype/transforms/functional/_utils.py | 18 ++++++++ 10 files changed, 72 insertions(+), 65 deletions(-) create mode 100644 torchvision/prototype/transforms/functional/_utils.py diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 42ddd9aec27..5b2e6658acd 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -1,6 +1,6 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._encoded import EncodedData, EncodedImage, EncodedVideo -from ._feature import _Feature, DType, is_simple_tensor -from ._image import ColorSpace, Image, ImageType +from ._feature import _Feature, is_simple_tensor +from ._image import ColorSpace, Image from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 9cfccf33e54..909772e607e 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -10,11 +10,6 @@ F = TypeVar("F", bound="_Feature") -# Due to torch.jit.script limitation we keep DType as torch.Tensor -# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature] -DType = torch.Tensor - - def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 62f7f2849c3..2a943840178 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -12,11 +12,6 @@ from ._feature import _Feature -# Due to torch.jit.script limitation we keep ImageType as torch.Tensor -# instead of Union[torch.Tensor, PIL.Image.Image, features.Image] -ImageType = torch.Tensor - - class ColorSpace(StrEnum): OTHER = StrEnum.auto() GRAY = StrEnum.auto() diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index cb79ef98f0a..a5973241ea8 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,6 +5,8 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from ._utils import ImageType + erase_image_tensor = _FT.erase @@ -19,14 +21,14 @@ def erase_image_pil( def erase( - inpt: features.ImageType, + inpt: ImageType, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> features.ImageType: +) -> ImageType: if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index d3257dc76a0..219240068e8 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,12 +2,14 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT +from ._utils import DType + adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness -def adjust_brightness(inpt: features.DType, brightness_factor: float) -> features.DType: +def adjust_brightness(inpt: DType, brightness_factor: float) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) elif isinstance(inpt, features._Feature): @@ -20,7 +22,7 @@ def adjust_brightness(inpt: features.DType, brightness_factor: float) -> feature adjust_saturation_image_pil = _FP.adjust_saturation -def adjust_saturation(inpt: features.DType, saturation_factor: float) -> features.DType: +def adjust_saturation(inpt: DType, saturation_factor: float) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) elif isinstance(inpt, features._Feature): @@ -33,7 +35,7 @@ def adjust_saturation(inpt: features.DType, saturation_factor: float) -> feature adjust_contrast_image_pil = _FP.adjust_contrast -def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DType: +def adjust_contrast(inpt: DType, contrast_factor: float) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) elif isinstance(inpt, features._Feature): @@ -46,7 +48,7 @@ def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DT adjust_sharpness_image_pil = _FP.adjust_sharpness -def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features.DType: +def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) elif isinstance(inpt, features._Feature): @@ -59,7 +61,7 @@ def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features. adjust_hue_image_pil = _FP.adjust_hue -def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType: +def adjust_hue(inpt: DType, hue_factor: float) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) elif isinstance(inpt, features._Feature): @@ -72,7 +74,7 @@ def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType: adjust_gamma_image_pil = _FP.adjust_gamma -def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> features.DType: +def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) elif isinstance(inpt, features._Feature): @@ -85,7 +87,7 @@ def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> feature posterize_image_pil = _FP.posterize -def posterize(inpt: features.DType, bits: int) -> features.DType: +def posterize(inpt: DType, bits: int) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return posterize_image_tensor(inpt, bits=bits) elif isinstance(inpt, features._Feature): @@ -98,7 +100,7 @@ def posterize(inpt: features.DType, bits: int) -> features.DType: solarize_image_pil = _FP.solarize -def solarize(inpt: features.DType, threshold: float) -> features.DType: +def solarize(inpt: DType, threshold: float) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return solarize_image_tensor(inpt, threshold=threshold) elif isinstance(inpt, features._Feature): @@ -111,7 +113,7 @@ def solarize(inpt: features.DType, threshold: float) -> features.DType: autocontrast_image_pil = _FP.autocontrast -def autocontrast(inpt: features.DType) -> features.DType: +def autocontrast(inpt: DType) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return autocontrast_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -124,7 +126,7 @@ def autocontrast(inpt: features.DType) -> features.DType: equalize_image_pil = _FP.equalize -def equalize(inpt: features.DType) -> features.DType: +def equalize(inpt: DType) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return equalize_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -137,7 +139,7 @@ def equalize(inpt: features.DType) -> features.DType: invert_image_pil = _FP.invert -def invert(inpt: features.DType) -> features.DType: +def invert(inpt: DType) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return invert_image_tensor(inpt) elif isinstance(inpt, features._Feature): diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index 9c8dcff5a86..e981856d6ec 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -7,10 +7,7 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F - -# Due to torch.jit.script limitation we keep LegacyImageType as torch.Tensor -# instead of Union[torch.Tensor, PIL.Image.Image] -LegacyImageType = torch.Tensor +from ._utils import ImageType, LegacyImageType @torch.jit.unused @@ -61,7 +58,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: features.ImageType) -> List[int]: +def get_image_size(inpt: ImageType) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 666e4105992..d8c57bed71f 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -17,6 +17,7 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor +from ._utils import DType, ImageType horizontal_flip_image_tensor = _FT.hflip @@ -43,7 +44,7 @@ def horizontal_flip_bounding_box( ).view(shape) -def horizontal_flip(inpt: features.DType) -> features.DType: +def horizontal_flip(inpt: DType) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return horizontal_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -76,7 +77,7 @@ def vertical_flip_bounding_box( ).view(shape) -def vertical_flip(inpt: features.DType) -> features.DType: +def vertical_flip(inpt: DType) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return vertical_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -153,12 +154,12 @@ def resize_bounding_box( def resize( - inpt: features.DType, + inpt: DType, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -421,7 +422,7 @@ def _convert_fill_arg( def affine( - inpt: features.DType, + inpt: DType, angle: float, translate: List[float], scale: float, @@ -429,7 +430,7 @@ def affine( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[Union[int, float, List[float]]] = None, center: Optional[List[float]] = None, -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return affine_image_tensor( inpt, @@ -567,13 +568,13 @@ def rotate_mask( def rotate( - inpt: features.DType, + inpt: DType, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, fill: Optional[Union[int, float, List[float]]] = None, center: Optional[List[float]] = None, -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, features._Feature): @@ -698,11 +699,11 @@ def pad_bounding_box( def pad( - inpt: features.DType, + inpt: DType, padding: Union[int, List[int]], fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) @@ -739,7 +740,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) return crop_image_tensor(mask, top, left, height, width) -def crop(inpt: features.DType, top: int, left: int, height: int, width: int) -> features.DType: +def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return crop_image_tensor(inpt, top, left, height, width) elif isinstance(inpt, features._Feature): @@ -874,11 +875,11 @@ def perspective_mask( def perspective( - inpt: features.DType, + inpt: DType, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, List[float]]] = None, -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -968,11 +969,11 @@ def elastic_mask( def elastic( - inpt: features.DType, + inpt: DType, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, List[float]]] = None, -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -1069,7 +1070,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor return output -def center_crop(inpt: features.DType, output_size: List[int]) -> features.DType: +def center_crop(inpt: DType, output_size: List[int]) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return center_crop_image_tensor(inpt, output_size) elif isinstance(inpt, features._Feature): @@ -1132,7 +1133,7 @@ def resized_crop_mask( def resized_crop( - inpt: features.DType, + inpt: DType, top: int, left: int, height: int, @@ -1140,7 +1141,7 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, -) -> features.DType: +) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resized_crop_image_tensor( @@ -1204,10 +1205,8 @@ def five_crop_image_pil( return tl, tr, bl, br, center -def five_crop( - inpt: features.ImageType, size: List[int] -) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: - # TODO: consider breaking BC here to return List[features.ImageType] to align this op with `ten_crop` +def five_crop(inpt: ImageType, size: List[int]) -> Tuple[ImageType, ImageType, ImageType, ImageType, ImageType]: + # TODO: consider breaking BC here to return List[ImageType] to align this op with `ten_crop` if isinstance(inpt, torch.Tensor): output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): @@ -1244,7 +1243,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] -def ten_crop(inpt: features.ImageType, size: List[int], vertical_flip: bool = False) -> List[features.ImageType]: +def ten_crop(inpt: ImageType, size: List[int], vertical_flip: bool = False) -> List[ImageType]: if isinstance(inpt, torch.Tensor): output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index b89ea09b726..958edc248c9 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -6,12 +6,14 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT +from ._utils import ImageType + get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: features.ImageType) -> Tuple[int, int, int]: +def get_chw(image: ImageType) -> Tuple[int, int, int]: if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): channels, height, width = get_dimensions_image_tensor(image) elif isinstance(image, features.Image): @@ -29,11 +31,11 @@ def get_chw(image: features.ImageType) -> Tuple[int, int, int]: # detailed above. -def get_dimensions(image: features.ImageType) -> List[int]: +def get_dimensions(image: ImageType) -> List[int]: return list(get_chw(image)) -def get_num_channels(image: features.ImageType) -> int: +def get_num_channels(image: ImageType) -> int: num_channels, *_ = get_chw(image) return num_channels @@ -43,7 +45,7 @@ def get_num_channels(image: features.ImageType) -> int: get_image_num_channels = get_num_channels -def get_spatial_size(image: features.ImageType) -> List[int]: +def get_spatial_size(image: ImageType) -> List[int]: _, *size = get_chw(image) return size @@ -208,11 +210,11 @@ def convert_color_space_image_pil( def convert_color_space( - inpt: features.ImageType, + inpt: ImageType, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True, -) -> features.ImageType: +) -> ImageType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): if old_color_space is None: raise RuntimeError( @@ -225,4 +227,4 @@ def convert_color_space( elif isinstance(inpt, features.Image): return inpt.to_color_space(color_space, copy=copy) else: - return cast(features.ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy)) + return cast(ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy)) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 4c8c4697d67..58bde5b78ec 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -6,10 +6,7 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image - -# Due to torch.jit.script limitation we keep TensorImageType as torch.Tensor -# instead of Union[torch.Tensor, features.Image] -TensorImageType = torch.Tensor +from ._utils import DType, TensorImageType normalize_image_tensor = _FT.normalize @@ -62,7 +59,7 @@ def gaussian_blur_image_pil( return to_pil_image(output, mode=img.mode) -def gaussian_blur(inpt: features.DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> features.DType: +def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, features._Feature): diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py new file mode 100644 index 00000000000..68756043359 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -0,0 +1,18 @@ +import torch + + +# The types defined in this file should be mirroring the ones in `.._utils.py` +# Unfortunately due to torch.jit.script limitations we must use fake types +# Keeping track of the actual types is useful in-case this limitation is lifted + +# Real type: Union[torch.Tensor, PIL.Image.Image, features._Feature] +DType = torch.Tensor + +# Real type: Union[torch.Tensor, PIL.Image.Image, features.Image] +ImageType = torch.Tensor + +# Real type: Union[torch.Tensor, PIL.Image.Image] +LegacyImageType = torch.Tensor + +# Real type: Union[torch.Tensor, features.Image] +TensorImageType = torch.Tensor From e5b36b66fe20eea924a1f1df25d3b7cf2b67cbe9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 22 Sep 2022 14:05:36 +0100 Subject: [PATCH 03/11] fixing error message on tests --- test/test_prototype_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 278b3811aeb..b9c89b2b76a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -132,7 +132,7 @@ def test_mixup_cutmix(self, transform, input): transform(input_copy) # Check if we raise an error if sample contains bbox or mask or label - err_msg = "does not support bounding boxes, masks and plain labels" + err_msg = "does not support PIL images, bounding boxes, masks and plain labels" input_copy = dict(input) for unsup_data in [ make_label(), From 7e61d750c57ef568ab96e3d67938419c50d885e7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 09:44:30 +0100 Subject: [PATCH 04/11] Apply code review suggestions Co-authored-by: vfdev --- torchvision/prototype/transforms/_auto_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index adcd722f210..cc1ae387fd0 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -63,7 +63,7 @@ def _apply_image_transform( transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Optional[Dict[Type, Optional[FillType]]], + fill: Optional[Dict[Type, FillType]], ) -> ImageType: fill_ = fill[type(image)] if fill is not None else None fill_ = F._geometry._convert_fill_arg(fill_) From 1edbc064887cae0a4f7ddc32a47c1b4a8cd631df Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 10:09:00 +0100 Subject: [PATCH 05/11] Centralizing types and switching to always getting dicts. --- .../prototype/features/_bounding_box.py | 11 ++--- torchvision/prototype/features/_feature.py | 13 +++--- torchvision/prototype/features/_image.py | 11 ++--- torchvision/prototype/features/_mask.py | 11 ++--- torchvision/prototype/features/_utils.py | 4 ++ .../prototype/transforms/_auto_augment.py | 4 +- torchvision/prototype/transforms/_geometry.py | 16 +++---- torchvision/prototype/transforms/_utils.py | 6 +-- .../transforms/functional/_geometry.py | 44 +++++++++---------- .../prototype/transforms/functional/_utils.py | 6 +++ 10 files changed, 70 insertions(+), 56 deletions(-) create mode 100644 torchvision/prototype/features/_utils.py diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 47473e7dc7c..1e67565a411 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -7,6 +7,7 @@ from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from ._feature import _Feature +from ._utils import FillType class BoundingBoxFormat(StrEnum): @@ -115,7 +116,7 @@ def resized_crop( def pad( self, padding: Union[int, Sequence[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, padding_mode: str = "constant", ) -> BoundingBox: # This cast does Sequence[int] -> List[int] and is required to make mypy happy @@ -137,7 +138,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> BoundingBox: output = self._F.rotate_bounding_box( @@ -165,7 +166,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> BoundingBox: output = self._F.affine_bounding_box( @@ -184,7 +185,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> BoundingBox: output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) @@ -193,7 +194,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> BoundingBox: output = self._F.elastic_bounding_box(self, self.format, displacement) return BoundingBox.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 909772e607e..fdc1e545b9f 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -7,6 +7,9 @@ from torch._C import _TensorBase, DisableTorchFunction from torchvision.transforms import InterpolationMode +from ._utils import FillType + + F = TypeVar("F", bound="_Feature") @@ -149,7 +152,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, padding_mode: str = "constant", ) -> _Feature: return self @@ -159,7 +162,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> _Feature: return self @@ -171,7 +174,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> _Feature: return self @@ -180,7 +183,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> _Feature: return self @@ -188,7 +191,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> _Feature: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 2a943840178..66f1a803045 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -10,6 +10,7 @@ from ._bounding_box import BoundingBox from ._feature import _Feature +from ._utils import FillType class ColorSpace(StrEnum): @@ -176,7 +177,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, padding_mode: str = "constant", ) -> Image: output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) @@ -187,7 +188,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> Image: output = self._F._geometry.rotate_image_tensor( @@ -202,7 +203,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> Image: output = self._F._geometry.affine_image_tensor( @@ -221,7 +222,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> Image: output = self._F._geometry.perspective_image_tensor( self, perspective_coeffs, interpolation=interpolation, fill=fill @@ -232,7 +233,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> Image: output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 29858578c23..b7f2a4d3569 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -6,6 +6,7 @@ from torchvision.transforms import InterpolationMode from ._feature import _Feature +from ._utils import FillType class Mask(_Feature): @@ -51,7 +52,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, padding_mode: str = "constant", ) -> Mask: output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) @@ -62,7 +63,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> Mask: output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill) @@ -75,7 +76,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> Mask: output = self._F.affine_mask( @@ -93,7 +94,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> Mask: output = self._F.perspective_mask(self, perspective_coeffs, fill=fill) return Mask.new_like(self, output) @@ -102,7 +103,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> Mask: output = self._F.elastic_mask(self, displacement, fill=fill) return Mask.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_utils.py b/torchvision/prototype/features/_utils.py new file mode 100644 index 00000000000..758049a5e5a --- /dev/null +++ b/torchvision/prototype/features/_utils.py @@ -0,0 +1,4 @@ +from typing import List, Union + +# Same definition as in `functional._utils.py`. Copied here to avoid cyclical dependencies. +FillType = Union[int, float, List[float], None] diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index cc1ae387fd0..414b93a85a7 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -63,9 +63,9 @@ def _apply_image_transform( transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Optional[Dict[Type, FillType]], + fill: Dict[Type, FillType], ) -> ImageType: - fill_ = fill[type(image)] if fill is not None else None + fill_ = fill[type(image)] fill_ = F._geometry._convert_fill_arg(fill_) if transform_id == "Identity": diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 5e1ca1e8972..caed0f592e7 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -228,7 +228,7 @@ def __init__( self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] # This cast does Sequence[int] -> List[int] and is required to make mypy happy padding = self.padding @@ -273,7 +273,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) return F.pad(inpt, **params, fill=fill) @@ -304,7 +304,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) return F.rotate( inpt, @@ -386,7 +386,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) return F.affine( inpt, @@ -466,7 +466,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: # TODO: (PERF) check for speed optimization if we avoid repeated pad calls - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) if self.padding is not None: @@ -537,7 +537,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(perspective_coeffs=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) return F.perspective( inpt, @@ -605,7 +605,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) return F.elastic( inpt, @@ -864,7 +864,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["needs_pad"]: - fill = self.fill[type(inpt)] if self.fill is not None else None + fill = self.fill[type(inpt)] fill = F._geometry._convert_fill_arg(fill) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7bedbc89906..4209332e263 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -33,13 +33,13 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: raise TypeError("Got inappropriate fill arg") -def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Optional[Dict[Type, FillType]]: +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: _check_fill_arg(fill) - if fill is None or isinstance(fill, dict): + if isinstance(fill, dict): return fill - return defaultdict(lambda: fill) # type: ignore[return-value] + return defaultdict(lambda: fill) # type: ignore[return-value, arg-type] def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d8c57bed71f..f0c490fee8c 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -17,7 +17,7 @@ from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor -from ._utils import DType, ImageType +from ._utils import DType, FillType, ImageType horizontal_flip_image_tensor = _FT.hflip @@ -229,7 +229,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if img.numel() == 0: @@ -261,7 +261,7 @@ def affine_image_pil( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) @@ -379,7 +379,7 @@ def affine_mask( translate: List[float], scale: float, shear: List[float], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -405,9 +405,7 @@ def affine_mask( return output -def _convert_fill_arg( - fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] -) -> Optional[Union[int, float, List[float]]]: +def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> FillType: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: @@ -428,7 +426,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): @@ -464,7 +462,7 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] @@ -503,7 +501,7 @@ def rotate_image_pil( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: @@ -543,7 +541,7 @@ def rotate_mask( mask: torch.Tensor, angle: float, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -572,7 +570,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, center: Optional[List[float]] = None, ) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): @@ -589,7 +587,7 @@ def rotate( def pad_image_tensor( img: torch.Tensor, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, padding_mode: str = "constant", ) -> torch.Tensor: if fill is None: @@ -653,7 +651,7 @@ def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant", - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> torch.Tensor: if fill is None: fill = 0 @@ -701,7 +699,7 @@ def pad_bounding_box( def pad( inpt: DType, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, padding_mode: str = "constant", ) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): @@ -753,7 +751,7 @@ def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> torch.Tensor: return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) @@ -763,7 +761,7 @@ def perspective_image_pil( img: PIL.Image.Image, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BICUBIC, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> PIL.Image.Image: return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) @@ -856,7 +854,7 @@ def perspective_bounding_box( def perspective_mask( mask: torch.Tensor, perspective_coeffs: List[float], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -878,7 +876,7 @@ def perspective( inpt: DType, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) @@ -892,7 +890,7 @@ def elastic_image_tensor( img: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> torch.Tensor: return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) @@ -902,7 +900,7 @@ def elastic_image_pil( img: PIL.Image.Image, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(img) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) @@ -952,7 +950,7 @@ def elastic_bounding_box( def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -972,7 +970,7 @@ def elastic( inpt: DType, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillType = None, ) -> DType: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 68756043359..19d451eb215 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,3 +1,5 @@ +from typing import List, Union + import torch @@ -16,3 +18,7 @@ # Real type: Union[torch.Tensor, features.Image] TensorImageType = torch.Tensor + +# Similarly, JIT doesn't support Sequencies and can't support at the same time, Lists of floats and ints. +# Ideal type: Union[int, float, Sequence[int], Sequence[float], None] +FillType = Union[int, float, List[float], None] From 4eb82f4ce7aaafd977193d2df185c8a448a70249 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 10:15:59 +0100 Subject: [PATCH 06/11] Fixing linter --- torchvision/prototype/transforms/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 4209332e263..7ee5369d8e0 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,7 +1,7 @@ import numbers from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union import PIL.Image From ad60b5084486daa327bd8e603d2cb9b01d692649 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 12:22:39 +0100 Subject: [PATCH 07/11] Refactoring typing definitions. --- .../prototype/features/_bounding_box.py | 13 ++- torchvision/prototype/features/_feature.py | 20 +++-- torchvision/prototype/features/_image.py | 22 +++-- torchvision/prototype/features/_mask.py | 13 ++- torchvision/prototype/features/_utils.py | 4 - torchvision/prototype/transforms/_augment.py | 4 +- .../prototype/transforms/_auto_augment.py | 5 +- torchvision/prototype/transforms/_color.py | 4 +- .../prototype/transforms/_deprecated.py | 4 +- torchvision/prototype/transforms/_geometry.py | 5 +- torchvision/prototype/transforms/_meta.py | 2 +- torchvision/prototype/transforms/_misc.py | 4 +- torchvision/prototype/transforms/_utils.py | 9 +- .../transforms/functional/_augment.py | 7 +- .../prototype/transforms/functional/_color.py | 25 +++--- .../transforms/functional/_deprecated.py | 6 +- .../transforms/functional/_geometry.py | 89 ++++++++++--------- .../prototype/transforms/functional/_meta.py | 16 ++-- .../prototype/transforms/functional/_misc.py | 7 +- .../prototype/transforms/functional/_utils.py | 24 ----- 20 files changed, 136 insertions(+), 147 deletions(-) delete mode 100644 torchvision/prototype/features/_utils.py delete mode 100644 torchvision/prototype/transforms/functional/_utils.py diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 1e67565a411..0b61439d10c 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -6,8 +6,7 @@ from torchvision._utils import StrEnum from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms -from ._feature import _Feature -from ._utils import FillType +from ._feature import _Feature, FillTypeJIT class BoundingBoxFormat(StrEnum): @@ -116,7 +115,7 @@ def resized_crop( def pad( self, padding: Union[int, Sequence[int]], - fill: FillType = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> BoundingBox: # This cast does Sequence[int] -> List[int] and is required to make mypy happy @@ -138,7 +137,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> BoundingBox: output = self._F.rotate_bounding_box( @@ -166,7 +165,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> BoundingBox: output = self._F.affine_bounding_box( @@ -185,7 +184,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) @@ -194,7 +193,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.elastic_bounding_box(self, self.format, displacement) return BoundingBox.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index fdc1e545b9f..3d4357b9a99 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -3,14 +3,14 @@ from types import ModuleType from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +import PIL.Image import torch from torch._C import _TensorBase, DisableTorchFunction from torchvision.transforms import InterpolationMode -from ._utils import FillType - - F = TypeVar("F", bound="_Feature") +FillType = Union[int, float, Sequence[int], Sequence[float], None] +FillTypeJIT = Union[int, float, List[float], None] def is_simple_tensor(inpt: Any) -> bool: @@ -152,7 +152,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: FillType = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> _Feature: return self @@ -162,7 +162,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> _Feature: return self @@ -174,7 +174,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> _Feature: return self @@ -183,7 +183,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> _Feature: return self @@ -191,7 +191,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> _Feature: return self @@ -230,3 +230,7 @@ def invert(self) -> _Feature: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: return self + + +InputType = Union[torch.Tensor, PIL.Image.Image, _Feature] +InputTypeJIT = torch.Tensor diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 66f1a803045..0b832ae0270 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -3,14 +3,14 @@ import warnings from typing import Any, cast, List, Optional, Tuple, Union +import PIL.Image import torch from torchvision._utils import StrEnum from torchvision.transforms.functional import InterpolationMode, to_pil_image from torchvision.utils import draw_bounding_boxes, make_grid from ._bounding_box import BoundingBox -from ._feature import _Feature -from ._utils import FillType +from ._feature import _Feature, FillTypeJIT class ColorSpace(StrEnum): @@ -177,7 +177,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: FillType = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> Image: output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) @@ -188,7 +188,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Image: output = self._F._geometry.rotate_image_tensor( @@ -203,7 +203,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Image: output = self._F._geometry.affine_image_tensor( @@ -222,7 +222,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> Image: output = self._F._geometry.perspective_image_tensor( self, perspective_coeffs, interpolation=interpolation, fill=fill @@ -233,7 +233,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> Image: output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) return Image.new_like(self, output) @@ -285,3 +285,11 @@ def invert(self) -> Image: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma) return Image.new_like(self, output) + + +ImageType = Union[torch.Tensor, PIL.Image.Image, Image] +ImageTypeJIT = torch.Tensor +LegacyImageType = Union[torch.Tensor, PIL.Image.Image] +LegacyImageTypeJIT = torch.Tensor +TensorImageType = Union[torch.Tensor, Image] +TensorImageTypeJIT = torch.Tensor diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index b7f2a4d3569..a0c3395dbe7 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -5,8 +5,7 @@ import torch from torchvision.transforms import InterpolationMode -from ._feature import _Feature -from ._utils import FillType +from ._feature import _Feature, FillTypeJIT class Mask(_Feature): @@ -52,7 +51,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: FillType = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> Mask: output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) @@ -63,7 +62,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Mask: output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill) @@ -76,7 +75,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Mask: output = self._F.affine_mask( @@ -94,7 +93,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> Mask: output = self._F.perspective_mask(self, perspective_coeffs, fill=fill) return Mask.new_like(self, output) @@ -103,7 +102,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> Mask: output = self._F.elastic_mask(self, displacement, fill=fill) return Mask.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_utils.py b/torchvision/prototype/features/_utils.py deleted file mode 100644 index 758049a5e5a..00000000000 --- a/torchvision/prototype/features/_utils.py +++ /dev/null @@ -1,4 +0,0 @@ -from typing import List, Union - -# Same definition as in `functional._utils.py`. Copied here to avoid cyclical dependencies. -FillType = Union[int, float, List[float], None] diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 84895d3457e..9811673314c 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -10,8 +10,10 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, InterpolationMode +from ..features._image import ImageType, TensorImageType + from ._transform import _RandomApplyTransform -from ._utils import has_any, ImageType, query_chw, TensorImageType +from ._utils import has_any, query_chw class RandomErasing(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 414b93a85a7..cc54a5059af 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -9,7 +9,10 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_chw -from ._utils import _isinstance, _setup_fill_arg, FillType, ImageType +from ..features._feature import FillType +from ..features._image import ImageType + +from ._utils import _isinstance, _setup_fill_arg K = TypeVar("K") V = TypeVar("V") diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index a23b3a07400..82d7f6c2b5b 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -6,8 +6,10 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform +from ..features._image import ImageType + from ._transform import _RandomApplyTransform -from ._utils import ImageType, query_chw +from ._utils import query_chw class ColorJitter(Transform): diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 2b4aa21a8b7..cbfb9239a13 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -10,8 +10,10 @@ from torchvision.transforms import functional as _F from typing_extensions import Literal +from ..features._image import ImageType + from ._transform import _RandomApplyTransform -from ._utils import ImageType, query_chw +from ._utils import query_chw class ToTensor(Transform): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index caed0f592e7..4dee4e6d9b0 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -12,6 +12,9 @@ from typing_extensions import Literal +from ..features._feature import FillType +from ..features._image import ImageType + from ._transform import _RandomApplyTransform from ._utils import ( _check_padding_arg, @@ -20,10 +23,8 @@ _setup_angle, _setup_fill_arg, _setup_size, - FillType, has_all, has_any, - ImageType, query_bounding_box, query_chw, ) diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index bd229640125..a1f6ff9bdc7 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -6,7 +6,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ._utils import ImageType, TensorImageType +from ..features._image import ImageType, TensorImageType class ConvertBoundingBoxFormat(Transform): diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index ab8188fb29c..12ad2a375d7 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -8,7 +8,9 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ._utils import _setup_size, has_any, query_bounding_box, TensorImageType +from ..features._image import TensorImageType + +from ._utils import _setup_size, has_any, query_bounding_box class Identity(Transform): diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7ee5369d8e0..7107b14b3e0 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -5,10 +5,10 @@ import PIL.Image -import torch from torch.utils._pytree import tree_flatten from torchvision._utils import sequence_to_str from torchvision.prototype import features +from torchvision.prototype.features._feature import FillType from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 @@ -16,13 +16,6 @@ from typing_extensions import Literal -# Type shortcuts: -DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] -ImageType = Union[torch.Tensor, PIL.Image.Image, features.Image] -TensorImageType = Union[torch.Tensor, features.Image] -FillType = Union[int, float, Sequence[int], Sequence[float], None] - - def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: if isinstance(fill, dict): for key, value in fill.items(): diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index a5973241ea8..d1ffeaa8735 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,8 +5,7 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import ImageType - +from ...features._image import ImageTypeJIT erase_image_tensor = _FT.erase @@ -21,14 +20,14 @@ def erase_image_pil( def erase( - inpt: ImageType, + inpt: ImageTypeJIT, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> ImageType: +) -> ImageTypeJIT: if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 219240068e8..dd27757e5c7 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,14 +2,13 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._utils import DType - +from ...features._feature import InputTypeJIT adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness -def adjust_brightness(inpt: DType, brightness_factor: float) -> DType: +def adjust_brightness(inpt: InputTypeJIT, brightness_factor: float) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) elif isinstance(inpt, features._Feature): @@ -22,7 +21,7 @@ def adjust_brightness(inpt: DType, brightness_factor: float) -> DType: adjust_saturation_image_pil = _FP.adjust_saturation -def adjust_saturation(inpt: DType, saturation_factor: float) -> DType: +def adjust_saturation(inpt: InputTypeJIT, saturation_factor: float) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) elif isinstance(inpt, features._Feature): @@ -35,7 +34,7 @@ def adjust_saturation(inpt: DType, saturation_factor: float) -> DType: adjust_contrast_image_pil = _FP.adjust_contrast -def adjust_contrast(inpt: DType, contrast_factor: float) -> DType: +def adjust_contrast(inpt: InputTypeJIT, contrast_factor: float) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) elif isinstance(inpt, features._Feature): @@ -48,7 +47,7 @@ def adjust_contrast(inpt: DType, contrast_factor: float) -> DType: adjust_sharpness_image_pil = _FP.adjust_sharpness -def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType: +def adjust_sharpness(inpt: InputTypeJIT, sharpness_factor: float) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) elif isinstance(inpt, features._Feature): @@ -61,7 +60,7 @@ def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType: adjust_hue_image_pil = _FP.adjust_hue -def adjust_hue(inpt: DType, hue_factor: float) -> DType: +def adjust_hue(inpt: InputTypeJIT, hue_factor: float) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) elif isinstance(inpt, features._Feature): @@ -74,7 +73,7 @@ def adjust_hue(inpt: DType, hue_factor: float) -> DType: adjust_gamma_image_pil = _FP.adjust_gamma -def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType: +def adjust_gamma(inpt: InputTypeJIT, gamma: float, gain: float = 1) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) elif isinstance(inpt, features._Feature): @@ -87,7 +86,7 @@ def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType: posterize_image_pil = _FP.posterize -def posterize(inpt: DType, bits: int) -> DType: +def posterize(inpt: InputTypeJIT, bits: int) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return posterize_image_tensor(inpt, bits=bits) elif isinstance(inpt, features._Feature): @@ -100,7 +99,7 @@ def posterize(inpt: DType, bits: int) -> DType: solarize_image_pil = _FP.solarize -def solarize(inpt: DType, threshold: float) -> DType: +def solarize(inpt: InputTypeJIT, threshold: float) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return solarize_image_tensor(inpt, threshold=threshold) elif isinstance(inpt, features._Feature): @@ -113,7 +112,7 @@ def solarize(inpt: DType, threshold: float) -> DType: autocontrast_image_pil = _FP.autocontrast -def autocontrast(inpt: DType) -> DType: +def autocontrast(inpt: InputTypeJIT) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return autocontrast_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -126,7 +125,7 @@ def autocontrast(inpt: DType) -> DType: equalize_image_pil = _FP.equalize -def equalize(inpt: DType) -> DType: +def equalize(inpt: InputTypeJIT) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return equalize_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -139,7 +138,7 @@ def equalize(inpt: DType) -> DType: invert_image_pil = _FP.invert -def invert(inpt: DType) -> DType: +def invert(inpt: InputTypeJIT) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return invert_image_tensor(inpt) elif isinstance(inpt, features._Feature): diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index e981856d6ec..20ed419bf04 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F -from ._utils import ImageType, LegacyImageType +from ...features._image import ImageTypeJIT, LegacyImageTypeJIT @torch.jit.unused @@ -24,7 +24,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima return _F.to_grayscale(inpt, num_output_channels=num_output_channels) -def rgb_to_grayscale(inpt: LegacyImageType, num_output_channels: int = 1) -> LegacyImageType: +def rgb_to_grayscale(inpt: LegacyImageTypeJIT, num_output_channels: int = 1) -> LegacyImageTypeJIT: old_color_space = ( features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)) @@ -58,7 +58,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: ImageType) -> List[int]: +def get_image_size(inpt: ImageTypeJIT) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index f0c490fee8c..6e51bd6c318 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -16,9 +16,10 @@ ) from torchvision.transforms.functional_tensor import _parse_pad_padding -from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor -from ._utils import DType, FillType, ImageType +from ...features._feature import FillTypeJIT, InputTypeJIT +from ...features._image import ImageTypeJIT +from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -44,7 +45,7 @@ def horizontal_flip_bounding_box( ).view(shape) -def horizontal_flip(inpt: DType) -> DType: +def horizontal_flip(inpt: InputTypeJIT) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return horizontal_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -77,7 +78,7 @@ def vertical_flip_bounding_box( ).view(shape) -def vertical_flip(inpt: DType) -> DType: +def vertical_flip(inpt: InputTypeJIT) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return vertical_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -154,12 +155,12 @@ def resize_bounding_box( def resize( - inpt: DType, + inpt: InputTypeJIT, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, -) -> DType: +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -229,7 +230,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if img.numel() == 0: @@ -261,7 +262,7 @@ def affine_image_pil( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) @@ -379,7 +380,7 @@ def affine_mask( translate: List[float], scale: float, shear: List[float], - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -405,7 +406,7 @@ def affine_mask( return output -def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> FillType: +def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: @@ -420,15 +421,15 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f def affine( - inpt: DType, + inpt: InputTypeJIT, angle: float, translate: List[float], scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, -) -> DType: +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return affine_image_tensor( inpt, @@ -462,7 +463,7 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] @@ -501,7 +502,7 @@ def rotate_image_pil( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: @@ -541,7 +542,7 @@ def rotate_mask( mask: torch.Tensor, angle: float, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -566,13 +567,13 @@ def rotate_mask( def rotate( - inpt: DType, + inpt: InputTypeJIT, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillType = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, -) -> DType: +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, features._Feature): @@ -587,7 +588,7 @@ def rotate( def pad_image_tensor( img: torch.Tensor, padding: Union[int, List[int]], - fill: FillType = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: if fill is None: @@ -651,7 +652,7 @@ def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant", - fill: FillType = None, + fill: FillTypeJIT = None, ) -> torch.Tensor: if fill is None: fill = 0 @@ -697,11 +698,11 @@ def pad_bounding_box( def pad( - inpt: DType, + inpt: InputTypeJIT, padding: Union[int, List[int]], - fill: FillType = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", -) -> DType: +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) @@ -738,7 +739,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) return crop_image_tensor(mask, top, left, height, width) -def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: +def crop(inpt: InputTypeJIT, top: int, left: int, height: int, width: int) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return crop_image_tensor(inpt, top, left, height, width) elif isinstance(inpt, features._Feature): @@ -751,7 +752,7 @@ def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> torch.Tensor: return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) @@ -761,7 +762,7 @@ def perspective_image_pil( img: PIL.Image.Image, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BICUBIC, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> PIL.Image.Image: return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) @@ -854,7 +855,7 @@ def perspective_bounding_box( def perspective_mask( mask: torch.Tensor, perspective_coeffs: List[float], - fill: FillType = None, + fill: FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -873,11 +874,11 @@ def perspective_mask( def perspective( - inpt: DType, + inpt: InputTypeJIT, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, -) -> DType: + fill: FillTypeJIT = None, +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -890,7 +891,7 @@ def elastic_image_tensor( img: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> torch.Tensor: return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) @@ -900,7 +901,7 @@ def elastic_image_pil( img: PIL.Image.Image, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(img) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) @@ -950,7 +951,7 @@ def elastic_bounding_box( def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, - fill: FillType = None, + fill: FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -967,11 +968,11 @@ def elastic_mask( def elastic( - inpt: DType, + inpt: InputTypeJIT, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillType = None, -) -> DType: + fill: FillTypeJIT = None, +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -1068,7 +1069,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor return output -def center_crop(inpt: DType, output_size: List[int]) -> DType: +def center_crop(inpt: InputTypeJIT, output_size: List[int]) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return center_crop_image_tensor(inpt, output_size) elif isinstance(inpt, features._Feature): @@ -1131,7 +1132,7 @@ def resized_crop_mask( def resized_crop( - inpt: DType, + inpt: InputTypeJIT, top: int, left: int, height: int, @@ -1139,7 +1140,7 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, -) -> DType: +) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resized_crop_image_tensor( @@ -1203,8 +1204,10 @@ def five_crop_image_pil( return tl, tr, bl, br, center -def five_crop(inpt: ImageType, size: List[int]) -> Tuple[ImageType, ImageType, ImageType, ImageType, ImageType]: - # TODO: consider breaking BC here to return List[ImageType] to align this op with `ten_crop` +def five_crop( + inpt: ImageTypeJIT, size: List[int] +) -> Tuple[ImageTypeJIT, ImageTypeJIT, ImageTypeJIT, ImageTypeJIT, ImageTypeJIT]: + # TODO: consider breaking BC here to return List[ImageTypeJIT] to align this op with `ten_crop` if isinstance(inpt, torch.Tensor): output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): @@ -1241,7 +1244,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] -def ten_crop(inpt: ImageType, size: List[int], vertical_flip: bool = False) -> List[ImageType]: +def ten_crop(inpt: ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[ImageTypeJIT]: if isinstance(inpt, torch.Tensor): output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 958edc248c9..894d4d1b36b 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -6,14 +6,14 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._utils import ImageType +from ...features._image import ImageTypeJIT get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: ImageType) -> Tuple[int, int, int]: +def get_chw(image: ImageTypeJIT) -> Tuple[int, int, int]: if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): channels, height, width = get_dimensions_image_tensor(image) elif isinstance(image, features.Image): @@ -31,11 +31,11 @@ def get_chw(image: ImageType) -> Tuple[int, int, int]: # detailed above. -def get_dimensions(image: ImageType) -> List[int]: +def get_dimensions(image: ImageTypeJIT) -> List[int]: return list(get_chw(image)) -def get_num_channels(image: ImageType) -> int: +def get_num_channels(image: ImageTypeJIT) -> int: num_channels, *_ = get_chw(image) return num_channels @@ -45,7 +45,7 @@ def get_num_channels(image: ImageType) -> int: get_image_num_channels = get_num_channels -def get_spatial_size(image: ImageType) -> List[int]: +def get_spatial_size(image: ImageTypeJIT) -> List[int]: _, *size = get_chw(image) return size @@ -210,11 +210,11 @@ def convert_color_space_image_pil( def convert_color_space( - inpt: ImageType, + inpt: ImageTypeJIT, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True, -) -> ImageType: +) -> ImageTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): if old_color_space is None: raise RuntimeError( @@ -227,4 +227,4 @@ def convert_color_space( elif isinstance(inpt, features.Image): return inpt.to_color_space(color_space, copy=copy) else: - return cast(ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy)) + return cast(ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 58bde5b78ec..848939838c7 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -6,13 +6,14 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import DType, TensorImageType +from ...features._feature import InputTypeJIT +from ...features._image import TensorImageTypeJIT normalize_image_tensor = _FT.normalize -def normalize(inpt: TensorImageType, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: +def normalize(inpt: TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: if not isinstance(inpt, torch.Tensor): raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") else: @@ -59,7 +60,7 @@ def gaussian_blur_image_pil( return to_pil_image(output, mode=img.mode) -def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType: +def gaussian_blur(inpt: InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None) -> InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, features._Feature): diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py deleted file mode 100644 index 19d451eb215..00000000000 --- a/torchvision/prototype/transforms/functional/_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import List, Union - -import torch - - -# The types defined in this file should be mirroring the ones in `.._utils.py` -# Unfortunately due to torch.jit.script limitations we must use fake types -# Keeping track of the actual types is useful in-case this limitation is lifted - -# Real type: Union[torch.Tensor, PIL.Image.Image, features._Feature] -DType = torch.Tensor - -# Real type: Union[torch.Tensor, PIL.Image.Image, features.Image] -ImageType = torch.Tensor - -# Real type: Union[torch.Tensor, PIL.Image.Image] -LegacyImageType = torch.Tensor - -# Real type: Union[torch.Tensor, features.Image] -TensorImageType = torch.Tensor - -# Similarly, JIT doesn't support Sequencies and can't support at the same time, Lists of floats and ints. -# Ideal type: Union[int, float, Sequence[int], Sequence[float], None] -FillType = Union[int, float, List[float], None] From 211f2947362ba681d8a88fa67484713235f8e7fc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 12:45:39 +0100 Subject: [PATCH 08/11] Remove relative imports. --- torchvision/prototype/features/__init__.py | 13 ++- torchvision/prototype/transforms/_augment.py | 19 ++-- .../prototype/transforms/_auto_augment.py | 21 ++--- torchvision/prototype/transforms/_color.py | 6 +- .../prototype/transforms/_deprecated.py | 6 +- torchvision/prototype/transforms/_geometry.py | 25 +++--- torchvision/prototype/transforms/_meta.py | 6 +- torchvision/prototype/transforms/_misc.py | 6 +- .../transforms/functional/_augment.py | 6 +- .../prototype/transforms/functional/_color.py | 24 +++-- .../transforms/functional/_deprecated.py | 6 +- .../transforms/functional/_geometry.py | 89 +++++++++---------- .../prototype/transforms/functional/_meta.py | 16 ++-- .../prototype/transforms/functional/_misc.py | 12 +-- 14 files changed, 122 insertions(+), 133 deletions(-) diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 5b2e6658acd..df77e8b77b3 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -1,6 +1,15 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._encoded import EncodedData, EncodedImage, EncodedVideo -from ._feature import _Feature, is_simple_tensor -from ._image import ColorSpace, Image +from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor +from ._image import ( + ColorSpace, + Image, + ImageType, + ImageTypeJIT, + LegacyImageType, + LegacyImageTypeJIT, + TensorImageType, + TensorImageTypeJIT, +) from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 9811673314c..3cd925fd996 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -10,8 +10,6 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, InterpolationMode -from ..features._image import ImageType, TensorImageType - from ._transform import _RandomApplyTransform from ._utils import has_any, query_chw @@ -94,7 +92,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) @@ -205,15 +203,15 @@ def __init__( def _copy_paste( self, - image: TensorImageType, + image: features.TensorImageType, target: Dict[str, Any], - paste_image: TensorImageType, + paste_image: features.TensorImageType, paste_target: Dict[str, Any], random_selection: torch.Tensor, blending: bool, resize_interpolation: F.InterpolationMode, antialias: Optional[bool], - ) -> Tuple[TensorImageType, Dict[str, Any]]: + ) -> Tuple[features.TensorImageType, Dict[str, Any]]: paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) @@ -280,7 +278,9 @@ def _copy_paste( return image, out_target - def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[TensorImageType], List[Dict[str, Any]]]: + def _extract_image_targets( + self, flat_sample: List[Any] + ) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input # with List[image], List[BoundingBox], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] @@ -309,7 +309,10 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[TensorIma return images, targets def _insert_outputs( - self, flat_sample: List[Any], output_images: List[TensorImageType], output_targets: List[Dict[str, Any]] + self, + flat_sample: List[Any], + output_images: List[features.TensorImageType], + output_targets: List[Dict[str, Any]], ) -> None: c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index cc54a5059af..c98e5c36e4a 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -9,9 +9,6 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_chw -from ..features._feature import FillType -from ..features._image import ImageType - from ._utils import _isinstance, _setup_fill_arg K = TypeVar("K") @@ -23,7 +20,7 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__() self.interpolation = interpolation @@ -38,7 +35,7 @@ def _extract_image( self, sample: Any, unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), - ) -> Tuple[int, ImageType]: + ) -> Tuple[int, features.ImageType]: sample_flat, _ = tree_flatten(sample) images = [] for id, inpt in enumerate(sample_flat): @@ -62,12 +59,12 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: def _apply_image_transform( self, - image: ImageType, + image: features.ImageType, transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Dict[Type, FillType], - ) -> ImageType: + fill: Dict[Type, features.FillType], + ) -> features.ImageType: fill_ = fill[type(image)] fill_ = F._geometry._convert_fill_arg(fill_) @@ -180,7 +177,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -340,7 +337,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -396,7 +393,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -456,7 +453,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Union[FillType, Dict[Type, FillType]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 82d7f6c2b5b..e0ee8d1b96a 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -6,8 +6,6 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ..features._image import ImageType - from ._transform import _RandomApplyTransform from ._utils import query_chw @@ -112,7 +110,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, ) - def _permute_channels(self, inpt: ImageType, permutation: torch.Tensor) -> ImageType: + def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType: if isinstance(inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) @@ -125,7 +123,7 @@ def _permute_channels(self, inpt: ImageType, permutation: torch.Tensor) -> Image return output - def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: if params["brightness"]: inpt = F.adjust_brightness( inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index cbfb9239a13..a9341415c1a 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -10,8 +10,6 @@ from torchvision.transforms import functional as _F from typing_extensions import Literal -from ..features._image import ImageType - from ._transform import _RandomApplyTransform from ._utils import query_chw @@ -54,7 +52,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: super().__init__() self.num_output_channels = num_output_channels - def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, features.Image): output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) @@ -83,7 +81,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: num_input_channels, _, _ = query_chw(sample) return dict(num_input_channels=num_input_channels) - def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, features.Image): output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4dee4e6d9b0..babcb83af04 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -12,9 +12,6 @@ from typing_extensions import Literal -from ..features._feature import FillType -from ..features._image import ImageType - from ._transform import _RandomApplyTransform from ._utils import ( _check_padding_arg, @@ -181,8 +178,8 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform( - self, inpt: ImageType, params: Dict[str, Any] - ) -> Tuple[ImageType, ImageType, ImageType, ImageType, ImageType]: + self, inpt: features.ImageType, params: Dict[str, Any] + ) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: return F.five_crop(inpt, self.size) def forward(self, *inputs: Any) -> Any: @@ -203,7 +200,7 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> List[ImageType]: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) def forward(self, *inputs: Any) -> Any: @@ -216,7 +213,7 @@ class Pad(Transform): def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -243,7 +240,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -285,7 +282,7 @@ def __init__( degrees: Union[numbers.Number, Sequence], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -325,7 +322,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[float, Sequence[float]]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -404,7 +401,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -494,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform): def __init__( self, distortion_scale: float = 0.5, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, p: float = 0.5, ) -> None: @@ -570,7 +567,7 @@ def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() @@ -783,7 +780,7 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index a1f6ff9bdc7..2ea3014aa6c 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -6,8 +6,6 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ..features._image import ImageType, TensorImageType - class ConvertBoundingBoxFormat(Transform): _transformed_types = (features.BoundingBox,) @@ -30,7 +28,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform(self, inpt: TensorImageType, params: Dict[str, Any]) -> TensorImageType: + def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: output = F.convert_image_dtype(inpt, dtype=self.dtype) return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type] @@ -56,7 +54,7 @@ def __init__( self.copy = copy - def _transform(self, inpt: ImageType, params: Dict[str, Any]) -> ImageType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: return F.convert_color_space( inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 12ad2a375d7..db93378312f 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -8,8 +8,6 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from ..features._image import TensorImageType - from ._utils import _setup_size, has_any, query_bounding_box @@ -70,7 +68,7 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) - def _transform(self, inpt: TensorImageType, params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: # Image instance after linear transformation is not Image anymore due to unknown data range # Thus we will return Tensor for input Image @@ -103,7 +101,7 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = self.std = list(std) self.inplace = inplace - def _transform(self, inpt: TensorImageType, params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) def forward(self, *inpts: Any) -> Any: diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index d1ffeaa8735..7a7780706d8 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,8 +5,6 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ...features._image import ImageTypeJIT - erase_image_tensor = _FT.erase @@ -20,14 +18,14 @@ def erase_image_pil( def erase( - inpt: ImageTypeJIT, + inpt: features.ImageTypeJIT, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> ImageTypeJIT: +) -> features.ImageTypeJIT: if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index dd27757e5c7..f375cb048c6 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,13 +2,11 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ...features._feature import InputTypeJIT - adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness -def adjust_brightness(inpt: InputTypeJIT, brightness_factor: float) -> InputTypeJIT: +def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) elif isinstance(inpt, features._Feature): @@ -21,7 +19,7 @@ def adjust_brightness(inpt: InputTypeJIT, brightness_factor: float) -> InputType adjust_saturation_image_pil = _FP.adjust_saturation -def adjust_saturation(inpt: InputTypeJIT, saturation_factor: float) -> InputTypeJIT: +def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) elif isinstance(inpt, features._Feature): @@ -34,7 +32,7 @@ def adjust_saturation(inpt: InputTypeJIT, saturation_factor: float) -> InputType adjust_contrast_image_pil = _FP.adjust_contrast -def adjust_contrast(inpt: InputTypeJIT, contrast_factor: float) -> InputTypeJIT: +def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) elif isinstance(inpt, features._Feature): @@ -47,7 +45,7 @@ def adjust_contrast(inpt: InputTypeJIT, contrast_factor: float) -> InputTypeJIT: adjust_sharpness_image_pil = _FP.adjust_sharpness -def adjust_sharpness(inpt: InputTypeJIT, sharpness_factor: float) -> InputTypeJIT: +def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) elif isinstance(inpt, features._Feature): @@ -60,7 +58,7 @@ def adjust_sharpness(inpt: InputTypeJIT, sharpness_factor: float) -> InputTypeJI adjust_hue_image_pil = _FP.adjust_hue -def adjust_hue(inpt: InputTypeJIT, hue_factor: float) -> InputTypeJIT: +def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) elif isinstance(inpt, features._Feature): @@ -73,7 +71,7 @@ def adjust_hue(inpt: InputTypeJIT, hue_factor: float) -> InputTypeJIT: adjust_gamma_image_pil = _FP.adjust_gamma -def adjust_gamma(inpt: InputTypeJIT, gamma: float, gain: float = 1) -> InputTypeJIT: +def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) elif isinstance(inpt, features._Feature): @@ -86,7 +84,7 @@ def adjust_gamma(inpt: InputTypeJIT, gamma: float, gain: float = 1) -> InputType posterize_image_pil = _FP.posterize -def posterize(inpt: InputTypeJIT, bits: int) -> InputTypeJIT: +def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return posterize_image_tensor(inpt, bits=bits) elif isinstance(inpt, features._Feature): @@ -99,7 +97,7 @@ def posterize(inpt: InputTypeJIT, bits: int) -> InputTypeJIT: solarize_image_pil = _FP.solarize -def solarize(inpt: InputTypeJIT, threshold: float) -> InputTypeJIT: +def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return solarize_image_tensor(inpt, threshold=threshold) elif isinstance(inpt, features._Feature): @@ -112,7 +110,7 @@ def solarize(inpt: InputTypeJIT, threshold: float) -> InputTypeJIT: autocontrast_image_pil = _FP.autocontrast -def autocontrast(inpt: InputTypeJIT) -> InputTypeJIT: +def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return autocontrast_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -125,7 +123,7 @@ def autocontrast(inpt: InputTypeJIT) -> InputTypeJIT: equalize_image_pil = _FP.equalize -def equalize(inpt: InputTypeJIT) -> InputTypeJIT: +def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return equalize_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -138,7 +136,7 @@ def equalize(inpt: InputTypeJIT) -> InputTypeJIT: invert_image_pil = _FP.invert -def invert(inpt: InputTypeJIT) -> InputTypeJIT: +def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return invert_image_tensor(inpt) elif isinstance(inpt, features._Feature): diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index 20ed419bf04..cbdea5130ef 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -7,8 +7,6 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F -from ...features._image import ImageTypeJIT, LegacyImageTypeJIT - @torch.jit.unused def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: @@ -24,7 +22,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima return _F.to_grayscale(inpt, num_output_channels=num_output_channels) -def rgb_to_grayscale(inpt: LegacyImageTypeJIT, num_output_channels: int = 1) -> LegacyImageTypeJIT: +def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT: old_color_space = ( features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)) @@ -58,7 +56,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: ImageTypeJIT) -> List[int]: +def get_image_size(inpt: features.ImageTypeJIT) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6e51bd6c318..e619b7b26e9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -16,9 +16,6 @@ ) from torchvision.transforms.functional_tensor import _parse_pad_padding -from ...features._feature import FillTypeJIT, InputTypeJIT -from ...features._image import ImageTypeJIT - from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor horizontal_flip_image_tensor = _FT.hflip @@ -45,7 +42,7 @@ def horizontal_flip_bounding_box( ).view(shape) -def horizontal_flip(inpt: InputTypeJIT) -> InputTypeJIT: +def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return horizontal_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -78,7 +75,7 @@ def vertical_flip_bounding_box( ).view(shape) -def vertical_flip(inpt: InputTypeJIT) -> InputTypeJIT: +def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return vertical_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -155,12 +152,12 @@ def resize_bounding_box( def resize( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, -) -> InputTypeJIT: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -230,7 +227,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if img.numel() == 0: @@ -262,7 +259,7 @@ def affine_image_pil( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) @@ -380,7 +377,7 @@ def affine_mask( translate: List[float], scale: float, shear: List[float], - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -406,7 +403,7 @@ def affine_mask( return output -def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> FillTypeJIT: +def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> features.FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: @@ -421,15 +418,15 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f def affine( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, angle: float, translate: List[float], scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, -) -> InputTypeJIT: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return affine_image_tensor( inpt, @@ -463,7 +460,7 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] @@ -502,7 +499,7 @@ def rotate_image_pil( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: @@ -542,7 +539,7 @@ def rotate_mask( mask: torch.Tensor, angle: float, expand: bool = False, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -567,13 +564,13 @@ def rotate_mask( def rotate( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, -) -> InputTypeJIT: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, features._Feature): @@ -588,7 +585,7 @@ def rotate( def pad_image_tensor( img: torch.Tensor, padding: Union[int, List[int]], - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: if fill is None: @@ -652,7 +649,7 @@ def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant", - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: if fill is None: fill = 0 @@ -698,11 +695,11 @@ def pad_bounding_box( def pad( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, padding: Union[int, List[int]], - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, padding_mode: str = "constant", -) -> InputTypeJIT: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) @@ -739,7 +736,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) return crop_image_tensor(mask, top, left, height, width) -def crop(inpt: InputTypeJIT, top: int, left: int, height: int, width: int) -> InputTypeJIT: +def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return crop_image_tensor(inpt, top, left, height, width) elif isinstance(inpt, features._Feature): @@ -752,7 +749,7 @@ def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) @@ -762,7 +759,7 @@ def perspective_image_pil( img: PIL.Image.Image, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BICUBIC, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> PIL.Image.Image: return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) @@ -855,7 +852,7 @@ def perspective_bounding_box( def perspective_mask( mask: torch.Tensor, perspective_coeffs: List[float], - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -874,11 +871,11 @@ def perspective_mask( def perspective( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillTypeJIT = None, -) -> InputTypeJIT: + fill: features.FillTypeJIT = None, +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -891,7 +888,7 @@ def elastic_image_tensor( img: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) @@ -901,7 +898,7 @@ def elastic_image_pil( img: PIL.Image.Image, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(img) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) @@ -951,7 +948,7 @@ def elastic_bounding_box( def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, - fill: FillTypeJIT = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -968,11 +965,11 @@ def elastic_mask( def elastic( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: FillTypeJIT = None, -) -> InputTypeJIT: + fill: features.FillTypeJIT = None, +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -1069,7 +1066,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor return output -def center_crop(inpt: InputTypeJIT, output_size: List[int]) -> InputTypeJIT: +def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return center_crop_image_tensor(inpt, output_size) elif isinstance(inpt, features._Feature): @@ -1132,7 +1129,7 @@ def resized_crop_mask( def resized_crop( - inpt: InputTypeJIT, + inpt: features.InputTypeJIT, top: int, left: int, height: int, @@ -1140,7 +1137,7 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, -) -> InputTypeJIT: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resized_crop_image_tensor( @@ -1205,9 +1202,11 @@ def five_crop_image_pil( def five_crop( - inpt: ImageTypeJIT, size: List[int] -) -> Tuple[ImageTypeJIT, ImageTypeJIT, ImageTypeJIT, ImageTypeJIT, ImageTypeJIT]: - # TODO: consider breaking BC here to return List[ImageTypeJIT] to align this op with `ten_crop` + inpt: features.ImageTypeJIT, size: List[int] +) -> Tuple[ + features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT +]: + # TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop` if isinstance(inpt, torch.Tensor): output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): @@ -1244,7 +1243,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] -def ten_crop(inpt: ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[ImageTypeJIT]: +def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]: if isinstance(inpt, torch.Tensor): output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 894d4d1b36b..90cfffcf276 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -6,14 +6,12 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ...features._image import ImageTypeJIT - get_dimensions_image_tensor = _FT.get_dimensions get_dimensions_image_pil = _FP.get_dimensions # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: ImageTypeJIT) -> Tuple[int, int, int]: +def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): channels, height, width = get_dimensions_image_tensor(image) elif isinstance(image, features.Image): @@ -31,11 +29,11 @@ def get_chw(image: ImageTypeJIT) -> Tuple[int, int, int]: # detailed above. -def get_dimensions(image: ImageTypeJIT) -> List[int]: +def get_dimensions(image: features.ImageTypeJIT) -> List[int]: return list(get_chw(image)) -def get_num_channels(image: ImageTypeJIT) -> int: +def get_num_channels(image: features.ImageTypeJIT) -> int: num_channels, *_ = get_chw(image) return num_channels @@ -45,7 +43,7 @@ def get_num_channels(image: ImageTypeJIT) -> int: get_image_num_channels = get_num_channels -def get_spatial_size(image: ImageTypeJIT) -> List[int]: +def get_spatial_size(image: features.ImageTypeJIT) -> List[int]: _, *size = get_chw(image) return size @@ -210,11 +208,11 @@ def convert_color_space_image_pil( def convert_color_space( - inpt: ImageTypeJIT, + inpt: features.ImageTypeJIT, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True, -) -> ImageTypeJIT: +) -> features.ImageTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): if old_color_space is None: raise RuntimeError( @@ -227,4 +225,4 @@ def convert_color_space( elif isinstance(inpt, features.Image): return inpt.to_color_space(color_space, copy=copy) else: - return cast(ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) + return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 848939838c7..952dc0d9e0d 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -6,14 +6,12 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ...features._feature import InputTypeJIT - -from ...features._image import TensorImageTypeJIT - normalize_image_tensor = _FT.normalize -def normalize(inpt: TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: +def normalize( + inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False +) -> torch.Tensor: if not isinstance(inpt, torch.Tensor): raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") else: @@ -60,7 +58,9 @@ def gaussian_blur_image_pil( return to_pil_image(output, mode=img.mode) -def gaussian_blur(inpt: InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None) -> InputTypeJIT: +def gaussian_blur( + inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, features._Feature): From f3cefdfc18af29df4d312bb64e89fb2258c33bc0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 12:52:14 +0100 Subject: [PATCH 09/11] Reuse type. --- torchvision/prototype/transforms/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index e619b7b26e9..87b65868bf9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -403,7 +403,7 @@ def affine_mask( return output -def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> features.FillTypeJIT: +def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: From 4bf566a8f0f54c9d93e7c5a5ba53494d9b9392ad Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 13:25:32 +0100 Subject: [PATCH 10/11] Temporarily remove the TorchData tests. --- .github/workflows/prototype-tests.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/prototype-tests.yml b/.github/workflows/prototype-tests.yml index 428f47a70cf..260c52f1e56 100644 --- a/.github/workflows/prototype-tests.yml +++ b/.github/workflows/prototype-tests.yml @@ -39,15 +39,6 @@ jobs: - name: Install test requirements run: pip install --progress-bar=off pytest pytest-mock pytest-cov - - name: Run prototype datasets tests - shell: bash - run: | - pytest \ - --durations=20 \ - --cov=torchvision/prototype/datasets \ - --cov-report=term-missing \ - test/test_prototype_datasets*.py - - name: Run prototype transforms tests shell: bash run: | From 458a60f339808eefc2fabf2fbc7e53b8aa27a35d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 23 Sep 2022 13:40:56 +0100 Subject: [PATCH 11/11] Restore the TorchData tests. --- .github/workflows/prototype-tests.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/prototype-tests.yml b/.github/workflows/prototype-tests.yml index 260c52f1e56..428f47a70cf 100644 --- a/.github/workflows/prototype-tests.yml +++ b/.github/workflows/prototype-tests.yml @@ -39,6 +39,15 @@ jobs: - name: Install test requirements run: pip install --progress-bar=off pytest pytest-mock pytest-cov + - name: Run prototype datasets tests + shell: bash + run: | + pytest \ + --durations=20 \ + --cov=torchvision/prototype/datasets \ + --cov-report=term-missing \ + test/test_prototype_datasets*.py + - name: Run prototype transforms tests shell: bash run: |