Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torch.rfft to torch.fft.rfft and complex tensor #941

Merged
merged 1 commit into from
Nov 5, 2020

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented Oct 8, 2020

Update torch.rfft (deprecated) to use the equivalent of torch.fft.rfft but without importing torch.fft.
Also, update power computation of complex value tensor to use complex type tensor.


def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor:
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft
return torch._C._fft.fft_rfft(input, n, dim, norm)
Copy link
Contributor

@mruberry mruberry Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a response in the document you've shared about why this mapping is more complicated. I can probably write a mapping from torch.rfft to torch.fft.rfft for you, if you like, but I'm not sure that's actually what you want to do. torch.fft.rfft returns a complex tensor, for example, so switching to it doesn't make a lot of sense unless you're also using complex tensors.

For now you may just want to suppress the warning that torch.rfft throws?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or you could use torch.view_as_real at the end, although I agree with @mruberry that it would be kind of wasteful to use the torch.fft.rfft before we migrate audio to start using complex tensors.

Copy link
Collaborator Author

@mthrok mthrok Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry, @anjali411

I wrote a response in the document you've shared about why this mapping is more complicated.

I looked at your response and it was very helpful, however that document was for general cases, and here we are dealing with the migration of specific use cases, and I think the concerns brought up are irrelevant. Here are the reasons;

The all torch.rfft usages in torchaudio are

  1. Inputs are always 2D and onesided=True.
    This means that we can simply migrate to torch.fft.rfft, and it is sufficient. We do no need to consider torch.fft.rfftn
  2. No normalization is required.
    We can pass norm=None to torch.fft.rfft, which is default so in my code change it is omitted.
  3. The complex values are immediately used to compute power.
    Look at the change in functional.py and kaldi.py. rfft is used as a mean of computing the power of the input signal in frequency domain. Therefore, with the appropriate changes (like using .abs on complex dtype), no computation is wasted, or no unnecessary computation is introduced.

I can probably write a mapping from torch.rfft to torch.fft.rfft

For torchaudio's usecase, the general mapping from torch.rfft to torch.fft.rfft is not necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this compatibility module is not for migration from torch.rfft to torch.fft.rfft but for TorchScript compatibility torch.fft.rfft (and others coming up) while avoiding explicit import of torch.fft.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a function that computes power spectrum (which is a real value) directly, then we do not need to use torch.fft.rfft but otherwise, the use of complex dtype here as an intermediate expression makes sense and it is not wasteful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Thanks for the additional context, @mthrok. Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?

Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry

Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?

I gave some thoughts on it, and we could do that but I think having the same signature as torch.fft.rfft has advantages when it comes to maintainability.

  1. If someone wants to add a new functionality that uses rfft, he/she can simply use the same signature as torch.fft.rfft, that will free him/her from having to consider the relationship between our abstraction function and other functions of torchaudio that use this abstraction function. (in short, the person working on the new feature can use this abstraction as a drop-in replacement)
  2. similar to 1., in future when we PyTorch is done with the immigration, we can simply replace the abstraction function path with the actual PyTorch implementation. (torchaudio._internal.fft.rfft -> torch.fft.rfft)

Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.

That is okay for domain libraries. We clearly state that domain libraries are expected/tested to work with the version of PyTorch that is released at the same time. So all the work on master branch expects master version (or the next latest stable release) of PyTorch.

@@ -2082,7 +2083,7 @@ def _measure(
if boot_count >= 0 \
else measure_smooth_time_mult

_d = complex_norm(_dftBuf[spectrum_start:spectrum_end])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep the change norm/abs separate. In a prior attempt, there was a performance regression, #747.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb Can you elaborate why the regression in torch.norm, which is not used in this PR, is the reason to discourage the use of torch.abs? Are they using the same implementation under the hood?

Copy link
Contributor

@vincentqb vincentqb Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let @anjali411 @mruberry comment on how torch.abs and torch.norm are related if they are. If they are we would end up with a performance regression again. We could add some performance tests/checks manually or automatically to catch performance changes. However, my suggestion is simply to decouple the changes about torch.*fft* from those about the use of .abs() so that, in the event we do get a regression, we can easily track it and revert it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datapoint: torch.abs and torch.norm have separate implementations.

Copy link

@anjali411 anjali411 Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be perfectly fine to use torch.abs() here. complex_norm is not even using torch.norm https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L424. In fact, it's using three kernels (pow, sum, pow), so torch.abs() should be a strict improvement here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datapoint: torch.abs and torch.norm have separate implementations.

Good to know :)

I think it should be perfectly fine to use torch.abs() here. complex_norm is not even using torch.norm https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L424. In fact, it's using three kernels (pow, sum, pow), so torch.abs() should be a strict improvement here.

We had complex_norm using torch.norm, but this led to a speed regression so we revert to the current implementation with three kernels. Because of this, I want to avoid a performance regression again. I'm glad to know this would launch only one kernel: has the performance difference been tested and compared in this case?

In any case, let's just separate concerns and move this to a separate pull request so we don't block the rest on this :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x=torch.randn(10, 20, 2)
In [7]: def fn(x):
   ...:     t0=time.time()
   ...:     o=x.pow(2.0).sum(-1).pow(0.5)
   ...:     t1 = time.time()
   ...:     print(t1-t0)
   ...:

In [8]: fn(x)
0.00037217140197753906

In [10]: y=torch.view_as_complex(x).contiguous()

In [11]: fn(y)
0.0001480579376220703

There's a significant performance gain (as expected) so I think we should switch to torch.abs.

Copy link
Collaborator Author

@mthrok mthrok Oct 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anjali411 !

In addition to that I did benchmark for the exact code path that @vincentqb suggested and observed that abs is 10x faster on CPU and 2x faster on GPU.

Method CPU GPU
t.abs() 100 loops, best of 5: 1.34 msec per loop 100 loops, best of 5: 28.2 usec per loop
complex_norm(view_as_real(t)) 100 loops, best of 5: 14.9 msec per loop 100 loops, best of 5: 60.5 usec per loop

PyTorch: 1.8.0a0+edac406
torchaudio: 0.8.0a0+0076ab0

code
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
t.abs();
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
t.abs();
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking!


power_spectrum = fft.pow(2).sum(2).unsqueeze(1) # size (m, 1, padded_window_size // 2 + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

complex_norm(
_cepstrum_Buf[cepstrum_start:cepstrum_end],
power=2.0)))
result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex number stuff looks good

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Any reason why this is marked as draft?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants