Skip to content

Commit

Permalink
Add autograd test to T.TimeStretch (and F.phase_vocoder)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 5, 2021
1 parent 8d2eeb1 commit 8ef832f
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,20 @@
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
nested_params,
)


# TODO:
# - replace T.Spectrogram
# - generalize it
# - move to common_utils
def get_spectrogram(return_complex):
spectrogram = T.Spectrogram(return_complex=return_complex, power=None)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
return spectrogram(waveform)


class AutogradTestMixin(TestBaseMixin):
def assert_grad(
self,
Expand All @@ -23,8 +34,12 @@ def assert_grad(

inputs_ = []
for i in inputs:
i.requires_grad = True
inputs_.append(i.to(dtype=torch.float64, device=self.device))
if torch.is_tensor(i):
i = i.to(
dtype=torch.cdouble if i.is_complex() else torch.double,
device=self.device)
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)

Expand Down Expand Up @@ -88,3 +103,21 @@ def test_fade(self, fade_shape):
transform = T.Fade(fade_shape=fade_shape)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@nested_params(
[0.7, 0.8, 0.9, 1.0, 1.3],
[True, False],
)
def test_timestretch(self, rate, test_complex):
transform = T.TimeStretch(fixed_rate=rate)
spectrogram = get_spectrogram(return_complex=test_complex)
self.assert_grad(transform, [spectrogram])

@nested_params(
[0.7, 0.8, 0.9, 1.0, 1.3],
[True, False],
)
def test_timestretch_override(self, rate, test_complex):
transform = T.TimeStretch()
spectrogram = get_spectrogram(return_complex=test_complex)
self.assert_grad(transform, [spectrogram, rate])

0 comments on commit 8ef832f

Please sign in to comment.