Skip to content

Conversation

mthrok
Copy link
Contributor

@mthrok mthrok commented Mar 6, 2021

Get rid of pseudo complex type in F.griffinlim.

Part of #1337

griffinlim is tested

  • ✅ TorchScript on CPU/CUDA for float32 and float64
  • librosa consistency on CPU
  • ✅ Batch consistency on CPU

TODO

Benchmark

script
#!/usr/bin/env bash

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -s """
import torch
import torchaudio;

batch_size = 32
num_channels = 1
num_freq = 201
num_frames = 400

torch.manual_seed(0);
spec = torch.randn(batch_size, num_channels, num_freq, num_frames, dtype=torch.float32, device='cpu');
griffinlim = torchaudio.transforms.GriffinLim()
""" """
griffinlim(spec)
"""
device master (8d2eeb1) This PR
CPU 1 loop, best of 5: 6.48 sec per loop 1 loop, best of 5: 3.57 sec per loop
GPU 1 loop, best of 5: 128 msec per loop 1 loop, best of 5: 93.4 msec per loop

@mthrok mthrok added this to the v0.9 milestone Mar 6, 2021
@mthrok mthrok requested a review from anjali411 March 6, 2021 05:15
@mthrok mthrok marked this pull request as ready for review March 6, 2021 05:19
batch, freq, frames = specgram.size()
if rand_init:
angles = 2 * math.pi * torch.rand(batch, freq, frames)
angles = torch.rand(batch, freq, frames, dtype=torch.cfloat, device=specgram.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: check to see if cfloat is preferred.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah or even take a step back and consider why we were creating float tensors by default and if it makes sense to add a dtype kwarg that defaults to float

Copy link
Contributor Author

@mthrok mthrok Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411

I looked into this and the followings are my finding;

  1. The use of cfloat here is fine when the input Tensor is float32 and float64.
    This is because the dtype of angles are updated to the correct one in the next for loop. Specifically, specgram * angles there produces the same dtype as specgram so after the first iteration, all the Tensor has the matching dtype as the input.
  2. The above suggests that the griffinlim will not work for float16 input type. But I did test and pow is not implemented for float16. (see the detail bellow) so I do not think it's and issue.

With this finding, I suggest to proceed with the current approach (using cfloat for intermediate variables) since it is working for float32 and float64. Also note that float16 is not in general tested in torchaudio.


The following is the result of running griffinlim for the different dtype on before/after this PR.

Script

script
import torch
import torchaudio.transforms as T


def test(waveform, dtype):
    n_fft = 1024
    win_length = None
    hop_length = 512

    spectrogram = T.Spectrogram(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
    )
    griffin_lim = T.GriffinLim(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
    )

    spec = spectrogram(waveform).to(dtype)
    print('*******')
    print('input:', spec.dtype)
    recon = griffin_lim(spec)
    print('output:', recon.dtype)
    assert recon.dtype == spec.dtype
    return recon


torch.random.manual_seed(0)
waveform = torch.randn(2, 1024)


test(waveform, torch.float64)
test(waveform, torch.float32)
test(waveform, torch.float16)

Result

Before applying this PR (commit 0433b7a)

*******
input: torch.float64
/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/functional.py:655: UserWarning: istft will require a complex-valued input tensor in a future PyTorch release. Matching the output from stft with return_complex=True.  (Triggered internally at  /opt/conda/conda-bld/pytorch_1617001512472/work/aten/src/ATen/native/SpectralOps.cpp:807.)
  return _VF.istft(input, n_fft, hop_length, win_length, window, center,  # type: ignore
output: torch.float64
*******
input: torch.float32
output: torch.float32
*******
input: torch.float16
Traceback (most recent call last):
  File "foo.py", line 36, in <module>
    test(waveform, torch.float16)
  File "foo.py", line 24, in test
    recon = griffin_lim(spec)
  File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch/moto/torchaudio/torchaudio/transforms.py", line 189, in forward
    return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
  File "/scratch/moto/torchaudio/torchaudio/functional/functional.py", line 168, in griffinlim
    specgram = specgram.pow(1 / power)
RuntimeError: "pow" not implemented for 'Half'

After applying this PR

*******
input: torch.float64
output: torch.float64
*******
input: torch.float32
output: torch.float32
*******
input: torch.float16
Traceback (most recent call last):
  File "foo.py", line 36, in <module>
    test(waveform, torch.float16)
  File "foo.py", line 24, in test
    recon = griffin_lim(spec)
  File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch/moto/torchaudio/torchaudio/transforms.py", line 189, in forward
    return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
  File "/scratch/moto/torchaudio/torchaudio/functional/functional.py", line 168, in griffinlim
    specgram = specgram.pow(1 / power)
RuntimeError: "pow" not implemented for 'Half'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR pytorch/pytorch#50999 recently added pow for torch.float16

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of cfloat here is fine when the input Tensor is float32 and float64.
This is because the dtype of angles are updated to the correct one in the next for loop. Specifically, specgram * angles there produces the same dtype as specgram so after the first iteration, all the Tensor has the matching dtype as the input.

So in that case, I think it makes sense to instead create angles of the same dtype as specgram since multiplication in the for loop (specgram * angles) promotes all input tensors to the same dtype, so I think it makes sense to create angles with dtype specgram.dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this could be the reason why autograd test in #1421 is failing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR pytorch/pytorch#50999 recently added pow for torch.float16

Looks like it's still landing? pytorch/pytorch#55280

So in that case, I think it makes sense to instead create angles of the same dtype as specgram since multiplication in the for loop (specgram * angles) promotes all input tensors to the same dtype, so I think it makes sense to create angles with dtype specgram.dtype.

Added the necessary change.

Copy link
Contributor

@imaginary-person imaginary-person Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, pytorch/pytorch#55280 would hopefully land soon to enable pow for float16 on CPU.

pow for float16 is already supported on CUDA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imaginary-person Thanks for the info.

Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs JIT, autograd and function correctness tests for complex to verify if everything works as expected

@mthrok mthrok force-pushed the migrate-griffinlim branch 2 times, most recently from 2baec15 to 218db3d Compare April 2, 2021 19:35
@mthrok mthrok changed the title [D] Adopt native complex dtype in griffnlim Adopt native complex dtype in griffnlim Apr 2, 2021
@mthrok mthrok modified the milestones: v0.9, Complex Tensor Migration Apr 5, 2021
@anjali411
Copy link

@mthrok this PR needs performance benchmarking!

Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks Moto. let's merge this after #1421 is merged

angles = 2 * math.pi * torch.rand(batch, freq, frames)
angles = torch.rand(
specgram.size(),
dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait the angles here should be a floating point tensor, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, angles is complex value. Griffin-Lim algorithm iteratively optimizes the phase (or the direction in complex plain) of each element of the given spectrogram so that at the end istft give the original waveform.

Copy link
Contributor Author

@mthrok mthrok Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*The original implementation was reusing the variable name, constructing the complex value Tensor called angles from real valued tensor called angles and magnitude.

@mthrok mthrok force-pushed the migrate-griffinlim branch 3 times, most recently from de1877e to 4b599e5 Compare April 6, 2021 14:28
@mthrok
Copy link
Contributor Author

mthrok commented Apr 6, 2021

@anjali411

The autograd tests are failing for the case rand_init=True.

Failing test log: https://app.circleci.com/pipelines/github/pytorch/audio/5665/workflows/746a53f5-6f18-4f2e-b717-869bf342eaf4/jobs/194528

The following is the corresponding code;

https://github.com/mthrok/audio/blob/17f61e73f8f4b0f314bb7356aa3a2222c4038347/torchaudio/functional/functional.py#L193-L201

However, it does not fail on my local env with PyTorch nightly from March 29th or today's.

[conda] pytorch 1.9.0.dev20210329 py3.8_cuda10.1_cudnn7.6.3_0 pytorch-nightly
[conda] pytorch 1.9.0.dev20210405 py3.8_cuda10.1_cudnn7.6.3_0 pytorch-nightly

I was wondering if this is a regression but not sure as it works on my local with the latest nightly as well. What do you think?

@mthrok mthrok force-pushed the migrate-griffinlim branch from 17f61e7 to 12f77e1 Compare April 7, 2021 01:28
@yoyolicoris
Copy link
Contributor

@anjali411

The autograd tests are failing for the case rand_init=True.

@mthrok

Hmm, I think it's an expected behavior, cuz numerical method will (I'm not sure though) forward the function multiple times, and with rand_init=True, each forward pass will start with different initial phases, result in different output.

For functions like it which have random behavior internally, maybe a .backward() at the end of test is enough, to make sure that gradients can be propogated without problems.

@mthrok
Copy link
Contributor Author

mthrok commented Apr 7, 2021

@anjali411
The autograd tests are failing for the case rand_init=True.

@mthrok

Hmm, I think it's an expected behavior, cuz numerical method will (I'm not sure though) forward the function multiple times, and with rand_init=True, each forward pass will start with different initial phases, result in different output.

For functions like it which have random behavior internally, maybe a .backward() at the end of test is enough, to make sure that gradients can be propogated without problems.

Hi @yoyololicon

Thanks for the insight. You are right. I added a wrapper that sets the random seed and the test passes now.
(I do not know why I thought it was working on my local env...)

@mthrok
Copy link
Contributor Author

mthrok commented Apr 7, 2021

@anjali411 This PR is ready for the final review.

@mthrok mthrok merged commit 78c3480 into pytorch:master Apr 9, 2021
@mthrok mthrok deleted the migrate-griffinlim branch April 9, 2021 19:01
carolineechen pushed a commit to carolineechen/audio that referenced this pull request Apr 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants