diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index caa444c988..2d9dca76d6 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -1,138 +1,174 @@ +from typing import Callable, Tuple import torch +from parameterized import parameterized +from torch import Tensor import torchaudio.functional as F from torch.autograd import gradcheck -from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + get_whitenoise, +) + + +class Autograd(TestBaseMixin): + def assert_grad( + self, + transform: Callable[..., Tensor], + inputs: Tuple[torch.Tensor], + *, + enable_all_grad: bool = True, + ): + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to(dtype=self.dtype, device=self.device) + if enable_all_grad: + i.requires_grad = True + inputs_.append(i) + assert gradcheck(transform, inputs_) - -class Autograd(common_utils.TestBaseMixin): def test_lfilter_x(self): torch.random.manual_seed(2434) - x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) - a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) - b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) + x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) x.requires_grad = True - assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) def test_lfilter_a(self): torch.random.manual_seed(2434) - x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) - a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) - b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) + x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) a.requires_grad = True - assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) def test_lfilter_b(self): torch.random.manual_seed(2434) - x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) - a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) - b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) + x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) b.requires_grad = True - assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) def test_lfilter_all_inputs(self): torch.random.manual_seed(2434) - x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device) - a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device) - b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) - b.requires_grad = True - a.requires_grad = True - x.requires_grad = True - assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + self.assert_grad(F.lfilter, (x, a, b)) def test_biquad(self): torch.random.manual_seed(2434) - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True) - b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10) + x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2])) - def test_band_biquad(self): + @parameterized.expand([ + (800, 0.7, True), + (800, 0.7, False), + ]) + def test_band_biquad(self, central_freq, Q, noise): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.band_biquad, (x, sr, central_freq, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise)) - def test_band_biquad_with_noise(self): + @parameterized.expand([ + (800, 0.7, 10), + (800, 0.7, -10), + ]) + def test_bass_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + gain = torch.tensor(gain) + self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q)) - def test_bass_biquad(self): - torch.random.manual_seed(2434) - sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q)) - - def test_treble_biquad(self): - torch.random.manual_seed(2434) - sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q)) + @parameterized.expand([ + (3000, 0.7, 10), + (3000, 0.7, -10), - def test_allpass_biquad(self): + ]) + def test_treble_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + gain = torch.tensor(gain) + self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q)) - def test_lowpass_biquad(self): + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_allpass_biquad(self, central_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q)) - def test_highpass_biquad(self): + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_lowpass_biquad(self, cutoff_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + cutoff_freq = torch.tensor(cutoff_freq) + Q = torch.tensor(Q) + self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) - def test_bandpass_biquad(self): + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_highpass_biquad(self, cutoff_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + cutoff_freq = torch.tensor(cutoff_freq) + Q = torch.tensor(Q) + self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q)) - def test_bandpass_biquad_with_const_skirt_gain(self): + @parameterized.expand([ + (800, 0.7, True), + (800, 0.7, False), + ]) + def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain)) - def test_equalizer_biquad(self): + @parameterized.expand([ + (800, 0.7, 10), + (800, 0.7, -10), + ]) + def test_equalizer_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + gain = torch.tensor(gain) + self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q)) - def test_bandreject_biquad(self): + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_bandreject_biquad(self, central_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True) - central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True) - Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True) - assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q)) + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))