Skip to content

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Mar 8, 2022

PyTorch's pad supports only constant values. Unfortunately this can be a problem for data augmentation techniques that require padding with a specific fill colour. For some of them we have previously employed the following trick:

if isinstance(image, torch.Tensor):
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
..., :, (left + orig_w) :
] = v

This PR adapts the approach and moves it to F.pad(). The fill can be either a float or a List[float]. Unfortunately JIT doesn't allow us to include also int and List[int]. The PR modifies the default values of some of the methods.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 8, 2022

💊 CI failures summary and remediations

As of commit 05384cf (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@datumbox datumbox changed the title Extending padding to support non-constant fill [EXPERIMENTAL] Extending padding to support non-constant fill Mar 8, 2022
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Providing comments to assist review

{"padding_mode": "constant", "fill": 20},
{"padding_mode": "constant", "fill": 10.0},
{"padding_mode": "constant", "fill": [10.0, 10.0, 10.0]},
{"padding_mode": "constant", "fill": [10.0, 0.0, 10.0]},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test won't pass if we provide integers. That's because the test conducts JIT-script checks as well.

Here we check for single values, lists with the same value and lists with different values.


def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
def pad(
img: Tensor, padding: List[int], fill: Union[List[float], float] = 0.0, padding_mode: str = "constant"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value changed from int to float. JIT will fail if we pass integers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where exactly we need float values ? Maybe we could keep ints and List[int] and cast to float where it is required ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The int used previously is very misleading. We mainly use floats because our tensors get rescaled as you know. Unfortunately adding both List[int] and List[float] in the union doesn't work due to JIT issues. See pytorch/pytorch#69434

(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
img = pad(img, padding_ltrb, fill=0.0) # PIL uses fill value 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again we need floats to appease JIT.

if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
if not isinstance(fill, (numbers.Number, str, list, tuple)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated bug fix on the original code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a bug fix here and we do not expect this to land, maybe better to split this into a separate PR ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can cherrypick afterwards if we don't land this.

if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated bug fix on the original code. This method doesn't work if floats are provided for PIL images, despite the method having floats in the signature.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here ?

if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floats to please JIT

if not isinstance(fill, (tuple, list)):
fill = [fill]
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, -1, 1, 1)
if pad_top > 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling negative padding values.

"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
_assert_fill(fill, num_channels)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just move out the code to reuse it above.

raise TypeError("Got inappropriate padding arg")

if not isinstance(fill, (numbers.Number, str, tuple)):
if not isinstance(fill, (numbers.Number, str, tuple, list)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated bug fix on the original code.

return i, j, th, tw

def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0.0, padding_mode="constant"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floats to please JIT

@datumbox datumbox force-pushed the transforms/pad_fill branch from 2ff188a to 8625bd4 Compare March 8, 2022 20:29
@datumbox
Copy link
Contributor Author

datumbox commented Mar 8, 2022

The previous solution failed. Seems we might have a gap on our JIT scripts cause it was caught by the doc scripts.

@datumbox datumbox force-pushed the transforms/pad_fill branch from 8625bd4 to 9589e59 Compare March 8, 2022 20:30
if padding_mode == "constant":
# The following if/else can't be simplified due to JIT
if isinstance(fill, (tuple, list)):
fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can't we create it directly as

fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. :( JIT requires it to be behind an if statement. I believe this is because it invokes a different C++ method (The one that receives a list VS a scalar).

Copy link
Collaborator

@vfdev-5 vfdev-5 Mar 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean avoid to call .to:

- fill_img = torch.tensor(fill).to(dtype=img.dtype, device=img.device).view(1, -1, 1, 1)
+ fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, -1, 1, 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed that. That's also needed for the scalar case. I believe some of the tests were failing due to fill being float and dtype being integer. Casting solves this.

BTW you are welcome to push to the branch if you want to experiment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ansley FYI this is the kind of weird code one must write to make things JIT-scriptable. Without the explicit if statement, JIT doesn't know how to handle fill when scalar vs when list. I believe this has to do with the fact that the C++ implementation ends up calling a different method.

img = img.to(torch.float32)

img = torch_pad(img, p, mode=padding_mode, value=float(fill))
img = torch_pad(img, p, mode=padding_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fill is a scalar now we still transform it to a tensor and apply to the image at most 4 times below (img[..., :, :pad_left] = fill_img). Maybe, for performance reasons we could do if/else here and keep previous behaviour with a single torch_pad call for scalars and for list/tuple do what you coded ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that, see earlier versions of the commit. Unfortunately I couldn't find a way to write it in a JIT-friendly way. See here for more details. If you have ideas on how to have the optimization and be JIT-scriptable I'm happy to use them :)

@datumbox
Copy link
Contributor Author

This solution is not good enough. Though the JIT tests pass, there are issues:

  • This PR leaves the code in a worse state. It contains a few "voodoo" parts which are non-obvious and exist only to please JIT.
  • The end result is very brittle. Making small changes can instantly break the code.
  • It's not efficient. JIT doesn't let us apply the patch only when necessary.

For the above reasons, I will close the PR. Perhaps we can revisit this on the future when issue pytorch/pytorch#69434 is addressed. For now I think that applying the right fill for padding can be addressed on the side of new Class Transforms where things don't have to be JIT-scriptable.

@datumbox datumbox closed this Mar 10, 2022
@datumbox datumbox deleted the transforms/pad_fill branch March 10, 2022 12:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants