Skip to content

Commit

Permalink
Introduce common utility for defining test matrix for device/dtype (#616
Browse files Browse the repository at this point in the history
)

* Introduce common utility for defining test matrix for device/dtype

* Make resample_waveform support float64

* Mark lfilter related test as xfail when float64

* fix
  • Loading branch information
mthrok committed May 8, 2020
1 parent 7a0d419 commit 00d3820
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 62 deletions.
36 changes: 36 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import tempfile
from typing import Type, Iterable
from contextlib import contextmanager
from shutil import copytree

import torch
import torchaudio
import pytest

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends
Expand Down Expand Up @@ -78,3 +80,37 @@ def supports_mp3(backend):


BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)


class TestBaseMixin:
dtype = None
device = None


def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
if dtype not in ['float32', 'float64']:
raise NotImplementedError(f'Unexpected dtype: {dtype}')
if device not in ['cpu', 'cuda']:
raise NotImplementedError(f'Unexpected device: {device}')

name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}'
attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)}
testsuite = type(name, (testbase,), attrs)

if device == 'cuda':
testsuite = pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')(testsuite)
return testsuite


def define_test_suites(
scope: dict,
testbases: Iterable[Type[TestBaseMixin]],
dtypes: Iterable[str] = ('float32', 'float64'),
devices: Iterable[str] = ('cpu', 'cuda'),
):
for suite in testbases:
for device in devices:
for dtype in dtypes:
t = define_test_suite(suite, dtype, device)
scope[t.__name__] = t
24 changes: 4 additions & 20 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import common_utils


class _LfilterMixin:
device = None
dtype = None

class Lfilter(common_utils.TestBaseMixin):
def test_simple(self):
"""
Create a very basic signal,
Expand All @@ -33,25 +30,12 @@ def test_clamp(self):
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True)
self.assertTrue(output_signal.max() <= 1)
assert output_signal.max() <= 1
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
self.assertTrue(output_signal.max() > 1)


class TestLfilterFloat32CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
dtype = torch.float32


class TestLfilterFloat64CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
dtype = torch.float64
assert output_signal.max() > 1


@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
class TestLfilterFloat32CUDA(_LfilterMixin, unittest.TestCase):
device = torch.device('cuda')
dtype = torch.float32
common_utils.define_test_suites(globals(), [Lfilter])


class TestComputeDeltas(unittest.TestCase):
Expand Down
94 changes: 55 additions & 39 deletions test/test_torchscript_consistency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test suites for jit-ability and its numerical compatibility"""
import unittest
import pytest

import torch
import torchaudio
Expand All @@ -9,8 +10,7 @@
import common_utils


def _assert_functional_consistency(func, tensor, device, shape_only=False):
tensor = tensor.to(device)
def _assert_functional_consistency(func, tensor, shape_only=False):
ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)
Expand All @@ -21,21 +21,18 @@ def _assert_functional_consistency(func, tensor, device, shape_only=False):
torch.testing.assert_allclose(ts_output, output)


def _assert_transforms_consistency(transform, tensor, device):
tensor = tensor.to(device)
transform = transform.to(device)
def _assert_transforms_consistency(transform, tensor):
ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
torch.testing.assert_allclose(ts_output, output)


class _FunctionalTestMixin:
class Functional(common_utils.TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
device = None

def _assert_consistency(self, func, tensor, shape_only=False):
return _assert_functional_consistency(func, tensor, self.device, shape_only=shape_only)
tensor = tensor.to(device=self.device, dtype=self.dtype)
return _assert_functional_consistency(func, tensor, shape_only=shape_only)

def test_spectrogram(self):
def func(tensor):
Expand Down Expand Up @@ -159,7 +156,7 @@ def func(tensor):
return F.complex_norm(tensor, power)

tensor = torch.randn(1, 2, 1025, 400, 2)
_assert_functional_consistency(func, tensor, self.device)
self._assert_consistency(func, tensor)

def test_mask_along_axis(self):
def func(tensor):
Expand Down Expand Up @@ -211,6 +208,9 @@ def func(tensor):
self._assert_consistency(func, tensor, shape_only=True)

def test_lfilter(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand Down Expand Up @@ -252,6 +252,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_lowpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -263,6 +266,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_highpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -274,6 +280,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_allpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -286,6 +295,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_bandpass_with_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -298,7 +310,10 @@ def func(tensor):

self._assert_consistency(func, waveform)

def test_bandpass_withou_csg(self):
def test_bandpass_without_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -312,6 +327,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_bandreject(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -324,6 +342,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_band_with_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -337,6 +358,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_band_without_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -350,6 +374,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_treble(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -363,6 +390,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_deemph(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -373,6 +403,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_riaa(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -383,6 +416,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_equalizer(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand All @@ -396,6 +432,9 @@ def func(tensor):
self._assert_consistency(func, waveform)

def test_perf_biquad_filtering(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)

Expand Down Expand Up @@ -489,12 +528,12 @@ def func(tensor):
self._assert_consistency(func, waveform)


class _TransformsTestMixin:
class Transforms(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
device = None

def _assert_consistency(self, transform, tensor):
_assert_transforms_consistency(transform, tensor, self.device)
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
_assert_transforms_consistency(transform, tensor)

def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
Expand Down Expand Up @@ -578,27 +617,4 @@ def test_Vad(self):
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)


class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU"""
device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestFunctionalCUDA(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on GPU"""
device = torch.device('cuda')


class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on CPU"""
device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on GPU"""
device = torch.device('cuda')


if __name__ == '__main__':
unittest.main()
common_utils.define_test_suites(globals(), [Functional, Transforms])
8 changes: 5 additions & 3 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,8 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
device, dtype = waveform.device, waveform.dtype

assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0

Expand All @@ -905,7 +907,7 @@ def resample_waveform(waveform: Tensor,
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width)
weights = weights.to(waveform.device) # TODO Create weights on device directly
weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly

assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
Expand All @@ -918,9 +920,9 @@ def resample_waveform(waveform: Tensor,
window_size = weights.size(1)
tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq)
output = torch.zeros((num_channels, tot_output_samp),
device=waveform.device)
device=device, dtype=dtype)
# eye size: (num_channels, num_channels, 1)
eye = torch.eye(num_channels, device=waveform.device).unsqueeze(2)
eye = torch.eye(num_channels, device=device, dtype=dtype).unsqueeze(2)
for i in range(first_indices.size(0)):
wave_to_conv = waveform
first_index = int(first_indices[i].item())
Expand Down

0 comments on commit 00d3820

Please sign in to comment.