diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 717a7bc87b..0f7b612482 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -125,6 +125,31 @@ def test_fade(self, fade_shape): waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform], nondet_tol=1e-10) + @parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)]) + def test_masking(self, masking_transform): + sample_rate = 8000 + n_fft = 400 + spectrogram = get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), + n_fft=n_fft, power=1) + deterministic_transform = _DeterministicWrapper(masking_transform(400)) + self.assert_grad(deterministic_transform, [spectrogram]) + + @parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)]) + def test_masking_iid(self, masking_transform): + sample_rate = 8000 + n_fft = 400 + specs = [get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i), + n_fft=n_fft, power=1) + for i in range(3) + ] + + batch = torch.stack(specs) + assert batch.ndim == 4 + deterministic_transform = _DeterministicWrapper(masking_transform(400, True)) + self.assert_grad(deterministic_transform, [batch]) + def test_spectral_centroid(self): sample_rate = 8000 transform = T.SpectralCentroid(sample_rate=sample_rate)