Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TorchScript-able "info" func to sox_io backend #728

Merged
merged 6 commits into from
Jun 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c pytorch-nightly pytorch "${cudatoolkit}"

printf "* Installing torchaudio\n"
BUILD_SOX=1 python setup.py develop
python setup.py develop
5 changes: 4 additions & 1 deletion .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ printf "* Installing dependencies (except PyTorch)\n"
conda env update --file "${this_dir}/environment.yml" --prune

# 4. Build codecs
build_tools/setup_helpers/build_third_party.sh
# build_tools/setup_helpers/build_third_party.sh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you elaborate on why you are changing this default mechanic as part of this pr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static build of libsox (which is what we ship as binary distribution) does not contain codecs for ogg/vorbis.
Using libsox-fmt-all we can test ogg/vorbis types too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can still build third parties by using BUILD_SOX=1, but by default, we link with the static build. Is that correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When BUILD_SOX=0 _torchaudio is linked against a libsox found in the system. static/dynamic depends on what is found.

# 4. Install codecs
apt update -q
apt install -y -q sox libsox-dev libsox-fmt-all
39 changes: 39 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import tempfile
import unittest
from typing import Union
Expand All @@ -7,6 +8,7 @@
import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio.list_audio_backends()
Expand Down Expand Up @@ -87,6 +89,33 @@ def set_audio_backend(backend):
torchaudio.set_audio_backend(be)


class TempDirMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR: just as is done below in a prior pr, what's the advantage of not inheriting PytorchTestCase here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid inheritance of PytorchTestCase multiple times at the test module.
TempDirMixin is designed to be composable so that each test case can choose use or not to use.
This decision should not collide with the fact we are requiring each test case to inherit PytorchTestCase.

"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
temp_dir = None

def setUp(self):
super().setUp()
self._init_temp_dir()

def tearDown(self):
super().tearDownClass()
self._clean_up_temp_dir()

def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name

def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None

def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)


class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
Expand All @@ -102,8 +131,18 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass


def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')


def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')


skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')


def get_whitenoise(
Expand Down
14 changes: 4 additions & 10 deletions test/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Test suites for checking numerical compatibility against Kaldi"""
import json
import shutil
import unittest
import subprocess

import kaldi_io
Expand All @@ -13,10 +11,6 @@
from parameterized import parameterized, param


def _not_available(cmd):
return shutil.which(cmd) is None


def _convert_args(**kwargs):
args = []
for key, value in kwargs.items():
Expand Down Expand Up @@ -61,7 +55,7 @@ def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)

@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
@common_utils.skipIfNoExec('apply-cmvn-sliding')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: again, i'm not sure why this is changing as part of this pr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as to avoid two duplicated implementation.

kaldi_compatibility_impl had this original version of skipIfNoExec implementation (which is _not_available func + unittest.skipIf), and now that implementation is promoted to common utility because the new sox_io test uses it too.

def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
Expand All @@ -78,7 +72,7 @@ def test_sliding_window_cmn(self):
self.assert_equal(result, expected=kaldi_result)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json')))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
@common_utils.skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand All @@ -89,7 +83,7 @@ def test_fbank(self, kwargs):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json')))
@unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available')
@common_utils.skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand All @@ -100,7 +94,7 @@ def test_spectrogram(self, kwargs):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json')))
@unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available')
@common_utils.skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand Down
Empty file added test/sox_io_backend/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions test/sox_io_backend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'
54 changes: 54 additions & 0 deletions test/sox_io_backend/sox_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import subprocess


def get_encoding(dtype):
encodings = {
'float32': 'floating-point',
'int32': 'signed-integer',
'int16': 'signed-integer',
'uint8': 'unsigned-integer',
}
return encodings[dtype]


def get_bit_depth(dtype):
bit_depths = {
'float32': 32,
'int32': 32,
'int16': 16,
'uint8': 8,
}
return bit_depths[dtype]


def gen_audio_file(
path, sample_rate, num_channels,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
):
"""Generate synthetic audio file with `sox` command."""
command = [
'sox',
'-V', # verbose
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
]
if compression is not None:
command += ['--compression', str(compression)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if encoding is not None:
command += ['--encoding', str(encoding)]
command += [
str(path),
'synth', str(duration), # synthesizes for the given duration [sec]
'sawtooth', '1',
# saw tooth covers the both ends of value range, which is a good property for test.
# similar to linspace(-1., 1.)
# this introduces bigger boundary effect than sine when converted to mp3
]
if attenuation is not None:
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)
114 changes: 114 additions & 0 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import itertools
from parameterized import parameterized

from torchaudio.backend import sox_io_backend

from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name
)
from . import sox_utils


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
Comment on lines +22 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is nice :) is there a way to have pass keyword arguments with parameterized?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know. Please refer to the doc https://pypi.org/project/parameterized/

)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
48 changes: 48 additions & 0 deletions test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import itertools

import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
)
from . import sox_utils


def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)


@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels):
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)

script_path = self.get_temp_path('info_func')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)

py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_channels() == ts_info.get_num_channels()
7 changes: 2 additions & 5 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import torchaudio
from torchaudio._internal.module_utils import is_module_available

from . import common_utils

Expand All @@ -28,15 +27,13 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
backend_module = torchaudio.backend.no_backend


@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
@common_utils.skipIfNoExtension
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend


@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
10 changes: 10 additions & 0 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)


@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
"""Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath)