From 8885f69ed29095637411133f41c65f19db58db8c Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Sat, 6 Mar 2021 05:12:26 +0000 Subject: [PATCH 1/5] Adopt native complex dtype in griffnlim --- torchaudio/functional/functional.py | 57 ++++++++++++++++------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index a596c84e01..f44987ba2a 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -125,6 +125,14 @@ def spectrogram( return spec_f +def _get_complex_dtype(real_dtype: torch.dtype): + if real_dtype == torch.double: + return torch.cdouble + if real_dtype == torch.float: + return torch.cfloat + raise ValueError(f'Unexpected dtype {real_dtype}') + + def griffinlim( specgram: Tensor, window: Tensor, @@ -180,23 +188,19 @@ def griffinlim( specgram = specgram.pow(1 / power) - # randomly initialize the phase - batch, freq, frames = specgram.size() + # initialize the phase if rand_init: - angles = 2 * math.pi * torch.rand(batch, freq, frames) + angles = torch.rand( + specgram.size(), + dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) else: - angles = torch.zeros(batch, freq, frames) - angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \ - .to(dtype=specgram.dtype, device=specgram.device) - specgram = specgram.unsqueeze(-1).expand_as(angles) + angles = torch.full( + specgram.size(), 1, + dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) # And initialize the previous iterate to 0 - rebuilt = torch.tensor(0.) - + tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device) for _ in range(n_iter): - # Store the previous iterate - tprev = rebuilt - # Invert with our current estimate of the phases inverse = torch.istft(specgram * angles, n_fft=n_fft, @@ -206,26 +210,27 @@ def griffinlim( length=length) # Rebuild the spectrogram - rebuilt = torch.view_as_real( - torch.stft( - input=inverse, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - center=True, - pad_mode='reflect', - normalized=False, - onesided=True, - return_complex=True, - ) + rebuilt = torch.stft( + input=inverse, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True, ) # Update our phase estimates angles = rebuilt if momentum: angles = angles - tprev.mul_(momentum / (1 + momentum)) - angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles)) + angles = angles.div(angles.abs().add(1e-16)) + + # Store the previous iterate + tprev = rebuilt # Return the final phase estimates waveform = torch.istft(specgram * angles, From d90a1298b75c512c8ee86cabadfc69c1ad803ea4 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Mon, 5 Apr 2021 15:15:16 -0700 Subject: [PATCH 2/5] Update autograd test --- .../transforms/autograd_test_impl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 504b0e9d01..ac03a3ebe0 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -8,6 +8,7 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, + nested_params, ) @@ -65,14 +66,17 @@ def test_melspectrogram(self): waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) - @parameterized.expand([(0, ), (0.99, )]) - def test_griffinlim(self, momentum): + @nested_params( + [0, 0.99], + [False, True], + ) + def test_griffinlim(self, momentum, rand_init): n_fft = 400 n_frames = 5 n_iter = 3 spec = torch.rand(n_fft // 2 + 1, n_frames) * n_fft - transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=False) - self.assert_grad(transform, [spec], nondet_tol=1e-10) + transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init) + self.assert_grad(transform, [spec]) @parameterized.expand([(False, ), (True, )]) def test_mfcc(self, log_mels): From 12f77e1508e414766e66cdd2a73dea4ed0790c58 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 6 Apr 2021 16:10:29 +0000 Subject: [PATCH 3/5] Add support to half --- torchaudio/functional/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index f44987ba2a..19b940caa8 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -130,6 +130,8 @@ def _get_complex_dtype(real_dtype: torch.dtype): return torch.cdouble if real_dtype == torch.float: return torch.cfloat + if real_dtype == torch.half: + return torch.complex32 raise ValueError(f'Unexpected dtype {real_dtype}') From 77c37648e6bfe50f87f71f1d0a913e4cb87a34b5 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 7 Apr 2021 07:48:03 -0700 Subject: [PATCH 4/5] Make TimeStretch deterministic --- .../transforms/autograd_test_impl.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index ac03a3ebe0..72e6684c4b 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -8,10 +8,23 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, + get_spectrogram, nested_params, ) +class _DeterministicWrapper(torch.nn.Module): + """Helper transform wrapper to make the given transform deterministic""" + def __init__(self, transform, seed=0): + super().__init__() + self.seed = seed + self.transform = transform + + def forward(self, input: torch.Tensor): + torch.random.manual_seed(0) + return self.transform(input) + + class AutogradTestMixin(TestBaseMixin): def assert_grad( self, @@ -72,10 +85,13 @@ def test_melspectrogram(self): ) def test_griffinlim(self, momentum, rand_init): n_fft = 400 - n_frames = 5 + power = 1 n_iter = 3 - spec = torch.rand(n_fft // 2 + 1, n_frames) * n_fft - transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init) + spec = get_spectrogram( + get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2), + n_fft=n_fft, power=power) + transform = _DeterministicWrapper( + T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power)) self.assert_grad(transform, [spec]) @parameterized.expand([(False, ), (True, )]) From 079668a3c6dde4aee6b309e3a36cdf58031a17da Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 7 Apr 2021 18:16:11 +0000 Subject: [PATCH 5/5] Fix wrapper --- test/torchaudio_unittest/transforms/autograd_test_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 72e6684c4b..da396f5521 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -21,7 +21,7 @@ def __init__(self, transform, seed=0): self.transform = transform def forward(self, input: torch.Tensor): - torch.random.manual_seed(0) + torch.random.manual_seed(self.seed) return self.transform(input)