Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,19 @@
skipIfNoSox,
)

from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
from .functional_impl import Functional, FunctionalComplex


class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')

@unittest.expectedFailure
def test_9th_order_filter_stability(self):
super().test_9th_order_filter_stability()
def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()


class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')


class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')

Expand Down
22 changes: 5 additions & 17 deletions test/torchaudio_unittest/functional/functional_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,21 @@
import unittest

from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
from .functional_impl import Functional, FunctionalComplex


@common_utils.skipIfNoCuda
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestFunctionalloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')

@unittest.expectedFailure
def test_9th_order_filter_stability(self):
super().test_9th_order_filter_stability()
def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()


@common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
class TestLFilterFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')

Expand Down
14 changes: 6 additions & 8 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from torchaudio_unittest.common_utils import nested_params


class Lfilter(common_utils.TestBaseMixin):
def test_simple(self):
class Functional(common_utils.TestBaseMixin):
def test_lfilter_simple(self):
"""
Create a very basic signal,
Then make a simple 4th order delay
Expand All @@ -25,7 +25,7 @@ def test_simple(self):

self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)

def test_clamp(self):
def test_lfilter_clamp(self):
input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
Expand All @@ -40,15 +40,15 @@ def test_clamp(self):
((2, 3, 44100),),
((1, 2, 3, 44100),)
])
def test_shape(self, shape):
def test_lfilter_shape(self, shape):
torch.random.manual_seed(42)
waveform = torch.rand(*shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size()

def test_9th_order_filter_stability(self):
def test_lfilter_9th_order_filter_stability(self):
"""
Validate the precision of lfilter against reference scipy implementation when using high order filter.
The reference implementation use cascaded second-order filters so is more numerically accurate.
Expand All @@ -70,10 +70,8 @@ def test_9th_order_filter_stability(self):
yhat = F.lfilter(x, a, b, False)
self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)


class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_grad_at_zero(self, power):
def test_spectogram_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0

https://github.com/pytorch/audio/issues/993
Expand Down