diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 278b3811aeb..b9c89b2b76a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -132,7 +132,7 @@ def test_mixup_cutmix(self, transform, input): transform(input_copy) # Check if we raise an error if sample contains bbox or mask or label - err_msg = "does not support bounding boxes, masks and plain labels" + err_msg = "does not support PIL images, bounding boxes, masks and plain labels" input_copy = dict(input) for unsup_data in [ make_label(), diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 42ddd9aec27..df77e8b77b3 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -1,6 +1,15 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._encoded import EncodedData, EncodedImage, EncodedVideo -from ._feature import _Feature, DType, is_simple_tensor -from ._image import ColorSpace, Image, ImageType +from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor +from ._image import ( + ColorSpace, + Image, + ImageType, + ImageTypeJIT, + LegacyImageType, + LegacyImageTypeJIT, + TensorImageType, + TensorImageTypeJIT, +) from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 47473e7dc7c..0b61439d10c 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -6,7 +6,7 @@ from torchvision._utils import StrEnum from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms -from ._feature import _Feature +from ._feature import _Feature, FillTypeJIT class BoundingBoxFormat(StrEnum): @@ -115,7 +115,7 @@ def resized_crop( def pad( self, padding: Union[int, Sequence[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> BoundingBox: # This cast does Sequence[int] -> List[int] and is required to make mypy happy @@ -137,7 +137,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> BoundingBox: output = self._F.rotate_bounding_box( @@ -165,7 +165,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> BoundingBox: output = self._F.affine_bounding_box( @@ -184,7 +184,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) @@ -193,7 +193,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> BoundingBox: output = self._F.elastic_bounding_box(self, self.format, displacement) return BoundingBox.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 9cfccf33e54..3d4357b9a99 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -3,16 +3,14 @@ from types import ModuleType from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union +import PIL.Image import torch from torch._C import _TensorBase, DisableTorchFunction from torchvision.transforms import InterpolationMode F = TypeVar("F", bound="_Feature") - - -# Due to torch.jit.script limitation we keep DType as torch.Tensor -# instead of Union[torch.Tensor, PIL.Image.Image, features._Feature] -DType = torch.Tensor +FillType = Union[int, float, Sequence[int], Sequence[float], None] +FillTypeJIT = Union[int, float, List[float], None] def is_simple_tensor(inpt: Any) -> bool: @@ -154,7 +152,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> _Feature: return self @@ -164,7 +162,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> _Feature: return self @@ -176,7 +174,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> _Feature: return self @@ -185,7 +183,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> _Feature: return self @@ -193,7 +191,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> _Feature: return self @@ -232,3 +230,7 @@ def invert(self) -> _Feature: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: return self + + +InputType = Union[torch.Tensor, PIL.Image.Image, _Feature] +InputTypeJIT = torch.Tensor diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 62f7f2849c3..0b832ae0270 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -3,18 +3,14 @@ import warnings from typing import Any, cast, List, Optional, Tuple, Union +import PIL.Image import torch from torchvision._utils import StrEnum from torchvision.transforms.functional import InterpolationMode, to_pil_image from torchvision.utils import draw_bounding_boxes, make_grid from ._bounding_box import BoundingBox -from ._feature import _Feature - - -# Due to torch.jit.script limitation we keep ImageType as torch.Tensor -# instead of Union[torch.Tensor, PIL.Image.Image, features.Image] -ImageType = torch.Tensor +from ._feature import _Feature, FillTypeJIT class ColorSpace(StrEnum): @@ -181,7 +177,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> Image: output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) @@ -192,7 +188,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Image: output = self._F._geometry.rotate_image_tensor( @@ -207,7 +203,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Image: output = self._F._geometry.affine_image_tensor( @@ -226,7 +222,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> Image: output = self._F._geometry.perspective_image_tensor( self, perspective_coeffs, interpolation=interpolation, fill=fill @@ -237,7 +233,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> Image: output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) return Image.new_like(self, output) @@ -289,3 +285,11 @@ def invert(self) -> Image: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma) return Image.new_like(self, output) + + +ImageType = Union[torch.Tensor, PIL.Image.Image, Image] +ImageTypeJIT = torch.Tensor +LegacyImageType = Union[torch.Tensor, PIL.Image.Image] +LegacyImageTypeJIT = torch.Tensor +TensorImageType = Union[torch.Tensor, Image] +TensorImageTypeJIT = torch.Tensor diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 29858578c23..a0c3395dbe7 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -5,7 +5,7 @@ import torch from torchvision.transforms import InterpolationMode -from ._feature import _Feature +from ._feature import _Feature, FillTypeJIT class Mask(_Feature): @@ -51,7 +51,7 @@ def resized_crop( def pad( self, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> Mask: output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) @@ -62,7 +62,7 @@ def rotate( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Mask: output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill) @@ -75,7 +75,7 @@ def affine( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> Mask: output = self._F.affine_mask( @@ -93,7 +93,7 @@ def perspective( self, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> Mask: output = self._F.perspective_mask(self, perspective_coeffs, fill=fill) return Mask.new_like(self, output) @@ -102,7 +102,7 @@ def elastic( self, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: FillTypeJIT = None, ) -> Mask: output = self._F.elastic_mask(self, displacement, fill=fill) return Mask.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 963cfac75b2..3cd925fd996 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple import PIL.Image import torch @@ -92,9 +92,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform( - self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any] - ) -> Union[torch.Tensor, features.Image, PIL.Image.Image]: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) @@ -110,8 +108,10 @@ def __init__(self, alpha: float, p: float = 0.5) -> None: def forward(self, *inputs: Any) -> Any: if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") - if has_any(inputs, features.BoundingBox, features.Mask, features.Label): - raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.") + if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): + raise TypeError( + f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." + ) return super().forward(*inputs) def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: @@ -203,15 +203,15 @@ def __init__( def _copy_paste( self, - image: Any, + image: features.TensorImageType, target: Dict[str, Any], - paste_image: Any, + paste_image: features.TensorImageType, paste_target: Dict[str, Any], random_selection: torch.Tensor, blending: bool, resize_interpolation: F.InterpolationMode, antialias: Optional[bool], - ) -> Tuple[Any, Dict[str, Any]]: + ) -> Tuple[features.TensorImageType, Dict[str, Any]]: paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) @@ -223,7 +223,7 @@ def _copy_paste( # This is something different to TF implementation we introduced here as # originally the algorithm works on equal-sized data # (for example, coming from LSJ data augmentations) - size1 = image.shape[-2:] + size1 = cast(List[int], image.shape[-2:]) size2 = paste_image.shape[-2:] if size1 != size2: paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation, antialias=antialias) @@ -278,7 +278,9 @@ def _copy_paste( return image, out_target - def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: + def _extract_image_targets( + self, flat_sample: List[Any] + ) -> Tuple[List[features.TensorImageType], List[Dict[str, Any]]]: # fetch all images, bboxes, masks and labels from unstructured input # with List[image], List[BoundingBox], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] @@ -307,7 +309,10 @@ def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], Lis return images, targets def _insert_outputs( - self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] + self, + flat_sample: List[Any], + output_images: List[features.TensorImageType], + output_targets: List[Dict[str, Any]], ) -> None: c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 0d14a434c1e..c98e5c36e4a 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms.functional._meta import get_chw -from ._utils import _isinstance, _setup_fill_arg, FillType +from ._utils import _isinstance, _setup_fill_arg K = TypeVar("K") V = TypeVar("V") @@ -20,7 +20,7 @@ def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__() self.interpolation = interpolation @@ -35,7 +35,7 @@ def _extract_image( self, sample: Any, unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), - ) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]: + ) -> Tuple[int, features.ImageType]: sample_flat, _ = tree_flatten(sample) images = [] for id, inpt in enumerate(sample_flat): @@ -59,12 +59,12 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: def _apply_image_transform( self, - image: Union[torch.Tensor, PIL.Image.Image, features.Image], + image: features.ImageType, transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Union[Dict[Type, FillType], Dict[Type, None]], - ) -> Any: + fill: Dict[Type, features.FillType], + ) -> features.ImageType: fill_ = fill[type(image)] fill_ = F._geometry._convert_fill_arg(fill_) @@ -177,7 +177,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.policy = policy @@ -337,7 +337,7 @@ def __init__( magnitude: int = 9, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self.num_ops = num_ops @@ -393,7 +393,7 @@ def __init__( self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ): super().__init__(interpolation=interpolation, fill=fill) self.num_magnitude_bins = num_magnitude_bins @@ -453,7 +453,7 @@ def __init__( alpha: float = 1.0, all_ops: bool = True, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[FillType, Dict[Type, FillType]]] = None, + fill: Union[features.FillType, Dict[Type, features.FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) self._PARAMETER_MAX = 10 diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 19f90c7c984..e0ee8d1b96a 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,5 +1,5 @@ import collections.abc -from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -9,8 +9,6 @@ from ._transform import _RandomApplyTransform from ._utils import query_chw -T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) - class ColorJitter(Transform): def __init__( @@ -112,7 +110,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, ) - def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any: + def _permute_channels(self, inpt: features.ImageType, permutation: torch.Tensor) -> features.ImageType: if isinstance(inpt, PIL.Image.Image): inpt = F.pil_to_tensor(inpt) @@ -125,9 +123,7 @@ def _permute_channels(self, inpt: Any, permutation: torch.Tensor) -> Any: return output - def _transform( - self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any] - ) -> Union[torch.Tensor, features.Image, PIL.Image.Image]: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: if params["brightness"]: inpt = F.adjust_brightness( inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 9b28551a048..a9341415c1a 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -11,7 +11,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import DType, query_chw +from ._utils import query_chw class ToTensor(Transform): @@ -52,7 +52,7 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: super().__init__() self.num_output_channels = num_output_channels - def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) if isinstance(inpt, features.Image): output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) @@ -81,7 +81,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: num_input_channels, _, _ = query_chw(sample) return dict(num_input_channels=num_input_channels) - def _transform(self, inpt: DType, params: Dict[str, Any]) -> DType: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) if isinstance(inpt, features.Image): output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index a8f0b09765b..babcb83af04 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -20,8 +20,6 @@ _setup_angle, _setup_fill_arg, _setup_size, - DType, - FillType, has_all, has_any, query_bounding_box, @@ -179,7 +177,9 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform(self, inpt: DType, params: Dict[str, Any]) -> Tuple[DType, DType, DType, DType, DType]: + def _transform( + self, inpt: features.ImageType, params: Dict[str, Any] + ) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: return F.five_crop(inpt, self.size) def forward(self, *inputs: Any) -> Any: @@ -200,7 +200,7 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _transform(self, inpt: DType, params: Dict[str, Any]) -> List[DType]: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> List[features.ImageType]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) def forward(self, *inputs: Any) -> Any: @@ -213,7 +213,7 @@ class Pad(Transform): def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -240,7 +240,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: @@ -282,7 +282,7 @@ def __init__( degrees: Union[numbers.Number, Sequence], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -322,7 +322,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[float, Sequence[float]]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -401,7 +401,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -491,7 +491,7 @@ class RandomPerspective(_RandomApplyTransform): def __init__( self, distortion_scale: float = 0.5, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, p: float = 0.5, ) -> None: @@ -567,7 +567,7 @@ def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() @@ -780,7 +780,7 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[FillType, Dict[Type, FillType]] = 0, + fill: Union[features.FillType, Dict[Type, features.FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index d4dc87cf6f8..2ea3014aa6c 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -28,9 +28,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform( - self, inpt: Union[torch.Tensor, features.Image], params: Dict[str, Any] - ) -> Union[torch.Tensor, features.Image]: + def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: output = F.convert_image_dtype(inpt, dtype=self.dtype) return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type] @@ -56,9 +54,7 @@ def __init__( self.copy = copy - def _transform( - self, inpt: Union[torch.Tensor, PIL.Image.Image, features._Feature], params: Dict[str, Any] - ) -> Union[torch.Tensor, PIL.Image.Image, features._Feature]: + def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: return F.convert_color_space( inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy ) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 82e08259662..db93378312f 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -68,7 +68,7 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) - def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: # Image instance after linear transformation is not Image anymore due to unknown data range # Thus we will return Tensor for input Image @@ -101,7 +101,7 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = self.std = list(std) self.inplace = inplace - def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: + def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> torch.Tensor: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) def forward(self, *inpts: Any) -> Any: diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 09829629e03..7107b14b3e0 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,14 +1,14 @@ import numbers from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union import PIL.Image -import torch from torch.utils._pytree import tree_flatten from torchvision._utils import sequence_to_str from torchvision.prototype import features +from torchvision.prototype.features._feature import FillType from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 @@ -16,12 +16,7 @@ from typing_extensions import Literal -# Type shortcuts: -DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] -FillType = Union[int, float, Sequence[int], Sequence[float]] - - -def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None: +def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: if isinstance(fill, dict): for key, value in fill.items(): # Check key for type @@ -31,15 +26,13 @@ def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> No raise TypeError("Got inappropriate fill arg") -def _setup_fill_arg( - fill: Optional[Union[FillType, Dict[Type, FillType]]] -) -> Union[Dict[Type, FillType], Dict[Type, None]]: +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: _check_fill_arg(fill) if isinstance(fill, dict): return fill - return defaultdict(lambda: fill) # type: ignore[return-value] + return defaultdict(lambda: fill) # type: ignore[return-value, arg-type] def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index cb79ef98f0a..7a7780706d8 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,7 +5,6 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image - erase_image_tensor = _FT.erase @@ -19,14 +18,14 @@ def erase_image_pil( def erase( - inpt: features.ImageType, + inpt: features.ImageTypeJIT, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False, -) -> features.ImageType: +) -> features.ImageTypeJIT: if isinstance(inpt, torch.Tensor): output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index d3257dc76a0..f375cb048c6 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,12 +2,11 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT - adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness -def adjust_brightness(inpt: features.DType, brightness_factor: float) -> features.DType: +def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) elif isinstance(inpt, features._Feature): @@ -20,7 +19,7 @@ def adjust_brightness(inpt: features.DType, brightness_factor: float) -> feature adjust_saturation_image_pil = _FP.adjust_saturation -def adjust_saturation(inpt: features.DType, saturation_factor: float) -> features.DType: +def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) elif isinstance(inpt, features._Feature): @@ -33,7 +32,7 @@ def adjust_saturation(inpt: features.DType, saturation_factor: float) -> feature adjust_contrast_image_pil = _FP.adjust_contrast -def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DType: +def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) elif isinstance(inpt, features._Feature): @@ -46,7 +45,7 @@ def adjust_contrast(inpt: features.DType, contrast_factor: float) -> features.DT adjust_sharpness_image_pil = _FP.adjust_sharpness -def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features.DType: +def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) elif isinstance(inpt, features._Feature): @@ -59,7 +58,7 @@ def adjust_sharpness(inpt: features.DType, sharpness_factor: float) -> features. adjust_hue_image_pil = _FP.adjust_hue -def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType: +def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) elif isinstance(inpt, features._Feature): @@ -72,7 +71,7 @@ def adjust_hue(inpt: features.DType, hue_factor: float) -> features.DType: adjust_gamma_image_pil = _FP.adjust_gamma -def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> features.DType: +def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) elif isinstance(inpt, features._Feature): @@ -85,7 +84,7 @@ def adjust_gamma(inpt: features.DType, gamma: float, gain: float = 1) -> feature posterize_image_pil = _FP.posterize -def posterize(inpt: features.DType, bits: int) -> features.DType: +def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return posterize_image_tensor(inpt, bits=bits) elif isinstance(inpt, features._Feature): @@ -98,7 +97,7 @@ def posterize(inpt: features.DType, bits: int) -> features.DType: solarize_image_pil = _FP.solarize -def solarize(inpt: features.DType, threshold: float) -> features.DType: +def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return solarize_image_tensor(inpt, threshold=threshold) elif isinstance(inpt, features._Feature): @@ -111,7 +110,7 @@ def solarize(inpt: features.DType, threshold: float) -> features.DType: autocontrast_image_pil = _FP.autocontrast -def autocontrast(inpt: features.DType) -> features.DType: +def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return autocontrast_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -124,7 +123,7 @@ def autocontrast(inpt: features.DType) -> features.DType: equalize_image_pil = _FP.equalize -def equalize(inpt: features.DType) -> features.DType: +def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return equalize_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -137,7 +136,7 @@ def equalize(inpt: features.DType) -> features.DType: invert_image_pil = _FP.invert -def invert(inpt: features.DType) -> features.DType: +def invert(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return invert_image_tensor(inpt) elif isinstance(inpt, features._Feature): diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index 9c8dcff5a86..cbdea5130ef 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -8,11 +8,6 @@ from torchvision.transforms import functional as _F -# Due to torch.jit.script limitation we keep LegacyImageType as torch.Tensor -# instead of Union[torch.Tensor, PIL.Image.Image] -LegacyImageType = torch.Tensor - - @torch.jit.unused def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: call = ", num_output_channels=3" if num_output_channels == 3 else "" @@ -27,7 +22,7 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima return _F.to_grayscale(inpt, num_output_channels=num_output_channels) -def rgb_to_grayscale(inpt: LegacyImageType, num_output_channels: int = 1) -> LegacyImageType: +def rgb_to_grayscale(inpt: features.LegacyImageTypeJIT, num_output_channels: int = 1) -> features.LegacyImageTypeJIT: old_color_space = ( features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)) @@ -61,7 +56,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: return _F.to_tensor(inpt) -def get_image_size(inpt: features.ImageType) -> List[int]: +def get_image_size(inpt: features.ImageTypeJIT) -> List[int]: warnings.warn( "The function `get_image_size(...)` is deprecated and will be removed in a future release. " "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 666e4105992..87b65868bf9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -18,7 +18,6 @@ from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor - horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -43,7 +42,7 @@ def horizontal_flip_bounding_box( ).view(shape) -def horizontal_flip(inpt: features.DType) -> features.DType: +def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return horizontal_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -76,7 +75,7 @@ def vertical_flip_bounding_box( ).view(shape) -def vertical_flip(inpt: features.DType) -> features.DType: +def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return vertical_flip_image_tensor(inpt) elif isinstance(inpt, features._Feature): @@ -153,12 +152,12 @@ def resize_bounding_box( def resize( - inpt: features.DType, + inpt: features.InputTypeJIT, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, -) -> features.DType: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) @@ -228,7 +227,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if img.numel() == 0: @@ -260,7 +259,7 @@ def affine_image_pil( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) @@ -378,7 +377,7 @@ def affine_mask( translate: List[float], scale: float, shear: List[float], - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -404,9 +403,7 @@ def affine_mask( return output -def _convert_fill_arg( - fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] -) -> Optional[Union[int, float, List[float]]]: +def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: @@ -421,15 +418,15 @@ def _convert_fill_arg( def affine( - inpt: features.DType, + inpt: features.InputTypeJIT, angle: float, translate: List[float], scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, -) -> features.DType: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return affine_image_tensor( inpt, @@ -463,7 +460,7 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] @@ -502,7 +499,7 @@ def rotate_image_pil( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: if center is not None and expand: @@ -542,7 +539,7 @@ def rotate_mask( mask: torch.Tensor, angle: float, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if mask.ndim < 3: @@ -567,13 +564,13 @@ def rotate_mask( def rotate( - inpt: features.DType, + inpt: features.InputTypeJIT, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, center: Optional[List[float]] = None, -) -> features.DType: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, features._Feature): @@ -588,7 +585,7 @@ def rotate( def pad_image_tensor( img: torch.Tensor, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: if fill is None: @@ -652,7 +649,7 @@ def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant", - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: if fill is None: fill = 0 @@ -698,11 +695,11 @@ def pad_bounding_box( def pad( - inpt: features.DType, + inpt: features.InputTypeJIT, padding: Union[int, List[int]], - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, padding_mode: str = "constant", -) -> features.DType: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) @@ -739,7 +736,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) return crop_image_tensor(mask, top, left, height, width) -def crop(inpt: features.DType, top: int, left: int, height: int, width: int) -> features.DType: +def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: int) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return crop_image_tensor(inpt, top, left, height, width) elif isinstance(inpt, features._Feature): @@ -752,7 +749,7 @@ def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) @@ -762,7 +759,7 @@ def perspective_image_pil( img: PIL.Image.Image, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BICUBIC, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> PIL.Image.Image: return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) @@ -855,7 +852,7 @@ def perspective_bounding_box( def perspective_mask( mask: torch.Tensor, perspective_coeffs: List[float], - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -874,11 +871,11 @@ def perspective_mask( def perspective( - inpt: features.DType, + inpt: features.InputTypeJIT, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, -) -> features.DType: + fill: features.FillTypeJIT = None, +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -891,7 +888,7 @@ def elastic_image_tensor( img: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) @@ -901,7 +898,7 @@ def elastic_image_pil( img: PIL.Image.Image, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(img) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) @@ -951,7 +948,7 @@ def elastic_bounding_box( def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, - fill: Optional[Union[int, float, List[float]]] = None, + fill: features.FillTypeJIT = None, ) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -968,11 +965,11 @@ def elastic_mask( def elastic( - inpt: features.DType, + inpt: features.InputTypeJIT, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[Union[int, float, List[float]]] = None, -) -> features.DType: + fill: features.FillTypeJIT = None, +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, features._Feature): @@ -1069,7 +1066,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor return output -def center_crop(inpt: features.DType, output_size: List[int]) -> features.DType: +def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return center_crop_image_tensor(inpt, output_size) elif isinstance(inpt, features._Feature): @@ -1132,7 +1129,7 @@ def resized_crop_mask( def resized_crop( - inpt: features.DType, + inpt: features.InputTypeJIT, top: int, left: int, height: int, @@ -1140,7 +1137,7 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, -) -> features.DType: +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): antialias = False if antialias is None else antialias return resized_crop_image_tensor( @@ -1205,9 +1202,11 @@ def five_crop_image_pil( def five_crop( - inpt: features.ImageType, size: List[int] -) -> Tuple[features.ImageType, features.ImageType, features.ImageType, features.ImageType, features.ImageType]: - # TODO: consider breaking BC here to return List[features.ImageType] to align this op with `ten_crop` + inpt: features.ImageTypeJIT, size: List[int] +) -> Tuple[ + features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT, features.ImageTypeJIT +]: + # TODO: consider breaking BC here to return List[features.ImageTypeJIT] to align this op with `ten_crop` if isinstance(inpt, torch.Tensor): output = five_crop_image_tensor(inpt, size) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): @@ -1244,7 +1243,7 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] -def ten_crop(inpt: features.ImageType, size: List[int], vertical_flip: bool = False) -> List[features.ImageType]: +def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = False) -> List[features.ImageTypeJIT]: if isinstance(inpt, torch.Tensor): output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) if not torch.jit.is_scripting() and isinstance(inpt, features.Image): diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index b89ea09b726..90cfffcf276 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -11,7 +11,7 @@ # TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init? -def get_chw(image: features.ImageType) -> Tuple[int, int, int]: +def get_chw(image: features.ImageTypeJIT) -> Tuple[int, int, int]: if isinstance(image, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(image, features.Image)): channels, height, width = get_dimensions_image_tensor(image) elif isinstance(image, features.Image): @@ -29,11 +29,11 @@ def get_chw(image: features.ImageType) -> Tuple[int, int, int]: # detailed above. -def get_dimensions(image: features.ImageType) -> List[int]: +def get_dimensions(image: features.ImageTypeJIT) -> List[int]: return list(get_chw(image)) -def get_num_channels(image: features.ImageType) -> int: +def get_num_channels(image: features.ImageTypeJIT) -> int: num_channels, *_ = get_chw(image) return num_channels @@ -43,7 +43,7 @@ def get_num_channels(image: features.ImageType) -> int: get_image_num_channels = get_num_channels -def get_spatial_size(image: features.ImageType) -> List[int]: +def get_spatial_size(image: features.ImageTypeJIT) -> List[int]: _, *size = get_chw(image) return size @@ -208,11 +208,11 @@ def convert_color_space_image_pil( def convert_color_space( - inpt: features.ImageType, + inpt: features.ImageTypeJIT, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True, -) -> features.ImageType: +) -> features.ImageTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Image)): if old_color_space is None: raise RuntimeError( @@ -225,4 +225,4 @@ def convert_color_space( elif isinstance(inpt, features.Image): return inpt.to_color_space(color_space, copy=copy) else: - return cast(features.ImageType, convert_color_space_image_pil(inpt, color_space, copy=copy)) + return cast(features.ImageTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 4c8c4697d67..952dc0d9e0d 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -6,16 +6,12 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image - -# Due to torch.jit.script limitation we keep TensorImageType as torch.Tensor -# instead of Union[torch.Tensor, features.Image] -TensorImageType = torch.Tensor - - normalize_image_tensor = _FT.normalize -def normalize(inpt: TensorImageType, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: +def normalize( + inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False +) -> torch.Tensor: if not isinstance(inpt, torch.Tensor): raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") else: @@ -62,7 +58,9 @@ def gaussian_blur_image_pil( return to_pil_image(output, mode=img.mode) -def gaussian_blur(inpt: features.DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> features.DType: +def gaussian_blur( + inpt: features.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> features.InputTypeJIT: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) elif isinstance(inpt, features._Feature):