Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autograd test for T.GriffinLim #1421

Merged
merged 9 commits into from
Apr 6, 2021
9 changes: 9 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def test_melspectrogram(self):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@parameterized.expand([(0, ), (0.99, )])
def test_griffinlim(self, momentum):
n_fft = 400
n_frames = 5
n_iter = 3
spec = torch.rand(n_fft // 2 + 1, n_frames) * n_fft
yoyololicon marked this conversation as resolved.
Show resolved Hide resolved
transform = T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=False)
self.assert_grad(transform, [spec], nondet_tol=1e-10)

@parameterized.expand([(False, ), (True, )])
def test_mfcc(self, log_mels):
sample_rate = 8000
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def griffinlim(
hop_length=hop_length,
win_length=win_length,
window=window,
length=length).float()
length=length)

# Rebuild the spectrogram
rebuilt = torch.view_as_real(
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self,
super(GriffinLim, self).__init__()

assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum > 0, 'momentum={} < 0'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)

self.n_fft = n_fft
self.n_iter = n_iter
Expand Down