Skip to content

Commit

Permalink
Guard test for windows
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 18, 2020
1 parent 60e8f06 commit 5e970ee
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 21 deletions.
12 changes: 12 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import tempfile
import unittest
from typing import Union
Expand All @@ -7,6 +8,7 @@
import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio.list_audio_backends()
Expand Down Expand Up @@ -129,8 +131,18 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass


def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')


def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')


skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')


def get_whitenoise(
Expand Down
14 changes: 4 additions & 10 deletions test/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Test suites for checking numerical compatibility against Kaldi"""
import json
import shutil
import unittest
import subprocess

import kaldi_io
Expand All @@ -13,10 +11,6 @@
from parameterized import parameterized, param


def _not_available(cmd):
return shutil.which(cmd) is None


def _convert_args(**kwargs):
args = []
for key, value in kwargs.items():
Expand Down Expand Up @@ -61,7 +55,7 @@ def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)

@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
@common_utils.skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
Expand All @@ -78,7 +72,7 @@ def test_sliding_window_cmn(self):
self.assert_equal(result, expected=kaldi_result)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json')))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
@common_utils.skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand All @@ -89,7 +83,7 @@ def test_fbank(self, kwargs):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json')))
@unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available')
@common_utils.skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand All @@ -100,7 +94,7 @@ def test_spectrogram(self, kwargs):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json')))
@unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available')
@common_utils.skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand Down
4 changes: 3 additions & 1 deletion test/sox_io/test_info.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import itertools
from parameterized import parameterized

import torchaudio
from torchaudio.backend import sox_io_backend

from .. import common_utils
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
Expand All @@ -14,6 +14,8 @@
from . import sox_utils


@common_utils.skipIfNoExec('sox')
@common_utils.skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand Down
15 changes: 10 additions & 5 deletions test/sox_io/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from .. import common_utils
from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
Expand All @@ -18,24 +19,28 @@ def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)


@common_utils.skipIfNoExec('sox')
@common_utils.skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels):
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
audio_path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)

ts_info_func = torch.jit.script(py_info_func)
script_path = self.get_temp_path(f'info_func')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)

py_info = py_info_func(path)
ts_info = ts_info_func(path)
py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
Expand Down
7 changes: 2 additions & 5 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import torchaudio
from torchaudio._internal.module_utils import is_module_available

from . import common_utils

Expand All @@ -28,15 +27,13 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
backend_module = torchaudio.backend.no_backend


@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
@common_utils.skipIfNoExtension
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend


@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend

0 comments on commit 5e970ee

Please sign in to comment.