-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TorchScript-able "info" func to sox_io backend #728
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import os | ||
import shutil | ||
import tempfile | ||
import unittest | ||
from typing import Union | ||
|
@@ -7,6 +8,7 @@ | |
import torch | ||
from torch.testing._internal.common_utils import TestCase as PytorchTestCase | ||
import torchaudio | ||
from torchaudio._internal.module_utils import is_module_available | ||
|
||
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) | ||
BACKENDS = torchaudio.list_audio_backends() | ||
|
@@ -87,6 +89,33 @@ def set_audio_backend(backend): | |
torchaudio.set_audio_backend(be) | ||
|
||
|
||
class TempDirMixin: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not for this PR: just as is done below in a prior pr, what's the advantage of not inheriting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid inheritance of |
||
"""Mixin to provide easy access to temp dir""" | ||
temp_dir_ = None | ||
temp_dir = None | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self._init_temp_dir() | ||
|
||
def tearDown(self): | ||
super().tearDownClass() | ||
self._clean_up_temp_dir() | ||
|
||
def _init_temp_dir(self): | ||
self.temp_dir_ = tempfile.TemporaryDirectory() | ||
self.temp_dir = self.temp_dir_.name | ||
|
||
def _clean_up_temp_dir(self): | ||
if self.temp_dir_ is not None: | ||
self.temp_dir_.cleanup() | ||
self.temp_dir_ = None | ||
self.temp_dir = None | ||
|
||
def get_temp_path(self, *paths): | ||
return os.path.join(self.temp_dir, *paths) | ||
|
||
|
||
class TestBaseMixin: | ||
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase""" | ||
dtype = None | ||
|
@@ -102,8 +131,18 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): | |
pass | ||
|
||
|
||
def skipIfNoExec(cmd): | ||
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available') | ||
|
||
|
||
def skipIfNoModule(module, display_name=None): | ||
display_name = display_name or module | ||
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') | ||
|
||
|
||
skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available') | ||
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') | ||
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension') | ||
|
||
|
||
def get_whitenoise( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
"""Test suites for checking numerical compatibility against Kaldi""" | ||
import json | ||
import shutil | ||
import unittest | ||
import subprocess | ||
|
||
import kaldi_io | ||
|
@@ -13,10 +11,6 @@ | |
from parameterized import parameterized, param | ||
|
||
|
||
def _not_available(cmd): | ||
return shutil.which(cmd) is None | ||
|
||
|
||
def _convert_args(**kwargs): | ||
args = [] | ||
for key, value in kwargs.items(): | ||
|
@@ -61,7 +55,7 @@ def assert_equal(self, output, *, expected, rtol=None, atol=None): | |
expected = expected.to(dtype=self.dtype, device=self.device) | ||
self.assertEqual(output, expected, rtol=rtol, atol=atol) | ||
|
||
@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available') | ||
@common_utils.skipIfNoExec('apply-cmvn-sliding') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: again, i'm not sure why this is changing as part of this pr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So as to avoid two duplicated implementation.
|
||
def test_sliding_window_cmn(self): | ||
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" | ||
kwargs = { | ||
|
@@ -78,7 +72,7 @@ def test_sliding_window_cmn(self): | |
self.assert_equal(result, expected=kaldi_result) | ||
|
||
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json'))) | ||
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available') | ||
@common_utils.skipIfNoExec('compute-fbank-feats') | ||
def test_fbank(self, kwargs): | ||
"""fbank should be numerically compatible with compute-fbank-feats""" | ||
wave_file = common_utils.get_asset_path('kaldi_file.wav') | ||
|
@@ -89,7 +83,7 @@ def test_fbank(self, kwargs): | |
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) | ||
|
||
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json'))) | ||
@unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available') | ||
@common_utils.skipIfNoExec('compute-spectrogram-feats') | ||
def test_spectrogram(self, kwargs): | ||
"""spectrogram should be numerically compatible with compute-spectrogram-feats""" | ||
wave_file = common_utils.get_asset_path('kaldi_file.wav') | ||
|
@@ -100,7 +94,7 @@ def test_spectrogram(self, kwargs): | |
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) | ||
|
||
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json'))) | ||
@unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available') | ||
@common_utils.skipIfNoExec('compute-mfcc-feats') | ||
def test_mfcc(self, kwargs): | ||
"""mfcc should be numerically compatible with compute-mfcc-feats""" | ||
wave_file = common_utils.get_asset_path('kaldi_file.wav') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def get_test_name(func, _, params): | ||
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import subprocess | ||
|
||
|
||
def get_encoding(dtype): | ||
encodings = { | ||
'float32': 'floating-point', | ||
'int32': 'signed-integer', | ||
'int16': 'signed-integer', | ||
'uint8': 'unsigned-integer', | ||
} | ||
return encodings[dtype] | ||
|
||
|
||
def get_bit_depth(dtype): | ||
bit_depths = { | ||
'float32': 32, | ||
'int32': 32, | ||
'int16': 16, | ||
'uint8': 8, | ||
} | ||
return bit_depths[dtype] | ||
|
||
|
||
def gen_audio_file( | ||
path, sample_rate, num_channels, | ||
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, | ||
): | ||
"""Generate synthetic audio file with `sox` command.""" | ||
command = [ | ||
'sox', | ||
'-V', # verbose | ||
'--rate', str(sample_rate), | ||
'--null', # no input | ||
'--channels', str(num_channels), | ||
] | ||
if compression is not None: | ||
command += ['--compression', str(compression)] | ||
if bit_depth is not None: | ||
command += ['--bits', str(bit_depth)] | ||
if encoding is not None: | ||
command += ['--encoding', str(encoding)] | ||
command += [ | ||
str(path), | ||
'synth', str(duration), # synthesizes for the given duration [sec] | ||
'sawtooth', '1', | ||
# saw tooth covers the both ends of value range, which is a good property for test. | ||
# similar to linspace(-1., 1.) | ||
# this introduces bigger boundary effect than sine when converted to mp3 | ||
] | ||
if attenuation is not None: | ||
command += ['vol', f'-{attenuation}dB'] | ||
print(' '.join(command)) | ||
subprocess.run(command, check=True) | ||
subprocess.run(['soxi', path], check=True) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import itertools | ||
from parameterized import parameterized | ||
|
||
from torchaudio.backend import sox_io_backend | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name | ||
) | ||
from . import sox_utils | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestInfo(TempDirMixin, PytorchTestCase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
Comment on lines
+22
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is nice :) is there a way to have pass keyword arguments with parameterized? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not know. Please refer to the doc https://pypi.org/project/parameterized/ |
||
)), name_func=get_test_name) | ||
def test_wav(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.info` can check wav file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.info` can check mp3 file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=bit_rate, duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
# mp3 does not preserve the number of samples | ||
# assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.info` can check flac file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=compression_level, duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.info` can check vorbis file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=quality_level, duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import itertools | ||
|
||
import torch | ||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
TorchaudioTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: | ||
return sox_io_backend.info(filepath) | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class SoxIO(TempDirMixin, TorchaudioTestCase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_info_wav(self, dtype, sample_rate, num_channels): | ||
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
audio_path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
) | ||
|
||
script_path = self.get_temp_path('info_func') | ||
torch.jit.script(py_info_func).save(script_path) | ||
ts_info_func = torch.jit.load(script_path) | ||
|
||
py_info = py_info_func(audio_path) | ||
ts_info = ts_info_func(audio_path) | ||
|
||
assert py_info.get_sample_rate() == ts_info.get_sample_rate() | ||
assert py_info.get_num_samples() == ts_info.get_num_samples() | ||
assert py_info.get_num_channels() == ts_info.get_num_channels() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
from torchaudio._internal import ( | ||
module_utils as _mod_utils, | ||
) | ||
|
||
|
||
@_mod_utils.requires_module('torchaudio._torchaudio') | ||
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: | ||
"""Get signal information of an audio file.""" | ||
return torch.ops.torchaudio.sox_io_get_info(filepath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you elaborate on why you are changing this default mechanic as part of this pr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The static build of
libsox
(which is what we ship as binary distribution) does not contain codecs forogg/vorbis
.Using
libsox-fmt-all
we can testogg/vorbis
types too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can still build third parties by using
BUILD_SOX=1
, but by default, we link with the static build. Is that correct?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
BUILD_SOX=0
_torchaudio
is linked against alibsox
found in the system. static/dynamic depends on what is found.