diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index 9f4b4f1a64..b9c5a6f64e 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -14,29 +14,19 @@ skipIfNoSox, ) -from .functional_impl import Lfilter, Spectrogram, FunctionalComplex +from .functional_impl import Functional, FunctionalComplex -class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): +class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase): dtype = torch.float32 device = torch.device('cpu') @unittest.expectedFailure - def test_9th_order_filter_stability(self): - super().test_9th_order_filter_stability() + def test_lfilter_9th_order_filter_stability(self): + super().test_lfilter_9th_order_filter_stability() -class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): - dtype = torch.float64 - device = torch.device('cpu') - - -class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): - dtype = torch.float32 - device = torch.device('cpu') - - -class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): +class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py index 62d9eff3eb..ebf8cd8326 100644 --- a/test/torchaudio_unittest/functional/functional_cuda_test.py +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -2,33 +2,21 @@ import unittest from torchaudio_unittest import common_utils -from .functional_impl import Lfilter, Spectrogram, FunctionalComplex +from .functional_impl import Functional, FunctionalComplex @common_utils.skipIfNoCuda -class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): +class TestFunctionalloat32(Functional, common_utils.PytorchTestCase): dtype = torch.float32 device = torch.device('cuda') @unittest.expectedFailure - def test_9th_order_filter_stability(self): - super().test_9th_order_filter_stability() + def test_lfilter_9th_order_filter_stability(self): + super().test_lfilter_9th_order_filter_stability() @common_utils.skipIfNoCuda -class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): - dtype = torch.float64 - device = torch.device('cuda') - - -@common_utils.skipIfNoCuda -class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase): - dtype = torch.float32 - device = torch.device('cuda') - - -@common_utils.skipIfNoCuda -class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase): +class TestLFilterFloat64(Functional, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index c4f863fe45..8b890e9563 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -9,8 +9,8 @@ from torchaudio_unittest.common_utils import nested_params -class Lfilter(common_utils.TestBaseMixin): - def test_simple(self): +class Functional(common_utils.TestBaseMixin): + def test_lfilter_simple(self): """ Create a very basic signal, Then make a simple 4th order delay @@ -25,7 +25,7 @@ def test_simple(self): self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5) - def test_clamp(self): + def test_lfilter_clamp(self): input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device) b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device) a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device) @@ -40,7 +40,7 @@ def test_clamp(self): ((2, 3, 44100),), ((1, 2, 3, 44100),) ]) - def test_shape(self, shape): + def test_lfilter_shape(self, shape): torch.random.manual_seed(42) waveform = torch.rand(*shape, dtype=self.dtype, device=self.device) b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) @@ -48,7 +48,7 @@ def test_shape(self, shape): output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) assert shape == waveform.size() == output_waveform.size() - def test_9th_order_filter_stability(self): + def test_lfilter_9th_order_filter_stability(self): """ Validate the precision of lfilter against reference scipy implementation when using high order filter. The reference implementation use cascaded second-order filters so is more numerically accurate. @@ -70,10 +70,8 @@ def test_9th_order_filter_stability(self): yhat = F.lfilter(x, a, b, False) self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5) - -class Spectrogram(common_utils.TestBaseMixin): @parameterized.expand([(0., ), (1., ), (2., ), (3., )]) - def test_grad_at_zero(self, power): + def test_spectogram_grad_at_zero(self, power): """The gradient of power spectrogram should not be nan but zero near x=0 https://github.com/pytorch/audio/issues/993