diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 5b0693a2e78..b6085bb1c71 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -71,6 +71,7 @@ class TestSmoke: transforms.CenterCrop([16, 16]), transforms.ConvertImageDtype(), transforms.RandomHorizontalFlip(), + transforms.Pad(5), ) def test_common(self, transform, input): transform(input) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 0a3d23db3bd..1cb53f2a885 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -15,6 +15,7 @@ TenCrop, BatchMultiCrop, RandomHorizontalFlip, + Pad, RandomZoomOut, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 19e5ced791e..061f12cd446 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,5 +1,6 @@ import collections.abc import math +import numbers import warnings from typing import Any, Dict, List, Union, Sequence, Tuple, cast @@ -9,6 +10,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int +from typing_extensions import Literal from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor @@ -272,42 +274,31 @@ def apply_recursively(obj: Any) -> Any: return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) -class RandomZoomOut(Transform): +class Pad(Transform): def __init__( - self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + self, + padding: Union[int, Sequence[int]], + fill: Union[float, Sequence[float]] = 0.0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") - if fill is None: - fill = 0.0 - self.fill = fill - - self.side_range = side_range - if side_range[0] < 1.0 or side_range[0] > side_range[1]: - raise ValueError(f"Invalid canvas side range provided {side_range}.") - - self.p = p - - def _get_params(self, sample: Any) -> Dict[str, Any]: - image = query_image(sample) - orig_c, orig_h, orig_w = get_image_dimensions(image) - - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) - canvas_width = int(orig_w * r) - canvas_height = int(orig_h * r) + if not isinstance(fill, (numbers.Number, str, tuple, list)): + raise TypeError("Got inappropriate fill arg") - r = torch.rand(2) - left = int((canvas_width - orig_w) * r[0]) - top = int((canvas_height - orig_h) * r[1]) - right = canvas_width - (left + orig_w) - bottom = canvas_height - (top + orig_h) - padding = [left, top, right, bottom] + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - fill = self.fill - if not isinstance(fill, collections.abc.Sequence): - fill = [fill] * orig_c + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) - return dict(padding=padding, fill=fill) + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image) or is_simple_tensor(input): @@ -349,6 +340,48 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + +class RandomZoomOut(Transform): + def __init__( + self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + ) -> None: + super().__init__() + + if fill is None: + fill = 0.0 + self.fill = fill + + self.side_range = side_range + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError(f"Invalid canvas side range provided {side_range}.") + + self.p = p + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + orig_c, orig_h, orig_w = get_image_dimensions(image) + + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + canvas_width = int(orig_w * r) + canvas_height = int(orig_h * r) + + r = torch.rand(2) + left = int((canvas_width - orig_w) * r[0]) + top = int((canvas_height - orig_h) * r[1]) + right = canvas_width - (left + orig_w) + bottom = canvas_height - (top + orig_h) + padding = [left, top, right, bottom] + + fill = self.fill + if not isinstance(fill, collections.abc.Sequence): + fill = [fill] * orig_c + + return dict(padding=padding, fill=fill) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + transform = Pad(**params, padding_mode="constant") + return transform(input) + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if torch.rand(1) >= self.p: