Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.fft tracking issue #42175

Closed
33 of 37 tasks
mruberry opened this issue Jul 28, 2020 · 24 comments
Closed
33 of 37 tasks

torch.fft tracking issue #42175

mruberry opened this issue Jul 28, 2020 · 24 comments
Assignees
Labels
feature A request for a proper, new feature. high priority module: complex Related to complex number support in PyTorch module: fft module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mruberry
Copy link
Collaborator

mruberry commented Jul 28, 2020

Tracking issue for tasks related to the torch.fft namespace, analogous to NumPy's numpy.fft namespace and SciPy's scipy.fft namespace.

PyTorch already has fft functions (fft, ifft, rfft , irfft, stft, istft), but they're inconsistent with NumPy and don't accept complex tensor inputs. The torch.fft namespace should be consistent with NumPy and SciPy where possible, plus provide a path towards removing PyTorch's existing fft functions in the 1.8 release (deprecating them in 1.7).

While adding the torch.fft namespace infrastructure and deprecating PyTorch's current fft-related functions are the top priorities, PyTorch is also missing many helpful functions, listed below, which should (eventually) be added to the new namespace, too.

Tasks:

  • Write doc preamble to fft module
  • Write blogpost about the torch.fft module
  • Write tutorial
  • Create forum group

Completed:

cc @ezyang @gchanan @zou3519 @anjali411 @dylanbespalko @mruberry @rgommers @peterbell10

@mruberry mruberry added high priority feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: complex Related to complex number support in PyTorch module: numpy Related to numpy support, and also numpy compatibility of our operators labels Jul 28, 2020
facebook-github-bot pushed a commit that referenced this issue Aug 6, 2020
Summary:
This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function.

Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python:

```
import torch.fft

t = torch.randn(128, dtype=torch.cdouble)
torch.fft.fft(t)
```

See #42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used.

Pull Request resolved: #41911

Reviewed By: glaringlee

Differential Revision: D22941894

Pulled By: mruberry

fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
@mruberry mruberry added the fft label Aug 6, 2020
@peterbell10
Copy link
Collaborator

A collection of my current thoughts on this.

Top priority functions

The linked PR covers fft, ifft, rfft and irfft. These should be fully numpy compatible (except for using dim instead of axis).

It may be easiest to start with forward for these and implement backward as separate follow-ups.

I haven't added tests for this yet, but I expect the current implementations to be differentiable because it still goes through _fft_with_size which has the derivative registered.

Implement torch.stft() (is there a better NumPy or SciPy analogue for this and its inverse?)

There's no equivalent in NumPy or SciPy. The commit history reveals the current interface was created for compatibility with librosa.stft so I think it's reasonable to maintain compatibility there. To make it more NumPy-like we can add a dim argument, change the normalized argument into the norm mode string and instead of having a onesided argument have two functions stft and rstft. librosa has none of those arguments so this shouldn't be a problem.

With that in mind, I suggest these signatures, putting the new arguments at the end of the librosa signature:

stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None,
     Tensor? window=None, int dim=-1, str? norm=None) -> Tensor

rstft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None,
      Tensor? window=None, int dim=-1, str? norm=None) -> Tensor

Multi-dim transforms

Wrapping _fft_with_size for multi-dim transforms will not be quite as simple since _fft_with_size only supports transforming up to 3 transform dimensions at a time. This comes from a limitation of cuFFT's transforms, mkl fft goes up to 7. The simple route is to split transforms into batches of 3 dimensions at a time. Or, a deeper rewrite of _fft_with_size is needed.

Batching _fft_with_size will be quite expensive because _fft internally reshapes the arrays to squash the batch dimensions. So it will waste time copying around data. If there's a rush to get this functionality in for the next release then I could start with a simple wrapper and move to a more performant rewrite after.

Helper functions

fftfreq, fftshift are both quite handy and actually really simple. fftfreq is just arange + mul, fftshift is just a roll.

In scipy.fft we also have a next_fast_len helper which calculates the lowest multiple of the FFT implementation's prime radices above a target value. This is useful for convolution-like applications where zero-padding is acceptable. Though, cuFFT doesn't implement radix-11 transforms so in principle this is different between cuda and cpu tensors.

@rgommers
Copy link
Collaborator

There's no equivalent in NumPy or SciPy.

@peterbell10 scipy.signal has stft and istft functions - it fits better than in scipy.fft, in the spectral analysis set of functions.

@peterbell10
Copy link
Collaborator

peterbell10 commented Aug 15, 2020

Thanks Ralf. It looks like the API there is somewhat similar but requires some translation:

n_fft -> nfft
window_length -> nperseg
hop_length -> overlap = nperseg - hop_length
pad_mode -> boundary
onesided -> return_onesided

The SciPy function also has some other functionality that pytorch doesn't have, like detrend, padded and fs.

In that case we need to decide between maintaining librosa compat, or changing to the SciPy names.

@rgommers
Copy link
Collaborator

Alternatively, don't put stft and istft in torch.fft. Despite the last two letters of the function names meaning "Fourier Transform", these functions really are more similar to spectrogram, power spectral density, periodogram, etc. signal processing functionality than to regular Fourier transforms. And given that there's no equivalent in numpy.fft or scipy.fft, why add it here?

@mruberry
Copy link
Collaborator Author

mruberry commented Aug 16, 2020

Looks like there's two issues up for discussion:

  • How do we support multidim transforms?
  • What do we do with stft and istft?

@peterbell10, maybe we should talk about the former issue on Slack?

As for stft and istft, let's talk about our options:

  • "Max Compat"
    • keep torch.stft and torch.istft and document them as being consistent with librosa
    • add a signal namespace
    • create torch.signal.stft and torch.signal.istft consistent with SciPy
  • "Fewer namespaces"
    • same as above except we create torch.fft.stft instead of torch.signal.stft

I'm not sure how much interest there is in other functions in the signal namespace, although @rgommers has pointed out that windowing functions like the Hann window are in scipy.signal.windows. NumPy also has windowing functions directly under its namespace (although is uses the wrong name for the Hann window) like Torch does.

What are your thoughts? There is some potential for confusion with keeping both stft functions, but the signatures are distinct. Maybe calling torch.stft with SciPy-like arguments can not only throw an error, but redirect the user to the SciPy-compatible version? Or if the SciPy-compatible stft can replace the librosa-compatible stft, maybe we can deprecate the librosa version?

@vincentqb, what are your thoughts from the torchaudio perspective?

@rgommers
Copy link
Collaborator

I'm not sure how much interest there is in other functions in the signal namespace,

I think scipy.signal is one of the less well maintained submodules in SciPy, and (maybe more importantly) it contains fairly high-level functionality. I would not expect that kind of functionality to be in PyTorch itself. My mental picture is that a few SciPy modules depend only on NumPy and BLAS/LAPACK, and those are the most important ones to have in PyTorch: fft, linalg, sparse, special.

Functionality that builds on top of that (e.g. signal, optimize, stats, cluster, spatial) is probably better kept separate. Now of course there's a few bits and pieces, like stft and LBFGS, that cover the same functionality as those SciPy modules. I'd suggest that's not a strong reason to then make another namespace to make them match.

@mruberry
Copy link
Collaborator Author

I'm not sure how much interest there is in other functions in the signal namespace,

I think scipy.signal is one of the less well maintained submodules in SciPy, and (maybe more importantly) it contains fairly high-level functionality. I would not expect that kind of functionality to be in PyTorch itself. My mental picture is that a few SciPy modules depend only on NumPy and BLAS/LAPACK, and those are the most important ones to have in PyTorch: fft, linalg, sparse, special.

Functionality that builds on top of that (e.g. signal, optimize, stats, cluster, spatial) is probably better kept separate. Now of course there's a few bits and pieces, like stft and LBFGS, that cover the same functionality as those SciPy modules. I'd suggest that's not a strong reason to then make another namespace to make them match.

