Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fftshift and ifftshift in pytorch #42075

Closed
Mon-ius opened this issue Jul 26, 2020 · 8 comments
Closed

support fftshift and ifftshift in pytorch #42075

Mon-ius opened this issue Jul 26, 2020 · 8 comments
Labels
module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Mon-ius
Copy link

Mon-ius commented Jul 26, 2020

Currently working on the IBM Power9 series clusters.

Cannot run such code:

x = torch.tensor(1, dtype=torch.complex64, device = 'cuda')

Error shows below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-93-353b552465e0> in <module>
----> 1 x = torch.tensor(1, dtype=torch.complex64, device = 'cuda')

RuntimeError: Could not run 'aten::empty.memory_format' with arguments from the 'ComplexCPUTensorId' backend. 'aten::empty.memory_format' is only available for these backends: [CUDATensorId, SparseCPUTensorId, VariableTensorId, CPUTensorId, MkldnnCPUTensorId, SparseCUDATensorId].

cc @ezyang @anjali411 @dylanbespalko

@ngimel
Copy link
Collaborator

ngimel commented Jul 26, 2020

What pytorch version are you using? The snippet above works on master.

@ngimel ngimel added module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 26, 2020
@Mon-ius
Copy link
Author

Mon-ius commented Jul 26, 2020

1.4.0

  Operating System: Red Hat Enterprise Linux
       CPE OS Name: cpe:/o:redhat:enterprise_linux:7.6:GA:server
            Kernel: Linux 4.14.0-115.7.1.el7a.ppc64le
      Architecture: ppc64-le

@ngimel
Copy link
Collaborator

ngimel commented Jul 27, 2020

Please update to the latest version

@Mon-ius
Copy link
Author

Mon-ius commented Jul 29, 2020

I tried the latest version. It does support the complex type. But I didn't see there is some function like np.fft.fftshift or np.fft.ifftshift if it has already been supported or I have to write one to do that. I don't want to switch to numpy and pytorch a lot during training.

@zasdfgbnm
Copy link
Collaborator

fft support is tracked at #33152

@ArthDh
Copy link

ArthDh commented Aug 16, 2020

@Mon-ius I couldn't find fftshift/ifftshift on master but this should work fine:

source: locuslab/pytorch_fft#12

def roll_n(X, axis, n):
    f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None)
                  for i in range(X.dim()))
    b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None)
                  for i in range(X.dim()))
    front = X[f_idx]
    back = X[b_idx]
    return torch.cat([back, front], axis)

def fftshift(X):
    # batch*channel*...*2
    real, imag = X.chunk(chunks=2, dim=-1)
    real, imag = real.squeeze(dim=-1), imag.squeeze(dim=-1)

    for dim in range(2, len(real.size())):
        real = roll_n(real, axis=dim, n=int(np.ceil(real.size(dim) / 2)))
        imag = roll_n(imag, axis=dim, n=int(np.ceil(imag.size(dim) / 2)))

    real, imag = real.unsqueeze(dim=-1), imag.unsqueeze(dim=-1)
    X = torch.cat((real,imag),dim=-1)
    return X

def ifftshift(X):
    # batch*channel*...*2
    real, imag = X.chunk(chunks=2, dim=-1)
    real, imag = real.squeeze(dim=-1), imag.squeeze(dim=-1)

    for dim in range(len(real.size()) - 1, 1, -1):
        real = roll_n(real, axis=dim, n=int(np.floor(real.size(dim) / 2)))
        imag = roll_n(imag, axis=dim, n=int(np.floor(imag.size(dim) / 2)))

    real, imag = real.unsqueeze(dim=-1), imag.unsqueeze(dim=-1)
    X = torch.cat((real, imag), dim=-1)

    return X

@ngimel
Copy link
Collaborator

ngimel commented Aug 16, 2020

Thanks @ArthDh, I'd like to point out that torch has torch.roll function that provides similar funcitonality to roll_n with likely better performance.

@ngimel ngimel changed the title Complex type error support fftshift and ifftshift in pytorch Aug 16, 2020
@ngimel
Copy link
Collaborator

ngimel commented Aug 16, 2020

Closing in favor of #42175

@ngimel ngimel closed this as completed Aug 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants