-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[EXPERIMENTAL] Extending padding to support non-constant fill #5568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ccef4e4
d427cbe
49c6c1a
9589e59
6d1b41f
05384cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import numbers | ||
import warnings | ||
from enum import Enum | ||
from typing import List, Tuple, Any, Optional | ||
from typing import List, Tuple, Any, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -431,7 +431,9 @@ def resize( | |
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) | ||
|
||
|
||
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: | ||
def pad( | ||
img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value changed from int to float. JIT will fail if we pass integers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where exactly we need float values ? Maybe we could keep ints and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
) -> Tensor: | ||
r"""Pad the given image on all sides with the given "pad" value. | ||
If the image is torch Tensor, it is expected | ||
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, | ||
|
@@ -451,7 +453,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con | |
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. | ||
If a tuple of length 3, it is used to fill R, G, B channels respectively. | ||
This value is only used when the padding_mode is constant. | ||
Only number is supported for torch Tensor. | ||
Only number or tuple is supported for torch Tensor. | ||
Only int or str or tuple value is supported for PIL Image. | ||
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. | ||
Default is constant. | ||
|
@@ -536,7 +538,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: | |
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0, | ||
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0, | ||
] | ||
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 | ||
img = pad(img, padding_ltrb, fill=0.0) # PIL uses fill value 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again we need floats to appease JIT. |
||
_, image_height, image_width = get_dimensions(img) | ||
if crop_width == image_width and crop_height == image_height: | ||
return img | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,7 +154,7 @@ def pad( | |
|
||
if not isinstance(padding, (numbers.Number, tuple, list)): | ||
raise TypeError("Got inappropriate padding arg") | ||
if not isinstance(fill, (numbers.Number, str, tuple)): | ||
if not isinstance(fill, (numbers.Number, str, list, tuple)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated bug fix on the original code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is a bug fix here and we do not expect this to land, maybe better to split this into a separate PR ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah we can cherrypick afterwards if we don't land this. |
||
raise TypeError("Got inappropriate fill arg") | ||
if not isinstance(padding_mode, str): | ||
raise TypeError("Got inappropriate padding_mode arg") | ||
|
@@ -291,7 +291,7 @@ def _parse_fill( | |
# Process fill color for affine transforms | ||
num_bands = len(img.getbands()) | ||
if fill is None: | ||
fill = 0 | ||
fill = 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Floats to please JIT |
||
if isinstance(fill, (int, float)) and num_bands > 1: | ||
fill = tuple([fill] * num_bands) | ||
if isinstance(fill, (list, tuple)): | ||
|
@@ -301,6 +301,12 @@ def _parse_fill( | |
|
||
fill = tuple(fill) | ||
|
||
if img.mode != "F": | ||
if isinstance(fill, (list, tuple)): | ||
fill = tuple(int(x) for x in fill) | ||
else: | ||
fill = int(fill) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated bug fix on the original code. This method doesn't work if floats are provided for PIL images, despite the method having floats in the signature. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here ? |
||
|
||
return {name: fill} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import warnings | ||
from typing import Optional, Tuple, List | ||
from typing import Optional, Tuple, List, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
@@ -141,7 +141,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: | |
|
||
if left < 0 or top < 0 or right > w or bottom > h: | ||
padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] | ||
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) | ||
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Floats to please JIT |
||
return img[..., top:bottom, left:right] | ||
|
||
|
||
|
@@ -353,13 +353,15 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: | |
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") | ||
|
||
|
||
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: | ||
def pad( | ||
img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant" | ||
) -> Tensor: | ||
_assert_image_tensor(img) | ||
|
||
if not isinstance(padding, (int, tuple, list)): | ||
raise TypeError("Got inappropriate padding arg") | ||
if not isinstance(fill, (int, float)): | ||
raise TypeError("Got inappropriate fill arg") | ||
channels, height, width = get_dimensions(img) | ||
_assert_fill(fill, channels) | ||
if not isinstance(padding_mode, str): | ||
raise TypeError("Got inappropriate padding_mode arg") | ||
|
||
|
@@ -411,7 +413,22 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con | |
need_cast = True | ||
img = img.to(torch.float32) | ||
|
||
img = torch_pad(img, p, mode=padding_mode, value=float(fill)) | ||
img = torch_pad(img, p, mode=padding_mode) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had that, see earlier versions of the commit. Unfortunately I couldn't find a way to write it in a JIT-friendly way. See here for more details. If you have ideas on how to have the optimization and be JIT-scriptable I'm happy to use them :) |
||
|
||
if padding_mode == "constant": | ||
# The following if/else can't be simplified due to JIT | ||
if isinstance(fill, (tuple, list)): | ||
fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can't we create it directly as fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device) ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. :( JIT requires it to be behind an if statement. I believe this is because it invokes a different C++ method (The one that receives a list VS a scalar). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean avoid to call - fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1)
+ fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, -1, 1, 1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I missed that. That's also needed for the scalar case. I believe some of the tests were failing due to BTW you are welcome to push to the branch if you want to experiment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ansley FYI this is the kind of weird code one must write to make things JIT-scriptable. Without the explicit if statement, JIT doesn't know how to handle |
||
else: | ||
fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device) | ||
if pad_top > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handling negative padding values. |
||
img[..., :pad_top, :] = fill_img | ||
if pad_left > 0: | ||
img[..., :, :pad_left] = fill_img | ||
if pad_bottom > 0: | ||
img[..., -pad_bottom:, :] = fill_img | ||
if pad_right > 0: | ||
img[..., :, -pad_right:] = fill_img | ||
|
||
if need_squeeze: | ||
img = img.squeeze(dim=0) | ||
|
@@ -499,11 +516,28 @@ def resize( | |
return img | ||
|
||
|
||
def _assert_fill(fill: Optional[Union[List[float], float]], num_channels: int): | ||
if fill is None: | ||
return | ||
if not isinstance(fill, (int, float, tuple, list)): | ||
warnings.warn("Argument fill should be either int, float, tuple or list") | ||
|
||
# Check fill | ||
if isinstance(fill, (tuple, list)): | ||
length = len(fill) | ||
if length > 1 and length != num_channels: | ||
msg = ( | ||
"The number of elements in 'fill' cannot broadcast to match the number of " | ||
"channels of the image ({} != {})" | ||
) | ||
raise ValueError(msg.format(length, num_channels)) | ||
|
||
|
||
def _assert_grid_transform_inputs( | ||
img: Tensor, | ||
matrix: Optional[List[float]], | ||
interpolation: str, | ||
fill: Optional[List[float]], | ||
fill: Optional[Union[List[float], float]], | ||
supported_interpolation_modes: List[str], | ||
coeffs: Optional[List[float]] = None, | ||
) -> None: | ||
|
@@ -522,17 +556,8 @@ def _assert_grid_transform_inputs( | |
if coeffs is not None and len(coeffs) != 8: | ||
raise ValueError("Argument coeffs should have 8 float values") | ||
|
||
if fill is not None and not isinstance(fill, (int, float, tuple, list)): | ||
warnings.warn("Argument fill should be either int, float, tuple or list") | ||
|
||
# Check fill | ||
num_channels = get_dimensions(img)[0] | ||
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): | ||
msg = ( | ||
"The number of elements in 'fill' cannot broadcast to match the number of " | ||
"channels of the image ({} != {})" | ||
) | ||
raise ValueError(msg.format(len(fill), num_channels)) | ||
_assert_fill(fill, num_channels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just move out the code to reuse it above. |
||
|
||
if interpolation not in supported_interpolation_modes: | ||
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -422,13 +422,13 @@ class Pad(torch.nn.Module): | |
will result in [2, 1, 1, 2, 3, 4, 4, 3] | ||
""" | ||
|
||
def __init__(self, padding, fill=0, padding_mode="constant"): | ||
def __init__(self, padding, fill=0.0, padding_mode="constant"): | ||
super().__init__() | ||
_log_api_usage_once(self) | ||
if not isinstance(padding, (numbers.Number, tuple, list)): | ||
raise TypeError("Got inappropriate padding arg") | ||
|
||
if not isinstance(fill, (numbers.Number, str, tuple)): | ||
if not isinstance(fill, (numbers.Number, str, tuple, list)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated bug fix on the original code. |
||
raise TypeError("Got inappropriate fill arg") | ||
|
||
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: | ||
|
@@ -641,7 +641,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int | |
j = torch.randint(0, w - tw + 1, size=(1,)).item() | ||
return i, j, th, tw | ||
|
||
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): | ||
def __init__(self, size, padding=None, pad_if_needed=False, fill=0.0, padding_mode="constant"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Floats to please JIT |
||
super().__init__() | ||
_log_api_usage_once(self) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test won't pass if we provide integers. That's because the test conducts JIT-script checks as well.
Here we check for single values, lists with the same value and lists with different values.