Would you like to suggest a third option, then, is to update the librosa-compatible torch.stft and not implement a NumPy-compatible stft? Sounds consistent to me.

@rgommers
Copy link
Collaborator

Would you like to suggest a third option, then, is to update the librosa-compatible torch.stft and not implement a NumPy-compatible stft? Sounds consistent to me.

Yes that seems best to me. Except I'm not sure if stft even needs an update? If you mean putting it into torch.fft, I see pros and cons to that. If you want to put it in, let's just say in the docstring "note: this function matches librosa.stft and differs from scipy.signal.stft"

@mruberry
Copy link
Collaborator Author

Would you like to suggest a third option, then, is to update the librosa-compatible torch.stft and not implement a NumPy-compatible stft? Sounds consistent to me.

Yes that seems best to me. Except I'm not sure if stft even needs an update? If you mean putting it into torch.fft, I see pros and cons to that. If you want to put it in, let's just say in the docstring "note: this function matches librosa.stft and differs from scipy.signal.stft"

Right, sorry, that was ambiguous. I was going to suggest we leave its name as is and keep its signature consistent with librosa but allow it to take and return complex tensors.

@peterbell10
Copy link
Collaborator

peterbell10 commented Aug 20, 2020

Quick question. cuFFT has some limited support for half precision FFTs (only for power of 2 lengths). But at the moment ComplexHalf doesn't seem to have any support in pytorch. I notice c10 has some references to it and there is even a torch.complex32 type accessible from python but I can't seem to create a tensor with that dtype.

In #43011 I currently implement the SciPy behavior which is to promote half inputs to single precision. If there are future plans for complex half then it would be better to error out instead.

@ezyang
Copy link
Contributor

ezyang commented Aug 20, 2020

We don't really support complex half right now, but in principle we could add support for it at some point in the future (probably CUDA only). You should plan as if, at some point in the future, we will add support for it.

@peterbell10
Copy link
Collaborator

Okay, all the transform in #43011 now raise an error for half and bfloat16.

@mruberry mruberry changed the title torch.fft namespace tracking issue torch.fft tracking issue Sep 9, 2020
facebook-github-bot pushed a commit that referenced this issue Dec 5, 2020
Summary:
Ref #42175

This removes the 4 deprecated spectral functions: `torch.{fft,rfft,ifft,irfft}`. `torch.fft` is also now imported by by default.

The actual `at::native` functions are still used in `torch.stft` so can't be full removed yet. But will once #47601 has been merged.

Pull Request resolved: #48594

Reviewed By: heitorschueroff

Differential Revision: D25298929

Pulled By: mruberry

fbshipit-source-id: e36737fe8192fcd16f7e6310f8b49de478e63bf0
peterbell10 added a commit that referenced this issue Dec 6, 2020
peterbell10 added a commit that referenced this issue Dec 6, 2020
peterbell10 added a commit that referenced this issue Dec 8, 2020
peterbell10 added a commit that referenced this issue Dec 8, 2020
peterbell10 added a commit that referenced this issue Dec 8, 2020
peterbell10 added a commit that referenced this issue Dec 8, 2020
peterbell10 added a commit that referenced this issue Dec 8, 2020
peterbell10 added a commit that referenced this issue Dec 8, 2020
peterbell10 added a commit that referenced this issue Dec 9, 2020
peterbell10 added a commit that referenced this issue Dec 9, 2020
@mruberry mruberry reopened this Dec 10, 2020
@peterbell10
Copy link
Collaborator

I've run some simple benchmarks comparing torch.fft transforms against mkl_fft, numpy, scipy and cupy. I haven't sampled a huge number of shapes because the benchmarks take a fairly long time to run, but do cover transforming over different dimensions which is interesting because the resulting performance depends a lot more on how batching is handled.

Simple benchmark code
import torch
import numpy as np
import scipy.fft
import cupy
import mkl_fft._numpy_fft
import mkl
import itertools

shape = (40, 40, 100)
#shape = (42, 14, 18)
c = torch.rand(*shape, dtype=torch.cdouble)
cn = c.numpy()
cc = c.cuda()
r = torch.rand(*shape, dtype=torch.double)
rn = r.numpy()
rc = r.cuda()

operators = ['fft', 'ifft', 'rfft', 'irfft', 'fftn', 'ifftn', 'rfftn', 'irfftn', 'hfft', 'ihfft']
results = []
def add_result(name, operator, dim, times):
    results.append((name, operator, dim, times.best, times.average, times.stdev))

for op, dim in itertools.product(operators, (0, 1, 2)):
    torch_fn = getattr(torch.fft, op)
    numpy_fn = getattr(np.fft, op)
    scipy_fn = getattr(scipy.fft, op)
    mkl_fn = getattr(mkl_fft._numpy_fft, op, None)
    cupy_fn = getattr(cupy.fft, op)
    name = f'{op} dim={dim}'
    if op.startswith('rfft') or op.startswith('ihfft'):
        x, xn, xc = r, rn, rc
    else:
        x, xn, xc = c, cn, cc

    xcp = cupy.array(xn)
    if op.endswith('n'):
        x = x.movedim(dim, -1)
        xn = np.moveaxis(xn, dim, -1)
        xc = xc.movedim(dim, -1)
        xcp = cupy.moveaxis(xcp, dim, -1)
        dim_kwargs = axis_kwargs = dict()
    else:
        axis_kwargs = {'axis': dim}
        dim_kwargs = {'dim': dim}

    print(name, 'multi-threaded')
    torch.set_num_threads(8)
    t = %timeit -o torch_fn(x, **dim_kwargs)
    add_result('torch mulithreaded', op, dim, t)
    if mkl_fn is not None:
        mkl.set_num_threads(8)
        t = %timeit -o mkl_fn(xn, **axis_kwargs)
        add_result('mkl_fft multi threaded', op, dim, t)
    t = %timeit -o scipy_fn(xn, workers=8, **axis_kwargs)
    add_result('scipy multi threaded', op, dim, t)

    print(name, 'single threaded')
    torch.set_num_threads(1)
    t = %timeit -o torch_fn(x, **dim_kwargs)
    add_result('torch single threaded', op, dim, t)
    if mkl_fn is not None:
        mkl.set_num_threads(1)
        t = %timeit -o mkl_fn(xn, **axis_kwargs)
        add_result('mkl_fft single threaded', op, dim, t)
    t = %timeit -o numpy_fn(xn, **axis_kwargs)
    add_result('numpy', op, dim, t)
    t = %timeit -o scipy_fn(xn, **axis_kwargs)
    add_result('scipy single threaded', op, dim, t)

    print(name, 'cufft')
    t = %timeit -o torch_fn(xc, **dim_kwargs); torch.cuda.synchronize()
    add_result('torch cuda', op, dim, t)
    t = %timeit -o cupy_fn(xcp, **axis_kwargs); cupy.cuda.runtime.deviceSynchronize()
    add_result('cupy', op, dim, t)

The full results are available in some spread sheets but I'll sumarise here.
fft-comparison-benchmarks.zip

For cuda, torch.fft.* either performed similarly or as much as 2x faster than cupy.fft.*. That's pretty good considering we're both just calling cuFFT under the hood.

For single threaded CPU performance I compared against mkl_fft, numpy and scipy.

  • numpy was slower in all cases, at best around 1.5x slower than pytorch and at worst 3.9x slower.
  • scipy hovered around 1-1.5x slower in most cases. For the larger shaped tensor, scipy is around 15-20% faster at ihfft on all but the last dimension. This is perhaps not surprising though, as scipy.fft has custom kernels for hfft and ihfft whereas intel mkl doesn't.
  • Perhaps most interesting is mkl_fft. mkl_fft is often a few microseconds (around 10% for the smaller tensor) faster for fft and ifft with dim=0 or dim=-1. I would attribute this to mkl_fft caching the transform descriptor between calls whereas torch.fft recreates them each time. This significantly simplifies the code though, so may not be worth the extra microseconds.
    On the other hand, many dim=1 transforms were 2-4x slower on mkl_fft than torch.fft so the way I'm handling batching is clearly good.

For multithreaded CPU performance, I only compare against mkl_fft and scipy since numpy doesn't support multithreading.

  • scipy was up to 30% faster for the smaller single dimensional transform on dim=1. Probably because It doesn't require contiguous batch dimensions, so avoids an extra copy. However, this difference becomes negligible for multi-dimensional transforms or transforms over larger tensors. In most other cases, torch.fft was 2-3x faster.
  • mkl_fft is again unsurprisingly similar in the easy cases. It does still get a slight advantage from descriptor caching though. I tested with 8 threads and at that level plan creation took up to 40% of the transform time for the smaller tensor, but still was only a few microseconds of overhead.
    However, on the harder cases torch.fft can actually be much faster than mkl_fft. I see 4-6x faster for the larger tensor on rfftn and irfftn e.g. in one case mkl_fft takes 1.9 ms and torch.fft takes only 0.3 ms to do the same transform.

@mruberry
Copy link
Collaborator Author

These results sound great, @peterbell10! Is there more for performance you'd like to do, or do you think we should declare victory here?

@peterbell10
Copy link
Collaborator

I'm very happy with the results here. Beating mkl_fft and cupy at all, let alone by such a large margin in some cases, is more than I was really expecting given we're using the same FFT libraries under the hood.

facebook-github-bot pushed a commit that referenced this issue Jan 6, 2021
Summary:
Ref #42175

This adds out argument support to all functions in the `torch.fft` namespace except for `fftshift` and `ifftshift` because they rely on `at::roll` which doesn't have an out argument version.

Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from `out` arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call `at::mul_out` instead of an inplace multiply.

If it's desirable, I could add a special case to transform into the output when `out.numel() == 0` since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where `out` having the correct shape follows a different code path from `out.resize_(0)`.

Pull Request resolved: #49335

Reviewed By: mrshenli

Differential Revision: D25756635

Pulled By: mruberry

fbshipit-source-id: d29843f024942443c8857139a2abdde09affd7d6
hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this issue Jan 14, 2021
Summary:
Ref pytorch#42175

This adds out argument support to all functions in the `torch.fft` namespace except for `fftshift` and `ifftshift` because they rely on `at::roll` which doesn't have an out argument version.

Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from `out` arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call `at::mul_out` instead of an inplace multiply.

If it's desirable, I could add a special case to transform into the output when `out.numel() == 0` since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where `out` having the correct shape follows a different code path from `out.resize_(0)`.

Pull Request resolved: pytorch#49335

Reviewed By: mrshenli

Differential Revision: D25756635

Pulled By: mruberry

fbshipit-source-id: d29843f024942443c8857139a2abdde09affd7d6
@mruberry
Copy link
Collaborator Author

Although we're still working on a blog post and documentation can always be improved, I believe this issue has addressed its principal goal: creating the torch.fft analogous to NumPy's fft module. Great work @peterbell10! Going forward I think we'll track fft-related issues independently or we should start a new tracker with a new goal.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Mar 24, 2021

Is it okay that old rfft supported onesided argument, and the new one only implements onesided=True. For new code, it's not a problem, but for porting it can cause some rewriting troubles. At least the docs should have a mini porting guide

@mruberry
Copy link
Collaborator Author

Is it okay that old rfft supported onesided argument, and the new one only implements onesided=True. For new code, it's not a problem, but for porting it can cause some rewriting troubles. At least the docs should have a mini porting guide

That's a good point. Would you comment on this issue? #49637. That way it'll be clearer how impactful creating such a guide would be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. high priority module: complex Related to complex number support in PyTorch module: fft module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants