You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The rFFT of a real signal with a shape of HxW will only have Hx(W//2+1) independent components due to the symmetry of the transformed signal. Therefore, for a real tensor with a shape of Bx14x14xC (the input feature), torch.fft.rfft2(x, dim=(1, 2), norm='ortho') yields a Bx14x8xC complex tensor. Therefore, we set the shape of the filter to 14x8xdimx2 (2 for a complex tensor). Given a Bx14x8xC complex feature, either returning a Bx14x14xC feature or a Bx14x15xC feature from irfft is reasonable. Therefore, we should specify the shape of the output feature s=(H,W)=(14, 14).
class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".
The text was updated successfully, but these errors were encountered: