Skip to content

Commit

Permalink
Run functional librosa compatibility test on CUDA as well (#1436)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 7, 2021
1 parent 52e2943 commit 9a0e70e
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 169 deletions.
31 changes: 26 additions & 5 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
get_asset_path,
get_whitenoise,
get_sinusoid,
get_spectrogram,
)
from .backend_utils import (
set_audio_backend,
Expand Down Expand Up @@ -30,8 +31,28 @@
nested_params
)

__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoKaldi', 'skipIfNoSox',
'skipIfNoSoxBackend', 'skipIfRocm', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav',
'load_params', 'nested_params']
__all__ = [
'get_asset_path',
'get_whitenoise',
'get_sinusoid',
'get_spectrogram',
'set_audio_backend',
'TempDirMixin',
'HttpServerMixin',
'TestBaseMixin',
'PytorchTestCase',
'TorchaudioTestCase',
'skipIfNoCuda',
'skipIfNoExec',
'skipIfNoModule',
'skipIfNoKaldi',
'skipIfNoSox',
'skipIfNoSoxBackend',
'skipIfRocm',
'get_wav_data',
'normalize_wav',
'load_wav',
'save_wav',
'load_params',
'nested_params',
]
44 changes: 42 additions & 2 deletions test/torchaudio_unittest/common_utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path
from typing import Union
from typing import Union, Optional

import torch

Expand Down Expand Up @@ -62,7 +62,7 @@ def get_whitenoise(
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
if dtype not in [torch.float32, torch.int32, torch.int16, torch.uint8]:
if dtype not in [torch.float64, torch.float32, torch.int32, torch.int16, torch.uint8]:
raise NotImplementedError(f'dtype {dtype} is not supported.')
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only fork on CPU, generate values and move the data to the given device
Expand Down Expand Up @@ -110,3 +110,43 @@ def get_sinusoid(
if not channels_first:
tensor = tensor.t()
return convert_tensor_encoding(tensor, dtype)


def get_spectrogram(
waveform,
*,
n_fft: int = 2048,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[torch.Tensor] = None,
center: bool = True,
pad_mode: str = 'reflect',
power: Optional[float] = None,
):
"""Generate a spectrogram of the given Tensor
Args:
n_fft: The number of FFT bins.
hop_length: Stride for sliding window. default: ``n_fft // 4``.
win_length: The size of window frame and STFT filter. default: ``n_fft``.
winwdow: Window function. default: Hann window
center: Pad the input sequence if True. See ``torch.stft`` for the detail.
pad_mode: Padding method used when center is True. Default: "reflect".
power: If ``None``, raw spectrogram with complex values are returned,
otherwise the norm of the spectrogram is returned.
"""
hop_length = hop_length or n_fft // 4
win_length = win_length or n_fft
window = torch.hann_window(win_length) if window is None else window
spec = torch.stft(
waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=center,
window=window,
pad_mode=pad_mode,
return_complex=True)
if power is not None:
spec = spec.abs() ** power
return spec
52 changes: 38 additions & 14 deletions test/torchaudio_unittest/common_utils/parameterized_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,41 @@ def load_params(*paths):
return [param(json.loads(line)) for line in file]


def nested_params(*params):
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
return f'{func.__name__}_{"_".join(strs)}'

return parameterized.expand(
list(product(*params)),
name_func=_name_func
)
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
return f'{func.__name__}_{"_".join(strs)}'


def nested_params(*params_set):
"""Generate the cartesian product of the given list of parameters.
Args:
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
all the parameters have to be specified with the class, only using kwargs.
"""
flatten = [p for params in params_set for p in params]

# Parameters to be nested are given as list of plain objects
if all(not isinstance(p, param) for p in flatten):
args = list(product(*params_set))
return parameterized.expand(args, name_func=_name_func)

# Parameters to be nested are given as list of `parameterized.param`
if not all(isinstance(p, param) for p in flatten):
raise TypeError(
"When using ``parameterized.param``, "
"all the parameters have to be of the ``param`` type.")
if any(p.args for p in flatten):
raise ValueError(
"When using ``parameterized.param``, "
"all the parameters have to be provided as keyword argument."
)
args = [param()]
for params in params_set:
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
return parameterized.expand(args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import Functional, FunctionalComplex


class TestFunctionalCPU(Functional, PytorchTestCase):
device = 'cpu'


class TestFunctionalComplexCPU(FunctionalComplex, PytorchTestCase):
device = 'cpu'
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import Functional, FunctionalComplex


@skipIfNoCuda
class TestFunctionalCUDA(Functional, PytorchTestCase):
device = 'cuda'


@skipIfNoCuda
class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase):
device = 'cuda'
148 changes: 0 additions & 148 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py

This file was deleted.

0 comments on commit 9a0e70e

Please sign in to comment.