Skip to content

Commit

Permalink
Test around non-zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 13, 2021
1 parent a2e9b7f commit ad3f7c5
Showing 1 changed file with 45 additions and 19 deletions.
64 changes: 45 additions & 19 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import unittest

from parameterized import parameterized
import torch
Expand Down Expand Up @@ -128,26 +129,51 @@ def test_spectral_centroid(self):
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

# Note: rate=0.7 fails
# https://github.com/pytorch/pytorch/issues/55557
@nested_params(
[0.8, 0.9, 1.0, 1.3],
[True, False],
)
def test_timestretch(self, rate, test_complex):
transform = T.TimeStretch(fixed_rate=rate)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400, power=1 if test_complex else None)
@unittest.expectedFailure
def test_timestretch_zeros_fail(self):
"""Test that ``T.TimeStretch`` fails gradcheck at 0
This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
which performs ``atan2(img, real)``, and gradient is not defined at 0.
"""
n_fft = 16
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99)
waveform = torch.zeros(2, 40)
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
self.assert_grad(transform, [spectrogram])

# Note: rate=0.7 fails
# https://github.com/pytorch/pytorch/issues/55557
@nested_params(
[0.8, 0.9, 1.0, 1.3],
[True, False],
[0.7, 0.8, 0.9, 1.0, 1.3],
[False, True],
)
def test_timestretch_override(self, rate, test_complex):
transform = T.TimeStretch()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400, power=1 if test_complex else None)
self.assert_grad(transform, [spectrogram, rate])
def test_timestretch(self, rate, test_pseudo_complex):
"""Verify that ``T.TimeStretch`` does not fail if it's not too close to 0
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
for cases where input is not zero, and different configurations of `TimeStretch`.
Ideally, we should be testing on Spectrogram of random waveform but it is hard to control
the values around zeros.
"""
n_fft = 16
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
waveform = torch.zeros(2, 40)
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)

# Epsilon values tried
#
# Note:
# This is not experimental and comprehensive.
# The result also depends on ``n_fft``.
#
# CPU / CUDA
# * 1e-1 ok / ok
# * 1e-2 ok / ok
# * 1e-3 ok / ok
# * 1e-3 + 1e-3j ok / ok
# * 1e-4 ok / NG

spectrogram += 1e-3
if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram])

0 comments on commit ad3f7c5

Please sign in to comment.