-
Notifications
You must be signed in to change notification settings - Fork 633
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TorchScript-able "info" func to sox_io backend (#728)
This is a part of PRs to add new "sox_io" backend #726, and depends on #718. This PR adds `info` function to "sox_io" backend, which allows users to fetch some metadata of an audio file. At this moment, the information retrieved are; - Number of samples in the audio file - Sampling rate - Number of channels
- Loading branch information
Showing
14 changed files
with
351 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def get_test_name(func, _, params): | ||
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import subprocess | ||
|
||
|
||
def get_encoding(dtype): | ||
encodings = { | ||
'float32': 'floating-point', | ||
'int32': 'signed-integer', | ||
'int16': 'signed-integer', | ||
'uint8': 'unsigned-integer', | ||
} | ||
return encodings[dtype] | ||
|
||
|
||
def get_bit_depth(dtype): | ||
bit_depths = { | ||
'float32': 32, | ||
'int32': 32, | ||
'int16': 16, | ||
'uint8': 8, | ||
} | ||
return bit_depths[dtype] | ||
|
||
|
||
def gen_audio_file( | ||
path, sample_rate, num_channels, | ||
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, | ||
): | ||
"""Generate synthetic audio file with `sox` command.""" | ||
command = [ | ||
'sox', | ||
'-V', # verbose | ||
'--rate', str(sample_rate), | ||
'--null', # no input | ||
'--channels', str(num_channels), | ||
] | ||
if compression is not None: | ||
command += ['--compression', str(compression)] | ||
if bit_depth is not None: | ||
command += ['--bits', str(bit_depth)] | ||
if encoding is not None: | ||
command += ['--encoding', str(encoding)] | ||
command += [ | ||
str(path), | ||
'synth', str(duration), # synthesizes for the given duration [sec] | ||
'sawtooth', '1', | ||
# saw tooth covers the both ends of value range, which is a good property for test. | ||
# similar to linspace(-1., 1.) | ||
# this introduces bigger boundary effect than sine when converted to mp3 | ||
] | ||
if attenuation is not None: | ||
command += ['vol', f'-{attenuation}dB'] | ||
print(' '.join(command)) | ||
subprocess.run(command, check=True) | ||
subprocess.run(['soxi', path], check=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import itertools | ||
from parameterized import parameterized | ||
|
||
from torchaudio.backend import sox_io_backend | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name | ||
) | ||
from . import sox_utils | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class TestInfo(TempDirMixin, PytorchTestCase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_wav(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.info` can check wav file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): | ||
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.info` can check mp3 file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=bit_rate, duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
# mp3 does not preserve the number of samples | ||
# assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.info` can check flac file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=compression_level, duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.info` can check vorbis file correctly""" | ||
duration = 1 | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=quality_level, duration=duration, | ||
) | ||
info = sox_io_backend.info(path) | ||
assert info.get_sample_rate() == sample_rate | ||
assert info.get_num_samples() == sample_rate * duration | ||
assert info.get_num_channels() == num_channels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import itertools | ||
|
||
import torch | ||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from ..common_utils import ( | ||
TempDirMixin, | ||
TorchaudioTestCase, | ||
skipIfNoExec, | ||
skipIfNoExtension, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: | ||
return sox_io_backend.info(filepath) | ||
|
||
|
||
@skipIfNoExec('sox') | ||
@skipIfNoExtension | ||
class SoxIO(TempDirMixin, TorchaudioTestCase): | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
)), name_func=get_test_name) | ||
def test_info_wav(self, dtype, sample_rate, num_channels): | ||
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') | ||
sox_utils.gen_audio_file( | ||
audio_path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), | ||
encoding=sox_utils.get_encoding(dtype), | ||
) | ||
|
||
script_path = self.get_temp_path('info_func') | ||
torch.jit.script(py_info_func).save(script_path) | ||
ts_info_func = torch.jit.load(script_path) | ||
|
||
py_info = py_info_func(audio_path) | ||
ts_info = ts_info_func(audio_path) | ||
|
||
assert py_info.get_sample_rate() == ts_info.get_sample_rate() | ||
assert py_info.get_num_samples() == ts_info.get_num_samples() | ||
assert py_info.get_num_channels() == ts_info.get_num_channels() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import torch | ||
from torchaudio._internal import ( | ||
module_utils as _mod_utils, | ||
) | ||
|
||
|
||
@_mod_utils.requires_module('torchaudio._torchaudio') | ||
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: | ||
"""Get signal information of an audio file.""" | ||
return torch.ops.torchaudio.sox_io_get_info(filepath) |
Oops, something went wrong.