Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Periodic/Circular Padding in Transpose Convolution #38

Open
Ceyron opened this issue Feb 7, 2024 · 9 comments
Open

Periodic/Circular Padding in Transpose Convolution #38

Ceyron opened this issue Feb 7, 2024 · 9 comments

Comments

@Ceyron
Copy link

Ceyron commented Feb 7, 2024

Hi,

Thanks a lot for open-sourcing the code; it has inspired me greatly! 😊

I am currently trying to set up a UNet architecture similar to the ones used in this repo to be applied to fields with periodic boundary conditions. In the PDE-Refiner paper, you write:

Crucially, all convolutions use circular padding in the U-Net to account for the periodic domain.

As far as I understand, the KS example (which has periodic BCs) uses this architecture which should instantiate the model of this file. For this UNet, you use the default padding_mode which you overwrite to be circular in this file. Part of the UNet is also an upsampling with the transpose convolution. For this convolution; however, it seems that there is only the default choice of padding_modes="zeros" which is also the only mode PyTorch supports.

Maybe the question is stupid, but shouldn't the transpose operator also use periodic padding like, for instance, flax supports?

@Hrrsmjd
Copy link
Contributor

Hrrsmjd commented Feb 15, 2024

@Ceyron I had the same concern. You might be able to use Conv1d (instead of ConvTranspose1d) with padding_mode=circular and adjust padding.

@Ceyron
Copy link
Author

Ceyron commented Feb 15, 2024

Thanks for the reply.

Don't we need the transpose convolution for upsampling?

@Hrrsmjd
Copy link
Contributor

Hrrsmjd commented Feb 15, 2024

Thanks for the reply.

Don't we need the transpose convolution for upsampling?

Yes, but it is always possible to reproduce a transposed convolution using a regular convolution.

I think, in this case with padding=1, you can simply replace ConvTranspose1d(_, _, kernel_size=3, stride=~2~1, padding=1) with Conv1d(_, _, kernel_size=3, stride=~2~1, padding=1) and get the same output.

@Ceyron
Copy link
Author

Ceyron commented Feb 15, 2024

I am not sure about this. Under similar settings (with stride=2) a forward conv reduces the spatial dimension (approx. half) and a transpose conv increases the spatial dimension (approx. twice).

In [1]: import torch

In [2]: conv = torch.nn.Conv1d(1, 1, 3, 2, 1)

In [3]: conv_transpose = torch.nn.ConvTranspose1d(1, 1, 3, 2, 1)

In [4]: x = torch.rand(1, 1, 10)

In [5]: conv(x).shape
Out[5]: torch.Size([1, 1, 5])

In [6]: conv_transpose(x).shape
Out[6]: torch.Size([1, 1, 19])

@Hrrsmjd
Copy link
Contributor

Hrrsmjd commented Feb 15, 2024

I am not sure about this. Under similar settings (with stride=2) a forward conv reduces the spatial dimension (approx. half) and a transpose conv increases the spatial dimension (approx. twice).

Sorry. You are right, that doesn't work if you set stride=2 for both. You might have to adjust stride, padding, and dilation.

@Ceyron
Copy link
Author

Ceyron commented Feb 15, 2024

I don't think that we can make the forward conv equal to the transpose conv (given stride=2) with any of the options. The essential ingredient is that transpose convs do "lhs dilation" (in the terminology of JAX).

@Hrrsmjd
Copy link
Contributor

Hrrsmjd commented Feb 15, 2024

@Ceyron See below:

import torch

## Input
x = torch.rand(1, 1, 64)

## ConvTranspose1d that we want to reproduce:
conv_transpose = torch.nn.ConvTranspose1d(1, 1, 3, stride=2, padding=1)
print(conv_transpose(x).shape) ## torch.Size([1, 1, 127])

## Reproduced ConvTranspose1d:
## Insert zeros between input elements:
x_new = torch.zeros(1, 1, x.shape[-1]*2 - 1) ## Send x_new to device!
x_new[..., ::2] = x 
conv = torch.nn.Conv1d(1, 1, 3, stride=1, padding=1, padding_mode='circular')
print(conv(x_new).shape) ## torch.Size([1, 1, 127])

@Ceyron
Copy link
Author

Ceyron commented Feb 15, 2024

Yeah, that works. It's a manual "lhs dilation" 😉 Would be nice if PyTorch had this by default.

But coming back to my original question: Don't you think that the UNet needs this transpose conv with circular/periodic padding?

@phlippe
Copy link
Contributor

phlippe commented Mar 15, 2024

Hi @Ceyron, sorry for the late reply, I just came across the issue. :)

Well spotted with the transposed convolutions! We found that in practice, this small non-periodicity didn't affect the results. The skip connections introduce back fully periodic information, and the model can correct these small differences with subsequent periodic convolutions. To make the upsample periodic, I would recommend reducing the kernel size to 2x2, which gives no overlap, or use PyTorch's Upsample, as used in other U-Net implementations. We haven't seen any benefits from it, but in other settings it might be more important.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants