Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,9 @@ def test_adjust_gamma(device, dtype, config, channels):
@pytest.mark.parametrize(
"config",
[
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "constant", "fill": 10.0},
{"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]},
{"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test won't pass if we provide integers. That's because the test conducts JIT-script checks as well.

Here we check for single values, lists with the same value and lists with different values.

{"padding_mode": "edge"},
{"padding_mode": "reflect"},
{"padding_mode": "symmetric"},
Expand Down
23 changes: 17 additions & 6 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,22 @@ def test_color_jitter_all(self, device, channels):
)


@pytest.mark.parametrize(
"config",
[
{"padding_mode": "constant", "fill": 10.0},
{"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]},
{"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
{"padding_mode": "symmetric"},
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
@pytest.mark.parametrize("mul", [1, -1])
def test_pad(m, mul, device):
fill = 127 if m == "constant" else 0
def test_pad(config, mul, device):
m = config["padding_mode"]
fill = config.get("fill", 0.0)

# Test functional.pad (PIL and Tensor) with padding as single int
_test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device)
Expand Down Expand Up @@ -252,9 +263,9 @@ def test_crop(device):
@pytest.mark.parametrize(
"padding_config",
[
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "constant", "fill": 10.0},
{"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]},
{"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
],
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0)
img = pad_image_tensor(img, padding_ltrb)

_, image_height, image_width = get_dimensions_image_tensor(img)
if crop_width == image_width and crop_height == image_height:
Expand All @@ -280,7 +280,7 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0)
img = pad_image_pil(img, padding_ltrb)

_, image_height, image_width = get_dimensions_image_pil(img)
if crop_width == image_width and crop_height == image_height:
Expand Down
10 changes: 6 additions & 4 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from enum import Enum
from typing import List, Tuple, Any, Optional
from typing import List, Tuple, Any, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -431,7 +431,9 @@ def resize(
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)


def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
def pad(
img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value changed from int to float. JIT will fail if we pass integers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where exactly we need float values ? Maybe we could keep ints and List[int] and cast to float where it is required ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The int used previously is very misleading. We mainly use floats because our tensors get rescaled as you know. Unfortunately adding both List[int] and List[float] in the union doesn't work due to JIT issues. See pytorch/pytorch#69434

) -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
Expand All @@ -451,7 +453,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only number or tuple is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
Expand Down Expand Up @@ -536,7 +538,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
img = pad(img, padding_ltrb, fill=0.0) # PIL uses fill value 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again we need floats to appease JIT.

_, image_height, image_width = get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
Expand Down
10 changes: 8 additions & 2 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def pad(

if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
if not isinstance(fill, (numbers.Number, str, list, tuple)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated bug fix on the original code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a bug fix here and we do not expect this to land, maybe better to split this into a separate PR ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can cherrypick afterwards if we don't land this.

raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
Expand Down Expand Up @@ -291,7 +291,7 @@ def _parse_fill(
# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
fill = 0
fill = 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floats to please JIT

if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
Expand All @@ -301,6 +301,12 @@ def _parse_fill(

fill = tuple(fill)

if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated bug fix on the original code. This method doesn't work if floats are provided for PIL images, despite the method having floats in the signature.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here ?


return {name: fill}


Expand Down
59 changes: 42 additions & 17 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -141,7 +141,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:

if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floats to please JIT

return img[..., top:bottom, left:right]


Expand Down Expand Up @@ -353,13 +353,15 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")


def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
def pad(
img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)

if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
channels, height, width = get_dimensions(img)
_assert_fill(fill, channels)
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")

Expand Down Expand Up @@ -411,7 +413,22 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
need_cast = True
img = img.to(torch.float32)

img = torch_pad(img, p, mode=padding_mode, value=float(fill))
img = torch_pad(img, p, mode=padding_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fill is a scalar now we still transform it to a tensor and apply to the image at most 4 times below (img[..., :, :pad_left] = fill_img). Maybe, for performance reasons we could do if/else here and keep previous behaviour with a single torch_pad call for scalars and for list/tuple do what you coded ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that, see earlier versions of the commit. Unfortunately I couldn't find a way to write it in a JIT-friendly way. See here for more details. If you have ideas on how to have the optimization and be JIT-scriptable I'm happy to use them :)


if padding_mode == "constant":
# The following if/else can't be simplified due to JIT
if isinstance(fill, (tuple, list)):
fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can't we create it directly as

fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. :( JIT requires it to be behind an if statement. I believe this is because it invokes a different C++ method (The one that receives a list VS a scalar).

Copy link
Collaborator

@vfdev-5 vfdev-5 Mar 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean avoid to call .to:

- fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1)
+ fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, -1, 1, 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed that. That's also needed for the scalar case. I believe some of the tests were failing due to fill being float and dtype being integer. Casting solves this.

BTW you are welcome to push to the branch if you want to experiment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ansley FYI this is the kind of weird code one must write to make things JIT-scriptable. Without the explicit if statement, JIT doesn't know how to handle fill when scalar vs when list. I believe this has to do with the fact that the C++ implementation ends up calling a different method.

else:
fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device)
if pad_top > 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling negative padding values.

img[..., :pad_top, :] = fill_img
if pad_left > 0:
img[..., :, :pad_left] = fill_img
if pad_bottom > 0:
img[..., -pad_bottom:, :] = fill_img
if pad_right > 0:
img[..., :, -pad_right:] = fill_img

if need_squeeze:
img = img.squeeze(dim=0)
Expand Down Expand Up @@ -499,11 +516,28 @@ def resize(
return img


def _assert_fill(fill: Optional[Union[List[float], float]], num_channels: int):
if fill is None:
return
if not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")

# Check fill
if isinstance(fill, (tuple, list)):
length = len(fill)
if length > 1 and length != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(length, num_channels))


def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[List[float]],
fill: Optional[Union[List[float], float]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
Expand All @@ -522,17 +556,8 @@ def _assert_grid_transform_inputs(
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")

if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")

# Check fill
num_channels = get_dimensions(img)[0]
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
_assert_fill(fill, num_channels)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just move out the code to reuse it above.


if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,13 @@ class Pad(torch.nn.Module):
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""

def __init__(self, padding, fill=0, padding_mode="constant"):
def __init__(self, padding, fill=0.0, padding_mode="constant"):
super().__init__()
_log_api_usage_once(self)
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")

if not isinstance(fill, (numbers.Number, str, tuple)):
if not isinstance(fill, (numbers.Number, str, tuple, list)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated bug fix on the original code.

raise TypeError("Got inappropriate fill arg")

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
Expand Down Expand Up @@ -641,7 +641,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw

def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0.0, padding_mode="constant"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floats to please JIT

super().__init__()
_log_api_usage_once(self)

Expand Down