Skip to content

Commit

Permalink
Add info for sox_io backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 18, 2020
1 parent c67250d commit 8bdfdb7
Show file tree
Hide file tree
Showing 13 changed files with 406 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c pytorch-nightly pytorch "${cudatoolkit}"

printf "* Installing torchaudio\n"
BUILD_SOX=1 python setup.py develop
python setup.py develop
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

python -m torch.utils.collect_env
export PATH="${PATH}:third_party/build/bin/"
pytest --cov=torchaudio --junitxml=test-results/junit.xml -v --durations 20 test
5 changes: 4 additions & 1 deletion .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ printf "* Installing dependencies (except PyTorch)\n"
conda env update --file "${this_dir}/environment.yml" --prune

# 4. Build codecs
build_tools/setup_helpers/build_third_party.sh
# build_tools/setup_helpers/build_third_party.sh
# 4. Install codecs
apt update -q
apt install -y -q sox libsox-dev libsox-fmt-all
27 changes: 27 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,33 @@ def set_audio_backend(backend):
torchaudio.set_audio_backend(be)


class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
temp_dir = None

def setUp(self):
super().setUp()
self._init_temp_dir()

def tearDown(self):
super().tearDownClass()
self._clean_up_temp_dir()

def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name

def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None

def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)


class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
Expand Down
Empty file added test/sox_io/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions test/sox_io/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Optional

import torch
import scipy.io.wavfile


def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'


def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype == torch.float32:
pass
elif tensor.dtype == torch.int32:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 2147483647.
tensor[tensor < 0] /= 2147483648.
elif tensor.dtype == torch.int16:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 32767.
tensor[tensor < 0] /= 32768.
elif tensor.dtype == torch.uint8:
tensor = tensor.to(torch.float32) - 128
tensor[tensor > 0] /= 127.
tensor[tensor < 0] /= 128.
return tensor


def get_wav_data(
dtype: str,
num_channels: int,
*,
num_samples: Optional[int] = None,
normalize: bool = False,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483647, 2147483647] for int32
[-32767, 32767] for int16
[0, 255] for uint8
num_samples allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(torch, dtype)

if num_samples is None:
if dtype == 'uint8':
num_samples = 256
else:
num_samples = 1 << 16

if dtype == 'uint8':
base = torch.linspace(0, 255, num_samples, dtype=dtype_)
if dtype == 'float32':
base = torch.linspace(-1., 1., num_samples, dtype=dtype_)
if dtype == 'int32':
# torch.linspace is broken when dtype=torch.int32
# https://github.com/pytorch/pytorch/issues/40118
base = torch.linspace(-2147483648, 2147483647, num_samples, dtype=torch.float32)
base = base.to(torch.int32)
base[0] = -2147483648
base[-1] = 2147483647
if dtype == 'int16':
base = torch.linspace(-32768, 32767, num_samples, dtype=dtype_)
data = base.repeat([num_channels, 1]).transpose(1, 0)
if normalize:
data = normalize_wav(data)
return data


def load_wav(path: str, normalize=False) -> torch.Tensor:
"""Load wav file without torchaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = torch.from_numpy(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
return data, sample_rate


def save_wav(path, data, sample_rate):
"""Save wav file without torchaudio"""
scipy.io.wavfile.write(path, sample_rate, data.numpy())
69 changes: 69 additions & 0 deletions test/sox_io/sox_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import subprocess


def get_encoding(dtype):
encodings = {
'float32': 'floating-point',
'int32': 'signed-integer',
'int16': 'signed-integer',
'uint8': 'unsigned-integer',
}
return encodings[dtype]


def get_bit_depth(dtype):
bit_depths = {
'float32': 32,
'int32': 32,
'int16': 16,
'uint8': 8,
}
return bit_depths[dtype]


def gen_audio_file(
path, sample_rate, num_channels,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
):
"""Generate synthetic audio file with `sox` command."""
command = [
'sox',
'-V', # verbose
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
]
if compression is not None:
command += ['--compression', str(compression)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if encoding is not None:
command += ['--encoding', str(encoding)]
command += [
str(path),
'synth', str(duration), # synthesizes for the given duration [sec]
'sawtooth', '1',
# saw tooth covers the both ends of value range, which is a good property for test.
# similar to linspace(-1., 1.)
# this introduces bigger boundary effect than sine when converted to mp3
]
if attenuation is not None:
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)


def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
command += ['--compression', str(compression)]
command += [dst_path]
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', dst_path], check=True)
89 changes: 89 additions & 0 deletions test/sox_io/test_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import itertools
from parameterized import parameterized

import torchaudio

from ..common_utils import (
TempDirMixin,
PytorchTestCase,
)
from .common import (
get_test_name
)
from . import sox_utils


class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels):
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)
info = torchaudio.backend.sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_info_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save wav with more than 2 channels."""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)
info = torchaudio.backend.sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_info_mp3(self, sample_rate, num_channels, bit_rate):
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=bit_rate)
info = torchaudio.backend.sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# assert info.get_num_samples() == sample_rate
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_info_flac(self, sample_rate, num_channels, compression_level):
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=compression_level)
info = torchaudio.backend.sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate
assert info.get_num_channels() == num_channels

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_info_vorbis(self, sample_rate, num_channels, quality_level):
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=quality_level)
info = torchaudio.backend.sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate
assert info.get_num_channels() == num_channels
44 changes: 44 additions & 0 deletions test/sox_io/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import itertools

import torch
import torchaudio
from parameterized import parameterized

from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
)
from .common import (
get_test_name,
)
from . import sox_utils


def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return torchaudio.info(filepath)


class SoxIO(TempDirMixin, TorchaudioTestCase):
backend = 'sox_io'

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels):
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)

ts_info_func = torch.jit.script(py_info_func)

py_info = py_info_func(path)
ts_info = ts_info_func(path)

assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_channels() == ts_info.get_num_channels()
10 changes: 10 additions & 0 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)


@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
"""Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath)
7 changes: 7 additions & 0 deletions torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TORCHAUDIO_REGISTER_H
#define TORCHAUDIO_REGISTER_H

#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/typedefs.h>

namespace torchaudio {
Expand All @@ -13,6 +14,12 @@ static auto registerSignalInfo =
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);

static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info")
.catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>());

} // namespace
} // namespace torchaudio
#endif

0 comments on commit 8bdfdb7

Please sign in to comment.