Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop gradient option for padding #68879

Open
pyscorcher opened this issue Nov 24, 2021 · 3 comments
Open

Stop gradient option for padding #68879

pyscorcher opened this issue Nov 24, 2021 · 3 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: autograd Related to torch.autograd, and the autograd engine in general module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pyscorcher
Copy link

pyscorcher commented Nov 24, 2021

馃殌 Feature

I'd like to propose an option to disable the gradient computation through padding pixels. Such that all source pixels get the same "weight". I propose an boolean option detach_padding that by default is set to False for backwards compatibility, but if set to true, the replicated pixels in the padding should not be considered in the gradient computation. This concerns all functions and modules that contain a padding option like torch.nn.functional.pad, torch.nn.ReflectionPad1d, torch.nn.ReplicationPad2d, torch.nn.Conv3d etc.

Motivation

Currently if we use padding options that reuse values from within the image, some pixels get considered multiple times in the gradient computation. This causes edge artifacts and is especially inconvenient with the circular options. If we use the circular padding mode we usually want toroidal topology (a.k.a. wallpaper tiling) where we just consider one patch of a repeating pattern. In that case we definitely don't want certain pixels to be considered different from the rest. But if we observe what the gradient looks like we see that this is not the case, the gradient computation also goes through the replicated pixels:

import torch
x = torch.zeros((1, 1, 6, 6), requires_grad=True)
conv = torch.nn.Conv2d(1, 1, 1, 1, 1, bias=False, padding_mode='circular') #same problem for other modes
with torch.no_grad():
    conv.weight[:] = 1
y = conv(x).sum()
y.backward()
print(x.grad)

I'd expect the gradient to be all ones, but this is not the case:

tensor([[[[4., 2., 2., 2., 2., 4.],
          [2., 1., 1., 1., 1., 2.],
          [2., 1., 1., 1., 1., 2.],
          [2., 1., 1., 1., 1., 2.],
          [2., 1., 1., 1., 1., 2.],
          [4., 2., 2., 2., 2., 4.]]]])

I claim that this behaviour is also not really wanted for the other padding modes, as it can cause artifacts at the edges of the images.

Pitch

Add an option to disable gradient through padding pixels in the various padding functions and modules.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mruberry @jbschlosser @walterddr @kshitij12345

@ejguan ejguan added module: autograd Related to torch.autograd, and the autograd engine in general module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 24, 2021
@jbschlosser jbschlosser added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix needs research We need to decide whether or not this merits inclusion, based on research world labels Nov 24, 2021
@github-actions github-actions bot added the Stale label Jan 23, 2022
@ezyang
Copy link
Contributor

ezyang commented Feb 6, 2022

This is a reasonable ask, though (1) not sure about the API and (2) not sure if there's a way to efficiently work around this problem right now

@lezcano lezcano assigned lezcano and unassigned lezcano Feb 6, 2022
@pyscorcher
Copy link
Author

While not directly related to this problem I wanted to highlight this following repo. They had to deal with 360掳 panoramic images, and they came up with a CircularConv2d that does respect that no artificial boundary in the circular direction gets a different multiplicator for the gradient: https://github.com/kazuto1011/circular-conv-pytorch

@OlliNiemitalo
Copy link

OlliNiemitalo commented Feb 2, 2024

That option could be useful for padding modes other than circular.

I find that circular convolution works as it should. In the code above, x is 6x6 while conv(x) result is 8x8 so we are looking beyond a single period of the convolved periodic image, and the edge pixels get "seen twice" due to wraparound. If we increase the kernel size to 3x3 then we are only looking at a single period in the conv(x) result and everything looks good:

import torch
x = torch.zeros((1, 1, 6, 6), requires_grad=True)
conv = torch.nn.Conv2d(1, 1, 3, 1, 1, bias=False, padding_mode='circular')
with torch.no_grad():
    conv.weight[:] = 1
y = conv(x).sum()
y.backward()
print(x.grad)

Output:

tensor([[[[9., 9., 9., 9., 9., 9.],
          [9., 9., 9., 9., 9., 9.],
          [9., 9., 9., 9., 9., 9.],
          [9., 9., 9., 9., 9., 9.],
          [9., 9., 9., 9., 9., 9.],
          [9., 9., 9., 9., 9., 9.]]]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: autograd Related to torch.autograd, and the autograd engine in general module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants