Skip to content

Commit

Permalink
Add load function
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 25, 2020
1 parent 0f0d0af commit 036bf8a
Show file tree
Hide file tree
Showing 11 changed files with 772 additions and 55 deletions.
88 changes: 88 additions & 0 deletions test/sox_io_backend/common.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,90 @@
from typing import Optional

import torch
import scipy.io.wavfile


def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'


def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype == torch.float32:
pass
elif tensor.dtype == torch.int32:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 2147483647.
tensor[tensor < 0] /= 2147483648.
elif tensor.dtype == torch.int16:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 32767.
tensor[tensor < 0] /= 32768.
elif tensor.dtype == torch.uint8:
tensor = tensor.to(torch.float32) - 128
tensor[tensor > 0] /= 127.
tensor[tensor < 0] /= 128.
return tensor


def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(torch, dtype)

if num_frames is None:
if dtype == 'uint8':
num_frames = 256
else:
num_frames = 1 << 16

if dtype == 'uint8':
base = torch.linspace(0, 255, num_frames, dtype=dtype_)
if dtype == 'float32':
base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
if dtype == 'int32':
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
if dtype == 'int16':
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_)
data = base.repeat([num_channels, 1])
if not channels_first:
data = data.transpose(1, 0)
if normalize:
data = normalize_wav(data)
return data


def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor:
"""Load wav file without torchaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = torch.from_numpy(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
if channels_first:
data = data.transpose(1, 0)
return data, sample_rate


def save_wav(path, data, sample_rate, channels_first=True):
"""Save wav file without torchaudio"""
if channels_first:
data = data.transpose(1, 0)
scipy.io.wavfile.write(path, sample_rate, data.numpy())
18 changes: 17 additions & 1 deletion test/sox_io_backend/sox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def gen_audio_file(
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
):
"""Generate synthetic audio file with `sox` command."""
if path.endswith('.wav'):
raise RuntimeError(
'Use get_wav_data and save_wav to generate wav file for accurate result.')
command = [
'sox',
'-V', # verbose
Expand All @@ -51,4 +54,17 @@ def gen_audio_file(
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)


def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
command += ['--compression', str(compression)]
command += [dst_path]
print(' '.join(command))
subprocess.run(command, check=True)
20 changes: 7 additions & 13 deletions test/sox_io_backend/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
skipIfNoExtension,
)
from .common import (
get_test_name
get_test_name,
get_wav_data,
save_wav,
)
from . import sox_utils

Expand All @@ -27,12 +29,8 @@ def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
Expand All @@ -47,12 +45,8 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
Expand Down

0 comments on commit 036bf8a

Please sign in to comment.