Skip to content

Commit

Permalink
Add TorchScript-able "info" func to sox_io backend (#728)
Browse files Browse the repository at this point in the history
This is a part of PRs to add new "sox_io" backend #726, and depends on #718.

This PR adds `info` function to "sox_io" backend, which allows users to fetch some metadata of an audio file. 
At this moment, the information retrieved are;

 - Number of samples in the audio file
 - Sampling rate
 - Number of channels
  • Loading branch information
mthrok committed Jun 19, 2020
1 parent f8eac89 commit 88fccd1
Show file tree
Hide file tree
Showing 14 changed files with 351 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c pytorch-nightly pytorch "${cudatoolkit}"

printf "* Installing torchaudio\n"
BUILD_SOX=1 python setup.py develop
python setup.py develop
5 changes: 4 additions & 1 deletion .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ printf "* Installing dependencies (except PyTorch)\n"
conda env update --file "${this_dir}/environment.yml" --prune

# 4. Build codecs
build_tools/setup_helpers/build_third_party.sh
# build_tools/setup_helpers/build_third_party.sh
# 4. Install codecs
apt update -q
apt install -y -q sox libsox-dev libsox-fmt-all
39 changes: 39 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import tempfile
import unittest
from typing import Union
Expand All @@ -7,6 +8,7 @@
import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available

_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio.list_audio_backends()
Expand Down Expand Up @@ -87,6 +89,33 @@ def set_audio_backend(backend):
torchaudio.set_audio_backend(be)


class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
temp_dir = None

def setUp(self):
super().setUp()
self._init_temp_dir()

def tearDown(self):
super().tearDownClass()
self._clean_up_temp_dir()

def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name

def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None

def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)


class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
Expand All @@ -102,8 +131,18 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass


def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')


def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')


skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')


def get_whitenoise(
Expand Down
14 changes: 4 additions & 10 deletions test/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Test suites for checking numerical compatibility against Kaldi"""
import json
import shutil
import unittest
import subprocess

import kaldi_io
Expand All @@ -13,10 +11,6 @@
from parameterized import parameterized, param


def _not_available(cmd):
return shutil.which(cmd) is None


def _convert_args(**kwargs):
args = []
for key, value in kwargs.items():
Expand Down Expand Up @@ -61,7 +55,7 @@ def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)

@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
@common_utils.skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
Expand All @@ -78,7 +72,7 @@ def test_sliding_window_cmn(self):
self.assert_equal(result, expected=kaldi_result)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json')))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
@common_utils.skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand All @@ -89,7 +83,7 @@ def test_fbank(self, kwargs):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json')))
@unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available')
@common_utils.skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand All @@ -100,7 +94,7 @@ def test_spectrogram(self, kwargs):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json')))
@unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available')
@common_utils.skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav')
Expand Down
Empty file added test/sox_io_backend/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions test/sox_io_backend/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'
54 changes: 54 additions & 0 deletions test/sox_io_backend/sox_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import subprocess


def get_encoding(dtype):
encodings = {
'float32': 'floating-point',
'int32': 'signed-integer',
'int16': 'signed-integer',
'uint8': 'unsigned-integer',
}
return encodings[dtype]


def get_bit_depth(dtype):
bit_depths = {
'float32': 32,
'int32': 32,
'int16': 16,
'uint8': 8,
}
return bit_depths[dtype]


def gen_audio_file(
path, sample_rate, num_channels,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
):
"""Generate synthetic audio file with `sox` command."""
command = [
'sox',
'-V', # verbose
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
]
if compression is not None:
command += ['--compression', str(compression)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if encoding is not None:
command += ['--encoding', str(encoding)]
command += [
str(path),
'synth', str(duration), # synthesizes for the given duration [sec]
'sawtooth', '1',
# saw tooth covers the both ends of value range, which is a good property for test.
# similar to linspace(-1., 1.)
# this introduces bigger boundary effect than sine when converted to mp3
]
if attenuation is not None:
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)
114 changes: 114 additions & 0 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import itertools
from parameterized import parameterized

from torchaudio.backend import sox_io_backend

from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name
)
from . import sox_utils


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
48 changes: 48 additions & 0 deletions test/sox_io_backend/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import itertools

import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
)
from . import sox_utils


def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)


@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels):
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)

script_path = self.get_temp_path('info_func')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)

py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_channels() == ts_info.get_num_channels()
7 changes: 2 additions & 5 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import torchaudio
from torchaudio._internal.module_utils import is_module_available

from . import common_utils

Expand All @@ -28,15 +27,13 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
backend_module = torchaudio.backend.no_backend


@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
@common_utils.skipIfNoExtension
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend


@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
10 changes: 10 additions & 0 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)


@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
"""Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath)

0 comments on commit 88fccd1

Please sign in to comment.