diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 3bdf0cfe34e..469ae5370f3 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -946,9 +946,9 @@ def test_adjust_gamma(device, dtype, config, channels): @pytest.mark.parametrize( "config", [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "constant", "fill": 10.0}, + {"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]}, + {"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]}, {"padding_mode": "edge"}, {"padding_mode": "reflect"}, {"padding_mode": "symmetric"}, diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 165c23dbdb8..9abad41627f 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -195,11 +195,22 @@ def test_color_jitter_all(self, device, channels): ) +@pytest.mark.parametrize( + "config", + [ + {"padding_mode": "constant", "fill": 10.0}, + {"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]}, + {"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + {"padding_mode": "symmetric"}, + ], +) @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"]) @pytest.mark.parametrize("mul", [1, -1]) -def test_pad(m, mul, device): - fill = 127 if m == "constant" else 0 +def test_pad(config, mul, device): + m = config["padding_mode"] + fill = config.get("fill", 0.0) # Test functional.pad (PIL and Tensor) with padding as single int _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device) @@ -252,9 +263,9 @@ def test_crop(device): @pytest.mark.parametrize( "padding_config", [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "constant", "fill": 10.0}, + {"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]}, + {"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]}, {"padding_mode": "edge"}, {"padding_mode": "reflect"}, ], diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6c9309749af..58aa50e9422 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -264,7 +264,7 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - img = pad_image_tensor(img, padding_ltrb, fill=0) + img = pad_image_tensor(img, padding_ltrb) _, image_height, image_width = get_dimensions_image_tensor(img) if crop_width == image_width and crop_height == image_height: @@ -280,7 +280,7 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - img = pad_image_pil(img, padding_ltrb, fill=0) + img = pad_image_pil(img, padding_ltrb) _, image_height, image_width = get_dimensions_image_pil(img) if crop_width == image_width and crop_height == image_height: diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5b762ff2975..30efaa7d867 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -431,7 +431,9 @@ def resize( return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) -def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: +def pad( + img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant" +) -> Tensor: r"""Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, @@ -451,7 +453,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant. - Only number is supported for torch Tensor. + Only number or tuple is supported for torch Tensor. Only int or str or tuple value is supported for PIL Image. padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. @@ -536,7 +538,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, ] - img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 + img = pad(img, padding_ltrb, fill=0.0) # PIL uses fill value 0 _, image_height, image_width = get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 261c4000bac..e35187a9ddb 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -154,7 +154,7 @@ def pad( if not isinstance(padding, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate padding arg") - if not isinstance(fill, (numbers.Number, str, tuple)): + if not isinstance(fill, (numbers.Number, str, list, tuple)): raise TypeError("Got inappropriate fill arg") if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") @@ -291,7 +291,7 @@ def _parse_fill( # Process fill color for affine transforms num_bands = len(img.getbands()) if fill is None: - fill = 0 + fill = 0.0 if isinstance(fill, (int, float)) and num_bands > 1: fill = tuple([fill] * num_bands) if isinstance(fill, (list, tuple)): @@ -301,6 +301,12 @@ def _parse_fill( fill = tuple(fill) + if img.mode != "F": + if isinstance(fill, (list, tuple)): + fill = tuple(int(x) for x in fill) + else: + fill = int(fill) + return {name: fill} diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 18b2c721f4e..eb1a4c11800 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union import torch from torch import Tensor @@ -141,7 +141,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: if left < 0 or top < 0 or right > w or bottom > h: padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] - return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) + return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0.0) return img[..., top:bottom, left:right] @@ -353,13 +353,15 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") -def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: +def pad( + img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant" +) -> Tensor: _assert_image_tensor(img) if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") - if not isinstance(fill, (int, float)): - raise TypeError("Got inappropriate fill arg") + channels, height, width = get_dimensions(img) + _assert_fill(fill, channels) if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") @@ -411,7 +413,22 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con need_cast = True img = img.to(torch.float32) - img = torch_pad(img, p, mode=padding_mode, value=float(fill)) + img = torch_pad(img, p, mode=padding_mode) + + if padding_mode == "constant": + # The following if/else can't be simplified due to JIT + if isinstance(fill, (tuple, list)): + fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1) + else: + fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device) + if pad_top > 0: + img[..., :pad_top, :] = fill_img + if pad_left > 0: + img[..., :, :pad_left] = fill_img + if pad_bottom > 0: + img[..., -pad_bottom:, :] = fill_img + if pad_right > 0: + img[..., :, -pad_right:] = fill_img if need_squeeze: img = img.squeeze(dim=0) @@ -499,11 +516,28 @@ def resize( return img +def _assert_fill(fill: Optional[Union[List[float], float]], num_channels: int): + if fill is None: + return + if not isinstance(fill, (int, float, tuple, list)): + warnings.warn("Argument fill should be either int, float, tuple or list") + + # Check fill + if isinstance(fill, (tuple, list)): + length = len(fill) + if length > 1 and length != num_channels: + msg = ( + "The number of elements in 'fill' cannot broadcast to match the number of " + "channels of the image ({} != {})" + ) + raise ValueError(msg.format(length, num_channels)) + + def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], interpolation: str, - fill: Optional[List[float]], + fill: Optional[Union[List[float], float]], supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ) -> None: @@ -522,17 +556,8 @@ def _assert_grid_transform_inputs( if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") - if fill is not None and not isinstance(fill, (int, float, tuple, list)): - warnings.warn("Argument fill should be either int, float, tuple or list") - - # Check fill num_channels = get_dimensions(img)[0] - if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): - msg = ( - "The number of elements in 'fill' cannot broadcast to match the number of " - "channels of the image ({} != {})" - ) - raise ValueError(msg.format(len(fill), num_channels)) + _assert_fill(fill, num_channels) if interpolation not in supported_interpolation_modes: raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2324acdd592..836a737c535 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -422,13 +422,13 @@ class Pad(torch.nn.Module): will result in [2, 1, 1, 2, 3, 4, 4, 3] """ - def __init__(self, padding, fill=0, padding_mode="constant"): + def __init__(self, padding, fill=0.0, padding_mode="constant"): super().__init__() _log_api_usage_once(self) if not isinstance(padding, (numbers.Number, tuple, list)): raise TypeError("Got inappropriate padding arg") - if not isinstance(fill, (numbers.Number, str, tuple)): + if not isinstance(fill, (numbers.Number, str, tuple, list)): raise TypeError("Got inappropriate fill arg") if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: @@ -641,7 +641,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw - def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + def __init__(self, size, padding=None, pad_if_needed=False, fill=0.0, padding_mode="constant"): super().__init__() _log_api_usage_once(self)