-
Notifications
You must be signed in to change notification settings - Fork 633
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
406 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import scipy.io.wavfile | ||
|
||
|
||
def get_test_name(func, _, params): | ||
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' | ||
|
||
|
||
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: | ||
if tensor.dtype == torch.float32: | ||
pass | ||
elif tensor.dtype == torch.int32: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 2147483647. | ||
tensor[tensor < 0] /= 2147483648. | ||
elif tensor.dtype == torch.int16: | ||
tensor = tensor.to(torch.float32) | ||
tensor[tensor > 0] /= 32767. | ||
tensor[tensor < 0] /= 32768. | ||
elif tensor.dtype == torch.uint8: | ||
tensor = tensor.to(torch.float32) - 128 | ||
tensor[tensor > 0] /= 127. | ||
tensor[tensor < 0] /= 128. | ||
return tensor | ||
|
||
|
||
def get_wav_data( | ||
dtype: str, | ||
num_channels: int, | ||
*, | ||
num_samples: Optional[int] = None, | ||
normalize: bool = False, | ||
): | ||
"""Generate linear signal of the given dtype and num_channels | ||
Data range is | ||
[-1.0, 1.0] for float32, | ||
[-2147483647, 2147483647] for int32 | ||
[-32767, 32767] for int16 | ||
[0, 255] for uint8 | ||
num_samples allow to change the linear interpolation parameter. | ||
Default values are 256 for uint8, else 1 << 16. | ||
1 << 16 as default is so that int16 value range is completely covered. | ||
""" | ||
dtype_ = getattr(torch, dtype) | ||
|
||
if num_samples is None: | ||
if dtype == 'uint8': | ||
num_samples = 256 | ||
else: | ||
num_samples = 1 << 16 | ||
|
||
if dtype == 'uint8': | ||
base = torch.linspace(0, 255, num_samples, dtype=dtype_) | ||
if dtype == 'float32': | ||
base = torch.linspace(-1., 1., num_samples, dtype=dtype_) | ||
if dtype == 'int32': | ||
# torch.linspace is broken when dtype=torch.int32 | ||
# https://github.com/pytorch/pytorch/issues/40118 | ||
base = torch.linspace(-2147483648, 2147483647, num_samples, dtype=torch.float32) | ||
base = base.to(torch.int32) | ||
base[0] = -2147483648 | ||
base[-1] = 2147483647 | ||
if dtype == 'int16': | ||
base = torch.linspace(-32768, 32767, num_samples, dtype=dtype_) | ||
data = base.repeat([num_channels, 1]).transpose(1, 0) | ||
if normalize: | ||
data = normalize_wav(data) | ||
return data | ||
|
||
|
||
def load_wav(path: str, normalize=False) -> torch.Tensor: | ||
"""Load wav file without torchaudio""" | ||
sample_rate, data = scipy.io.wavfile.read(path) | ||
data = torch.from_numpy(data.copy()) | ||
if data.ndim == 1: | ||
data = data.unsqueeze(1) | ||
if normalize: | ||
data = normalize_wav(data) | ||
return data, sample_rate | ||
|
||
|
||
def save_wav(path, data, sample_rate): | ||
"""Save wav file without torchaudio""" | ||
scipy.io.wavfile.write(path, sample_rate, data.numpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import subprocess | ||
|
||
|
||
def get_encoding(dtype): | ||
encodings = { | ||
'float32': 'floating-point', | ||
'int32': 'signed-integer', | ||
'int16': 'signed-integer', | ||
'uint8': 'unsigned-integer', | ||
} | ||
return encodings[dtype] | ||
|
||
|
||
def get_bit_depth(dtype): | ||
bit_depths = { | ||
'float32': 32, | ||
'int32': 32, | ||
'int16': 16, | ||
'uint8': 8, | ||
} | ||
return bit_depths[dtype] | ||
|
||
|
||
def gen_audio_file( | ||
path, sample_rate, num_channels, | ||
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, | ||
): | ||
"""Generate synthetic audio file with `sox` command.""" | ||
command = [ | ||
'sox', | ||
'-V', # verbose | ||
'--rate', str(sample_rate), | ||
'--null', # no input | ||
'--channels', str(num_channels), | ||
] | ||
if compression is not None: | ||
command += ['--compression', str(compression)] | ||
if bit_depth is not None: | ||
command += ['--bits', str(bit_depth)] | ||
if encoding is not None: | ||
command += ['--encoding', str(encoding)] | ||
command += [ | ||
str(path), | ||
'synth', str(duration), # synthesizes for the given duration [sec] | ||
'sawtooth', '1', | ||
# saw tooth covers the both ends of value range, which is a good property for test. | ||
# similar to linspace(-1., 1.) | ||
# this introduces bigger boundary effect than sine when converted to mp3 | ||
] | ||
if attenuation is not None: | ||
command += ['vol', f'-{attenuation}dB'] | ||
print(' '.join(command)) | ||
subprocess.run(command, check=True) | ||
subprocess.run(['soxi', path], check=True) | ||
|
||
|
||
def convert_audio_file( | ||
src_path, dst_path, | ||
*, bit_depth=None, compression=None): | ||
"""Convert audio file with `sox` command.""" | ||
command = ['sox', str(src_path)] | ||
if bit_depth is not None: | ||
command += ['--bits', str(bit_depth)] | ||
if compression is not None: | ||
command += ['--compression', str(compression)] | ||
command += [dst_path] | ||
print(' '.join(command)) | ||
subprocess.run(command, check=True) | ||
subprocess.run(['soxi', dst_path], check=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import itertools | ||
from parameterized import parameterized | ||
|
||
import torchaudio | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
) | ||
from .common import ( | ||
get_test_name | ||
) | ||
from . import sox_utils | ||
|
||
|
||
class TestInfo(TempDirMixin, PytorchTestCase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_info_wav(self, dtype, sample_rate, num_channels): | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
) | ||
info = torchaudio.backend.sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_info_wav_multiple_channels(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.save` can save wav with more than 2 channels.""" | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
) | ||
info = torchaudio.backend.sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_info_mp3(self, sample_rate, num_channels, bit_rate): | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3') | ||
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=bit_rate) | ||
info = torchaudio.backend.sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
# assert info.get_num_samples() == sample_rate | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_info_flac(self, sample_rate, num_channels, compression_level): | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') | ||
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=compression_level) | ||
info = torchaudio.backend.sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_info_vorbis(self, sample_rate, num_channels, quality_level): | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis') | ||
sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=quality_level) | ||
info = torchaudio.backend.sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate | ||
assert info.get_num_channels() == num_channels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import itertools | ||
|
||
import torch | ||
import torchaudio | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
TorchaudioTestCase, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: | ||
return torchaudio.info(filepath) | ||
|
||
|
||
class SoxIO(TempDirMixin, TorchaudioTestCase): | ||
backend = 'sox_io' | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_info_wav(self, dtype, sample_rate, num_channels): | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
) | ||
|
||
ts_info_func = torch.jit.script(py_info_func) | ||
|
||
py_info = py_info_func(path) | ||
ts_info = ts_info_func(path) | ||
|
||
assert py_info.get_sample_rate() == ts_info.get_sample_rate() | ||
assert py_info.get_num_samples() == ts_info.get_num_samples() | ||
assert py_info.get_num_channels() == ts_info.get_num_channels() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
from torchaudio._internal import ( | ||
module_utils as _mod_utils, | ||
) | ||
|
||
|
||
@_mod_utils.requires_module('torchaudio._torchaudio') | ||
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: | ||
"""Get signal information of an audio file.""" | ||
return torch.ops.torchaudio.sox_io_get_info(filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.