Skip to content

Commit

Permalink
Adopt PyTorch's test util to torchscript test
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed May 14, 2020
1 parent 9835db7 commit 5958f67
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 41 deletions.
11 changes: 7 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import tempfile
import unittest
from typing import Type, Iterable
from contextlib import contextmanager
from shutil import copytree

import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import pytest

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends
Expand Down Expand Up @@ -87,6 +88,9 @@ class TestBaseMixin:
device = None


_SKIP_IF_NO_CUDA = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')


def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
if dtype not in ['float32', 'float64']:
raise NotImplementedError(f'Unexpected dtype: {dtype}')
Expand All @@ -95,11 +99,10 @@ def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):

name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}'
attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)}
testsuite = type(name, (testbase,), attrs)
testsuite = type(name, (testbase, TestCase), attrs)

if device == 'cuda':
testsuite = pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')(testsuite)
testsuite = _SKIP_IF_NO_CUDA(testsuite)
return testsuite


Expand Down
2 changes: 1 addition & 1 deletion test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_simple(self):
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)

torch.testing.assert_allclose(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)

def test_clamp(self):
input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
Expand Down
64 changes: 28 additions & 36 deletions test/test_torchscript_consistency.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Test suites for jit-ability and its numerical compatibility"""
import unittest
import pytest

import torch
import torchaudio
Expand All @@ -10,29 +9,18 @@
import common_utils


def _assert_functional_consistency(func, tensor, shape_only=False):
ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)

if shape_only:
assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
else:
torch.testing.assert_allclose(ts_output, output)


def _assert_transforms_consistency(transform, tensor):
ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
torch.testing.assert_allclose(ts_output, output)


class Functional(common_utils.TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)
return _assert_functional_consistency(func, tensor, shape_only=shape_only)

ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)
if shape_only:
assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
else:
self.assertEqual(ts_output, output)

def test_spectrogram(self):
def func(tensor):
Expand Down Expand Up @@ -209,7 +197,7 @@ def func(tensor):

def test_lfilter(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand Down Expand Up @@ -253,7 +241,7 @@ def func(tensor):

def test_lowpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -267,7 +255,7 @@ def func(tensor):

def test_highpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -281,7 +269,7 @@ def func(tensor):

def test_allpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -296,7 +284,7 @@ def func(tensor):

def test_bandpass_with_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -312,7 +300,7 @@ def func(tensor):

def test_bandpass_without_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -328,7 +316,7 @@ def func(tensor):

def test_bandreject(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -343,7 +331,7 @@ def func(tensor):

def test_band_with_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -359,7 +347,7 @@ def func(tensor):

def test_band_without_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -375,7 +363,7 @@ def func(tensor):

def test_treble(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -391,7 +379,7 @@ def func(tensor):

def test_deemph(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -404,7 +392,7 @@ def func(tensor):

def test_riaa(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -417,7 +405,7 @@ def func(tensor):

def test_equalizer(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand All @@ -433,7 +421,7 @@ def func(tensor):

def test_perf_biquad_filtering(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
raise unittest.SkipTest("This test is known to fail for float64")

filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
Expand Down Expand Up @@ -514,7 +502,7 @@ def func(tensor):

def test_phaser(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
waveform, _ = torchaudio.load(filepath, normalization=True)

def func(tensor):
gain_in = 0.5
Expand All @@ -533,7 +521,11 @@ class Transforms(common_utils.TestBaseMixin):
def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
_assert_transforms_consistency(transform, tensor)

ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)

def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
Expand Down

0 comments on commit 5958f67

Please sign in to comment.