Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parameters #13

Closed
123456789-qwer opened this issue Apr 14, 2022 · 1 comment
Closed

parameters #13

123456789-qwer opened this issue Apr 14, 2022 · 1 comment

Comments

@123456789-qwer
Copy link

class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

def forward(self, x):
    B, H, W, C = x.shape
    x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
    weight = torch.view_as_complex(self.complex_weight)
    x = x * weight
    x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
    return x

Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".

@raoyongming
Copy link
Owner

Hi, thanks for your interest in our paper.

The rFFT of a real signal with a shape of HxW will only have Hx(W//2+1) independent components due to the symmetry of the transformed signal. Therefore, for a real tensor with a shape of Bx14x14xC (the input feature), torch.fft.rfft2(x, dim=(1, 2), norm='ortho') yields a Bx14x8xC complex tensor. Therefore, we set the shape of the filter to 14x8xdimx2 (2 for a complex tensor). Given a Bx14x8xC complex feature, either returning a Bx14x14xC feature or a Bx14x15xC feature from irfft is reasonable. Therefore, we should specify the shape of the output feature s=(H,W)=(14, 14).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants