-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update torch.rfft to torch.fft.rfft and complex tensor #941
Conversation
|
||
def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor: | ||
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft | ||
return torch._C._fft.fft_rfft(input, n, dim, norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote a response in the document you've shared about why this mapping is more complicated. I can probably write a mapping from torch.rfft
to torch.fft.rfft
for you, if you like, but I'm not sure that's actually what you want to do. torch.fft.rfft
returns a complex tensor, for example, so switching to it doesn't make a lot of sense unless you're also using complex tensors.
For now you may just want to suppress the warning that torch.rfft
throws?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or you could use torch.view_as_real
at the end, although I agree with @mruberry that it would be kind of wasteful to use the torch.fft.rfft
before we migrate audio to start using complex tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote a response in the document you've shared about why this mapping is more complicated.
I looked at your response and it was very helpful, however that document was for general cases, and here we are dealing with the migration of specific use cases, and I think the concerns brought up are irrelevant. Here are the reasons;
The all torch.rfft
usages in torchaudio
are
- Inputs are always 2D and
onesided=True
.
This means that we can simply migrate totorch.fft.rfft
, and it is sufficient. We do no need to considertorch.fft.rfftn
- No normalization is required.
We can passnorm=None
totorch.fft.rfft
, which is default so in my code change it is omitted. - The complex values are immediately used to compute power.
Look at the change infunctional.py
andkaldi.py
.rfft
is used as a mean of computing the power of the input signal in frequency domain. Therefore, with the appropriate changes (like using.abs
on complex dtype), no computation is wasted, or no unnecessary computation is introduced.
I can probably write a mapping from
torch.rfft
totorch.fft.rfft
For torchaudio's usecase, the general mapping from torch.rfft
to torch.fft.rfft
is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this compatibility module is not for migration from torch.rfft
to torch.fft.rfft
but for TorchScript compatibility torch.fft.rfft
(and others coming up) while avoiding explicit import of torch.fft
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a function that computes power spectrum (which is a real value) directly, then we do not need to use torch.fft.rfft
but otherwise, the use of complex dtype here as an intermediate expression makes sense and it is not wasteful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Thanks for the additional context, @mthrok. Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?
Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?
I gave some thoughts on it, and we could do that but I think having the same signature as torch.fft.rfft
has advantages when it comes to maintainability.
- If someone wants to add a new functionality that uses
rfft
, he/she can simply use the same signature astorch.fft.rfft
, that will free him/her from having to consider the relationship between our abstraction function and other functions of torchaudio that use this abstraction function. (in short, the person working on the new feature can use this abstraction as a drop-in replacement) - similar to 1., in future when we PyTorch is done with the immigration, we can simply replace the abstraction function path with the actual PyTorch implementation. (
torchaudio._internal.fft.rfft
->torch.fft.rfft
)
Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.
That is okay for domain libraries. We clearly state that domain libraries are expected/tested to work with the version of PyTorch that is released at the same time. So all the work on master branch expects master version (or the next latest stable release) of PyTorch.
@@ -2082,7 +2083,7 @@ def _measure( | |||
if boot_count >= 0 \ | |||
else measure_smooth_time_mult | |||
|
|||
_d = complex_norm(_dftBuf[spectrum_start:spectrum_end]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep the change norm/abs separate. In a prior attempt, there was a performance regression, #747.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vincentqb Can you elaborate why the regression in torch.norm
, which is not used in this PR, is the reason to discourage the use of torch.abs
? Are they using the same implementation under the hood?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll let @anjali411 @mruberry comment on how torch.abs
and torch.norm
are related if they are. If they are we would end up with a performance regression again. We could add some performance tests/checks manually or automatically to catch performance changes. However, my suggestion is simply to decouple the changes about torch.*fft*
from those about the use of .abs()
so that, in the event we do get a regression, we can easily track it and revert it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Datapoint: torch.abs and torch.norm have separate implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be perfectly fine to use torch.abs()
here. complex_norm
is not even using torch.norm
https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L424. In fact, it's using three kernels (pow, sum, pow), so torch.abs()
should be a strict improvement here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Datapoint: torch.abs and torch.norm have separate implementations.
Good to know :)
I think it should be perfectly fine to use
torch.abs()
here.complex_norm
is not even usingtorch.norm
https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L424. In fact, it's using three kernels (pow, sum, pow), sotorch.abs()
should be a strict improvement here.
We had complex_norm
using torch.norm
, but this led to a speed regression so we revert to the current implementation with three kernels. Because of this, I want to avoid a performance regression again. I'm glad to know this would launch only one kernel: has the performance difference been tested and compared in this case?
In any case, let's just separate concerns and move this to a separate pull request so we don't block the rest on this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x=torch.randn(10, 20, 2)
In [7]: def fn(x):
...: t0=time.time()
...: o=x.pow(2.0).sum(-1).pow(0.5)
...: t1 = time.time()
...: print(t1-t0)
...:
In [8]: fn(x)
0.00037217140197753906
In [10]: y=torch.view_as_complex(x).contiguous()
In [11]: fn(y)
0.0001480579376220703
There's a significant performance gain (as expected) so I think we should switch to torch.abs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @anjali411 !
In addition to that I did benchmark for the exact code path that @vincentqb suggested and observed that abs
is 10x faster on CPU and 2x faster on GPU.
Method | CPU | GPU |
---|---|---|
t.abs() |
100 loops, best of 5: 1.34 msec per loop | 100 loops, best of 5: 28.2 usec per loop |
complex_norm(view_as_real(t)) |
100 loops, best of 5: 14.9 msec per loop | 100 loops, best of 5: 60.5 usec per loop |
PyTorch: 1.8.0a0+edac406
torchaudio: 0.8.0a0+0076ab0
code
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
t.abs();
"""
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
t.abs();
"""
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking!
torchaudio/compliance/kaldi.py
Outdated
|
||
power_spectrum = fft.pow(2).sum(2).unsqueeze(1) # size (m, 1, padded_window_size // 2 + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
complex_norm( | ||
_cepstrum_Buf[cepstrum_start:cepstrum_end], | ||
power=2.0))) | ||
result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
|
||
# Convert the FFT into a power spectrum | ||
power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1) | ||
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
39edd48
to
ac02f9f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex number stuff looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Any reason why this is marked as draft?
Update
torch.rfft
(deprecated) to use the equivalent oftorch.fft.rfft
but without importingtorch.fft
.Also, update power computation of complex value tensor to use complex type tensor.