New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.fft tracking issue #42175
Comments
Summary: This PR creates a new namespace, torch.fft (torch::fft) and puts a single function, fft, in it. This function is analogous to is a simplified version of NumPy's [numpy.fft.fft](https://numpy.org/doc/1.18/reference/generated/numpy.fft.fft.html?highlight=fft#numpy.fft.fft) that accepts no optional arguments. It is intended to demonstrate how to add and document functions in the namespace, and is not intended to deprecate the existing torch.fft function. Adding this namespace was complicated by the existence of the torch.fft function in Python. Creating a torch.fft Python module makes this name ambiguous: does it refer to a function or module? If the JIT didn't exist, a solution to this problem would have been to make torch.fft refer to a callable class that mimicked both the function and module. The JIT, however, cannot understand this pattern. As a workaround it's required to explicitly `import torch.fft` to access the torch.fft.fft function in Python: ``` import torch.fft t = torch.randn(128, dtype=torch.cdouble) torch.fft.fft(t) ``` See #42175 for future work. Another possible future PR is to get the JIT to understand torch.fft as a callable class so it need not be imported explicitly to be used. Pull Request resolved: #41911 Reviewed By: glaringlee Differential Revision: D22941894 Pulled By: mruberry fbshipit-source-id: c8e0b44cbe90d21e998ca3832cf3a533f28dbe8d
A collection of my current thoughts on this. Top priority functionsThe linked PR covers
I haven't added tests for this yet, but I expect the current implementations to be differentiable because it still goes through
There's no equivalent in NumPy or SciPy. The commit history reveals the current interface was created for compatibility with With that in mind, I suggest these signatures, putting the new arguments at the end of the
Multi-dim transformsWrapping Batching Helper functions
In |
@peterbell10 |
Thanks Ralf. It looks like the API there is somewhat similar but requires some translation:
The SciPy function also has some other functionality that pytorch doesn't have, like In that case we need to decide between maintaining |
Alternatively, don't put |
Looks like there's two issues up for discussion:
@peterbell10, maybe we should talk about the former issue on Slack? As for stft and istft, let's talk about our options:
I'm not sure how much interest there is in other functions in the signal namespace, although @rgommers has pointed out that windowing functions like the Hann window are in scipy.signal.windows. NumPy also has windowing functions directly under its namespace (although is uses the wrong name for the Hann window) like Torch does. What are your thoughts? There is some potential for confusion with keeping both stft functions, but the signatures are distinct. Maybe calling torch.stft with SciPy-like arguments can not only throw an error, but redirect the user to the SciPy-compatible version? Or if the SciPy-compatible stft can replace the librosa-compatible stft, maybe we can deprecate the librosa version? @vincentqb, what are your thoughts from the torchaudio perspective? |
I think Functionality that builds on top of that (e.g. |
Would you like to suggest a third option, then, is to update the librosa-compatible torch.stft and not implement a NumPy-compatible stft? Sounds consistent to me. |
Yes that seems best to me. Except I'm not sure if |
Right, sorry, that was ambiguous. I was going to suggest we leave its name as is and keep its signature consistent with librosa but allow it to take and return complex tensors. |
Quick question. In #43011 I currently implement the SciPy behavior which is to promote half inputs to single precision. If there are future plans for complex half then it would be better to error out instead. |
We don't really support complex half right now, but in principle we could add support for it at some point in the future (probably CUDA only). You should plan as if, at some point in the future, we will add support for it. |
Okay, all the transform in #43011 now raise an error for half and bfloat16. |
Summary: Ref #42175 This removes the 4 deprecated spectral functions: `torch.{fft,rfft,ifft,irfft}`. `torch.fft` is also now imported by by default. The actual `at::native` functions are still used in `torch.stft` so can't be full removed yet. But will once #47601 has been merged. Pull Request resolved: #48594 Reviewed By: heitorschueroff Differential Revision: D25298929 Pulled By: mruberry fbshipit-source-id: e36737fe8192fcd16f7e6310f8b49de478e63bf0
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
Fixes #42175 (comment) [ghstack-poisoned]
I've run some simple benchmarks comparing Simple benchmark codeimport torch
import numpy as np
import scipy.fft
import cupy
import mkl_fft._numpy_fft
import mkl
import itertools
shape = (40, 40, 100)
#shape = (42, 14, 18)
c = torch.rand(*shape, dtype=torch.cdouble)
cn = c.numpy()
cc = c.cuda()
r = torch.rand(*shape, dtype=torch.double)
rn = r.numpy()
rc = r.cuda()
operators = ['fft', 'ifft', 'rfft', 'irfft', 'fftn', 'ifftn', 'rfftn', 'irfftn', 'hfft', 'ihfft']
results = []
def add_result(name, operator, dim, times):
results.append((name, operator, dim, times.best, times.average, times.stdev))
for op, dim in itertools.product(operators, (0, 1, 2)):
torch_fn = getattr(torch.fft, op)
numpy_fn = getattr(np.fft, op)
scipy_fn = getattr(scipy.fft, op)
mkl_fn = getattr(mkl_fft._numpy_fft, op, None)
cupy_fn = getattr(cupy.fft, op)
name = f'{op} dim={dim}'
if op.startswith('rfft') or op.startswith('ihfft'):
x, xn, xc = r, rn, rc
else:
x, xn, xc = c, cn, cc
xcp = cupy.array(xn)
if op.endswith('n'):
x = x.movedim(dim, -1)
xn = np.moveaxis(xn, dim, -1)
xc = xc.movedim(dim, -1)
xcp = cupy.moveaxis(xcp, dim, -1)
dim_kwargs = axis_kwargs = dict()
else:
axis_kwargs = {'axis': dim}
dim_kwargs = {'dim': dim}
print(name, 'multi-threaded')
torch.set_num_threads(8)
t = %timeit -o torch_fn(x, **dim_kwargs)
add_result('torch mulithreaded', op, dim, t)
if mkl_fn is not None:
mkl.set_num_threads(8)
t = %timeit -o mkl_fn(xn, **axis_kwargs)
add_result('mkl_fft multi threaded', op, dim, t)
t = %timeit -o scipy_fn(xn, workers=8, **axis_kwargs)
add_result('scipy multi threaded', op, dim, t)
print(name, 'single threaded')
torch.set_num_threads(1)
t = %timeit -o torch_fn(x, **dim_kwargs)
add_result('torch single threaded', op, dim, t)
if mkl_fn is not None:
mkl.set_num_threads(1)
t = %timeit -o mkl_fn(xn, **axis_kwargs)
add_result('mkl_fft single threaded', op, dim, t)
t = %timeit -o numpy_fn(xn, **axis_kwargs)
add_result('numpy', op, dim, t)
t = %timeit -o scipy_fn(xn, **axis_kwargs)
add_result('scipy single threaded', op, dim, t)
print(name, 'cufft')
t = %timeit -o torch_fn(xc, **dim_kwargs); torch.cuda.synchronize()
add_result('torch cuda', op, dim, t)
t = %timeit -o cupy_fn(xcp, **axis_kwargs); cupy.cuda.runtime.deviceSynchronize()
add_result('cupy', op, dim, t) The full results are available in some spread sheets but I'll sumarise here. For cuda, For single threaded CPU performance I compared against
For multithreaded CPU performance, I only compare against
|
These results sound great, @peterbell10! Is there more for performance you'd like to do, or do you think we should declare victory here? |
I'm very happy with the results here. Beating |
Summary: Ref #42175 This adds out argument support to all functions in the `torch.fft` namespace except for `fftshift` and `ifftshift` because they rely on `at::roll` which doesn't have an out argument version. Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from `out` arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call `at::mul_out` instead of an inplace multiply. If it's desirable, I could add a special case to transform into the output when `out.numel() == 0` since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where `out` having the correct shape follows a different code path from `out.resize_(0)`. Pull Request resolved: #49335 Reviewed By: mrshenli Differential Revision: D25756635 Pulled By: mruberry fbshipit-source-id: d29843f024942443c8857139a2abdde09affd7d6
Summary: Ref pytorch#42175 This adds out argument support to all functions in the `torch.fft` namespace except for `fftshift` and `ifftshift` because they rely on `at::roll` which doesn't have an out argument version. Note that there's no general way to do the transforms directly into the output since both cufft and mkl-fft only support single batch dimensions. At a minimum, the output may need to be re-strided which I don't think is expected from `out` arguments normally. So, on cpu this just copies the result into the out tensor. On cuda, the normalization is changed to call `at::mul_out` instead of an inplace multiply. If it's desirable, I could add a special case to transform into the output when `out.numel() == 0` since there's no expectation to preserve the strides in that case anyway. But that would lead to the slightly odd situation where `out` having the correct shape follows a different code path from `out.resize_(0)`. Pull Request resolved: pytorch#49335 Reviewed By: mrshenli Differential Revision: D25756635 Pulled By: mruberry fbshipit-source-id: d29843f024942443c8857139a2abdde09affd7d6
Although we're still working on a blog post and documentation can always be improved, I believe this issue has addressed its principal goal: creating the torch.fft analogous to NumPy's fft module. Great work @peterbell10! Going forward I think we'll track fft-related issues independently or we should start a new tracker with a new goal. |
Is it okay that old rfft supported onesided argument, and the new one only implements onesided=True. For new code, it's not a problem, but for porting it can cause some rewriting troubles. At least the docs should have a mini porting guide |
That's a good point. Would you comment on this issue? #49637. That way it'll be clearer how impactful creating such a guide would be. |
Tracking issue for tasks related to the torch.fft namespace, analogous to NumPy's numpy.fft namespace and SciPy's scipy.fft namespace.
PyTorch already has fft functions (fft, ifft, rfft , irfft, stft, istft), but they're inconsistent with NumPy and don't accept complex tensor inputs. The torch.fft namespace should be consistent with NumPy and SciPy where possible, plus provide a path towards removing PyTorch's existing fft functions in the 1.8 release (deprecating them in 1.7).
While adding the torch.fft namespace infrastructure and deprecating PyTorch's current fft-related functions are the top priorities, PyTorch is also missing many helpful functions, listed below, which should (eventually) be added to the new namespace, too.
Tasks:
Completed:
_fft_with_size
to handle transforming arbitrary dimensions without transposing/cloning (Improve torch.fft n-dimensional transforms #46911 still requires cloning, but this is required for best performance)gradgradcheck
passescc @ezyang @gchanan @zou3519 @anjali411 @dylanbespalko @mruberry @rgommers @peterbell10
The text was updated successfully, but these errors were encountered: