Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use old and new fft in one program #49695

Closed
boeddeker opened this issue Dec 21, 2020 · 12 comments
Closed

Use old and new fft in one program #49695

boeddeker opened this issue Dec 21, 2020 · 12 comments
Labels
module: fft triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@boeddeker
Copy link
Contributor

boeddeker commented Dec 21, 2020

馃悰 Bug

In 1.8.0.dev20201221+cpu the default torch.fft changed, i.e. it is now a module and not a callable.

It is now very challenging to write code that works with nightly and older PyTorch versions.

To Reproduce

Steps to reproduce the behavior:

The call torch.fft(...) should not fail with TypeError: 'module' object is not callable:

try:
    import torch.fft   # for python 1.7
except Exception:
    pass

def function_that_needs_new_torch_version():
    ...  # some code that uses torch.fft.fft

import torch
def function_that_still_uses_old_fft():
    torch.fft(...)

function_that_needs_new_torch_version()

Expected behavior

Support old and new fft in one program

Possible Solution

There is a python hack to support callable python modules:
https://stackoverflow.com/a/48100440/5766934

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0.dev20201221+cpu
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.6
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

#42175 looks like the main issue for the fft change, so maybe @mruberry or @peterbell10 have an opinion to this.

cc @mruberry @peterbell10 @walterddr

@ngimel ngimel added module: fft triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 21, 2020
@ngimel
Copy link
Collaborator

ngimel commented Dec 21, 2020

Hi @boeddeker sorry you are facing the issue. This is the expected effect of introducing torch.fft module. fft migration wiki unfortunately does not have a solution to your problem https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7.
cc @mruberry for possible workarounds.

@mruberry
Copy link
Collaborator

Sorry you're experiencing this issue, @boeddeker.

The wiki article has a few snippets that are relevant here. For example, you can use this pattern:

import sys
import warnings

if "torch.fft" not in sys.modules:
    with warnings.catch_warnings(record=True) as w:
        # calls torch.fft
else:
    # calls torch.fft.fft

Making the module callable was considered but we wanted to remove the older torch.fft(), not continue to support it, and it would have required changes to Torchscript to support it.

Would this pattern work for you?

Also, torch.fft.fft() is available starting in 1.7, so another pattern is to require the torch.fft module be available if you only supporting PyTorch 1.7+.

@boeddeker
Copy link
Contributor Author

Thank you for the suggestions. I was hoping to get a longer deprecation time and be able to port part by part.
(Beside using the newest pytorch version without porting the code now. Especially because here a breaking change is not necessary)
Is the old function still available?
It would be nice, when it is possible to write some tests, that the ported code does the same as the old code.

Making the module callable was considered but we wanted to remove the older torch.fft(), not continue to support it, and it would have required changes to Torchscript to support it.

Was it considered keeping the old torch.fft for those that don't use Torchscript?
The deprecation time in pytorch is effectively zero, when you use the fft function at more than one position.

@mruberry
Copy link
Collaborator

mruberry commented Dec 22, 2020

Is the old function still available?
Was it considered keeping the old torch.fft for those that don't use Torchscript?

No and no. We really want people to upgrade to complex tensors and stop using float tensors mimicking complex tensors.

The deprecation time in pytorch is effectively zero

I'm sympathetic to this feeling and I'm sorry the change seems so abrupt. This was an unfortunate case where the function vs module conflict was challenging to resolve. The deprecation warning has been in nightlies for 4 months.

when you use the fft function at more than one position.

You can create a helper, like my_fft that forwards the call to a common implementation. This helper can even replicate the old signature if you like. For example:

# Has a signature like torch.fft() but computes using torch.fft.fft
def my_fft(input, signal_ndim, normalized=False):
  if signal_ndim < 1 or signal_ndim > 3:
    print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive")
    return

  dims = (-1)
  if signal_ndim == 2:
    dims = (-2, -1)
  if signal_ndim == 3:
    dims = (-3, -2, -1)

  norm = "backward"
  if normalized:
    norm = "ortho"

  return torch.view_as_real(torch.fft.fftn(torch.view_as_complex(input), dim=dims, norm=norm))

@boeddeker
Copy link
Contributor Author

I'm sympathetic to this feeling and I'm sorry the change seems so abrupt. This was an unfortunate case where the function vs module conflict was challenging to resolve. The deprecation warning has been in nightlies for 4 months.

I recognized it some time ago, but I ignored it, because I observed some issues with complex number support in pytorch (missing support in functions for cpu or gpu or gradient support).
Maybe it is time to try again native complex operations.

You can create a helper, like my_fft that forwards the call to a common implementation. This helper can even replicate the old signature if you like. For example:

Thank you. That example is helpful to compare the old code with the new code.

@gchanan
Copy link
Contributor

gchanan commented Dec 28, 2020

@boeddeker please let us know if you run into missing complex number support -- we clearly don't want to push people to use incomplete APIs, so it would help us prioritize if you run into any issues.

@yihuajack
Copy link

Thank you for your Python example. What if immigrating fft of PyTorch C++ API?

@mruberry
Copy link
Collaborator

mruberry commented Jun 5, 2021

Thank you for your Python example. What if immigrating fft of PyTorch C++ API?

I believe you would use the same logic. All the functions used in the above example are available in the C++ API, too.

@AGenchev
Copy link

AGenchev commented Oct 9, 2021

the example my_fft() above fails for me with: Tensor must have a last dimension of size 2
the input tensor has shape [1, 1, 780, 1340], signal_ndim is 3.
Also I tried with ndim=2 to reproduce the example here:
https://pytorch.org/docs/0.4.1/torch.html#spectral-ops for torch.rfft:

x = torch.randn(5, 5)
>>> torch.rfft(x, 2, onesided=False).shape

like this:

x = torch.randn(5, 5)
my_fft(x, signal_ndim=2).shape

it fails as well

@peterbell10
Copy link
Collaborator

@AGenchev the example you copied is for torch.rfft which just accepts a real tensor. However, my_fft mimics torch.fft which expects a complex input represented as a real tensor with last dimension of size 2, as the error states.

@AGenchev
Copy link

AGenchev commented Oct 9, 2021

@peterbell10 my error, thank you for your answer.

@mruberry
Copy link
Collaborator

Closing this issue because I believe the initial question has been addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fft triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants