Skip to content

Commit

Permalink
Add load function
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 18, 2020
1 parent f473bc4 commit 6da4ff4
Show file tree
Hide file tree
Showing 8 changed files with 584 additions and 0 deletions.
80 changes: 80 additions & 0 deletions test/sox_io/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,83 @@

def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'


def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype == torch.float32:
pass
elif tensor.dtype == torch.int32:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 2147483647.
tensor[tensor < 0] /= 2147483648.
elif tensor.dtype == torch.int16:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 32767.
tensor[tensor < 0] /= 32768.
elif tensor.dtype == torch.uint8:
tensor = tensor.to(torch.float32) - 128
tensor[tensor > 0] /= 127.
tensor[tensor < 0] /= 128.
return tensor


def get_wav_data(
dtype: str,
num_channels: int,
*,
num_samples: Optional[int] = None,
normalize: bool = False,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483647, 2147483647] for int32
[-32767, 32767] for int16
[0, 255] for uint8
num_samples allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(torch, dtype)

if num_samples is None:
if dtype == 'uint8':
num_samples = 256
else:
num_samples = 1 << 16

if dtype == 'uint8':
base = torch.linspace(0, 255, num_samples, dtype=dtype_)
if dtype == 'float32':
base = torch.linspace(-1., 1., num_samples, dtype=dtype_)
if dtype == 'int32':
# torch.linspace is broken when dtype=torch.int32
# https://github.com/pytorch/pytorch/issues/40118
base = torch.linspace(-2147483648, 2147483647, num_samples, dtype=torch.float32)
base = base.to(torch.int32)
base[0] = -2147483648
base[-1] = 2147483647
if dtype == 'int16':
base = torch.linspace(-32768, 32767, num_samples, dtype=dtype_)
data = base.repeat([num_channels, 1]).transpose(1, 0)
if normalize:
data = normalize_wav(data)
return data


def load_wav(path: str, normalize=False) -> torch.Tensor:
"""Load wav file without torchaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = torch.from_numpy(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
return data, sample_rate


def save_wav(path, data, sample_rate):
"""Save wav file without torchaudio"""
scipy.io.wavfile.write(path, sample_rate, data.numpy())
15 changes: 15 additions & 0 deletions test/sox_io/sox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@ def gen_audio_file(
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)


def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
command += ['--compression', str(compression)]
command += [dst_path]
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', dst_path], check=True)
229 changes: 229 additions & 0 deletions test/sox_io/test_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import itertools

import torchaudio
from parameterized import parameterized

from .. import common_utils
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
)
from .common import (
get_test_name,
get_wav_data,
load_wav,
save_wav,
)
from . import sox_utils


class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
expected = load_wav(path, normalize=normalize)[0]
data, sr = torchaudio.backend.sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
self.assertEqual(data, expected)

def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.load` can load mp3 format.
mp3 encoding introduces delay and boundary effects so
we create reference wav file from mp3
x
|
| Generate mp3 with Sox
|
v
mp3 --- Convert to wav with Sox --> wav
| |
| load with torchaduio | load with scipy
| |
v v
tensor --------> compare <--------- tensor
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3')
ref_path = f'{path}.wav'

sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration)
sox_utils.convert_audio_file(path, ref_path)

data, sr = torchaudio.backend.sox_io_backend.load(path)
data_ref = load_wav(ref_path, normalize=True)[0]

assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)

def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.load` can load flac format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac')
ref_path = f'{path}.wav'

sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, bit_depth=16, duration=duration)
sox_utils.convert_audio_file(path, ref_path)

data, sr = torchaudio.backend.sox_io_backend.load(path)
data_ref = load_wav(ref_path, normalize=True)[0]

assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.load` can load vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis')
ref_path = f'{path}.wav'

sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, bit_depth=16, duration=duration)
sox_utils.convert_audio_file(path, ref_path)

data, sr = torchaudio.backend.sox_io_backend.load(path)
data_ref = load_wav(ref_path, normalize=True)[0]

assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)


@common_utils.skipIfNoExec('sox')
@common_utils.skipIfNoExtension
class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=get_test_name)
def testload_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

@parameterized.expand(list(itertools.product(
['int16'],
[16000],
[2],
[False],
)), name_func=get_test_name)
def testload_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_load_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

@parameterized.expand(list(itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_load_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)

@parameterized.expand(list(itertools.product(
[16000],
[2],
[128],
)), name_func=get_test_name)
def test_load_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_load_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)

@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=get_test_name)
def test_load_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, two_hours)

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_load_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)

@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=get_test_name)
def test_load_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)


@common_utils.skipIfNoExec('sox')
@common_utils.skipIfNoExtension
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None
path = None

def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data('float32', num_channels=2)
self.path = self.get_temp_path('test.wave')
save_wav(self.path, self.original, sample_rate)

@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=get_test_name)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = torchaudio.backend.sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[frame_offset:frame_end])

@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
def test_load_channel_first(self, channel_first):
"""channel_first swaps axes"""
found, _ = torchaudio.backend.sox_io_backend.load(self.path, channel_first=channel_first)
expected = self.original.transpose(1, 0) if channel_first else self.original
self.assertEqual(found, expected)
26 changes: 26 additions & 0 deletions test/sox_io/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)


def py_load_func(filepath: str):
return sox_io_backend.load(filepath)


@common_utils.skipIfNoExec('sox')
@common_utils.skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
Expand All @@ -45,3 +49,25 @@ def test_info_wav(self, dtype, sample_rate, num_channels):
assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_channels() == ts_info.get_num_channels()

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=get_test_name)
def test_load_wav(self, dtype, sample_rate, num_channels, normalize):
audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype))

script_path = self.get_temp_path('load_func')
torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path)

py_data, py_sr = py_load_func(audio_path)
ts_data, ts_sr = ts_load_func(audio_path)

self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data)

0 comments on commit 6da4ff4

Please sign in to comment.