-
Notifications
You must be signed in to change notification settings - Fork 634
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
595 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import itertools | ||
|
||
from torchaudio.backend import sox_io_backend | ||
from parameterized import parameterized | ||
|
||
from .. import common_utils | ||
from ..common_utils import ( | ||
TempDirMixin, | ||
PytorchTestCase, | ||
) | ||
from .common import ( | ||
get_test_name, | ||
get_wav_data, | ||
load_wav, | ||
save_wav, | ||
) | ||
from . import sox_utils | ||
|
||
|
||
class LoadTestBase(TempDirMixin, PytorchTestCase): | ||
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): | ||
"""`sox_io_backend.load` can load wav format correctly. | ||
Wav data loaded with sox_io backend should match those with scipy | ||
""" | ||
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') | ||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) | ||
expected = load_wav(path, normalize=normalize)[0] | ||
data, sr = sox_io_backend.load(path, normalize=normalize) | ||
assert sr == sample_rate | ||
self.assertEqual(data, expected) | ||
|
||
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): | ||
"""`sox_io_backend.load` can load mp3 format. | ||
mp3 encoding introduces delay and boundary effects so | ||
we create reference wav file from mp3 | ||
x | ||
| | ||
| Generate mp3 with Sox | ||
| | ||
v | ||
mp3 --- Convert to wav with Sox --> wav | ||
| | | ||
| load with torchaduio | load with scipy | ||
| | | ||
v v | ||
tensor --------> compare <--------- tensor | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3') | ||
ref_path = f'{path}.wav' | ||
|
||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=bit_rate, duration=duration) | ||
sox_utils.convert_audio_file(path, ref_path) | ||
|
||
data, sr = sox_io_backend.load(path) | ||
data_ref = load_wav(ref_path)[0] | ||
|
||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06) | ||
|
||
def assert_flac(self, sample_rate, num_channels, compression_level, duration): | ||
"""`sox_io_backend.load` can load flac format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac') | ||
ref_path = f'{path}.wav' | ||
|
||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=compression_level, bit_depth=16, duration=duration) | ||
sox_utils.convert_audio_file(path, ref_path) | ||
|
||
data, sr = sox_io_backend.load(path) | ||
data_ref = load_wav(ref_path)[0] | ||
|
||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) | ||
|
||
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): | ||
"""`sox_io_backend.load` can load vorbis format. | ||
This test takes the same strategy as mp3 to compare the result | ||
""" | ||
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis') | ||
ref_path = f'{path}.wav' | ||
|
||
sox_utils.gen_audio_file( | ||
path, sample_rate, num_channels, | ||
compression=quality_level, bit_depth=16, duration=duration) | ||
sox_utils.convert_audio_file(path, ref_path) | ||
|
||
data, sr = sox_io_backend.load(path) | ||
data_ref = load_wav(ref_path)[0] | ||
|
||
assert sr == sample_rate | ||
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) | ||
|
||
|
||
@common_utils.skipIfNoExec('sox') | ||
@common_utils.skipIfNoExtension | ||
class TestLoad(LoadTestBase): | ||
"""Test the correctness of `sox_io_backend.load` for various formats""" | ||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[8000, 16000], | ||
[1, 2], | ||
[False, True], | ||
)), name_func=get_test_name) | ||
def test_wav(self, dtype, sample_rate, num_channels, normalize): | ||
"""`sox_io_backend.load` can load wav format correctly.""" | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['int16'], | ||
[16000], | ||
[2], | ||
[False], | ||
)), name_func=get_test_name) | ||
def test_wav_large(self, dtype, sample_rate, num_channels, normalize): | ||
"""`sox_io_backend.load` can load large wav file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
['float32', 'int32', 'int16', 'uint8'], | ||
[4, 8, 16, 32], | ||
)), name_func=get_test_name) | ||
def test_multiple_channels(self, dtype, num_channels): | ||
"""`sox_io_backend.load` can load wav file with more than 2 channels.""" | ||
sample_rate = 8000 | ||
normalize = False | ||
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000, 44100], | ||
[1, 2], | ||
[96, 128, 160, 192, 224, 256, 320], | ||
)), name_func=get_test_name) | ||
def test_mp3(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.load` can load mp3 format correctly.""" | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[128], | ||
)), name_func=get_test_name) | ||
def test_mp3_large(self, sample_rate, num_channels, bit_rate): | ||
"""`sox_io_backend.load` can load large mp3 file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
list(range(9)), | ||
)), name_func=get_test_name) | ||
def test_flac(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.load` can load flac format correctly.""" | ||
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[0], | ||
)), name_func=get_test_name) | ||
def test_flac_large(self, sample_rate, num_channels, compression_level): | ||
"""`sox_io_backend.load` can load large flac file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_flac(sample_rate, num_channels, compression_level, two_hours) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[8000, 16000], | ||
[1, 2], | ||
[-1, 0, 1, 2, 3, 3.6, 5, 10], | ||
)), name_func=get_test_name) | ||
def test_vorbis(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.load` can load vorbis format correctly.""" | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[16000], | ||
[2], | ||
[10], | ||
)), name_func=get_test_name) | ||
def test_vorbis_large(self, sample_rate, num_channels, quality_level): | ||
"""`sox_io_backend.load` can load large vorbis file correctly.""" | ||
two_hours = 2 * 60 * 60 | ||
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) | ||
|
||
|
||
@common_utils.skipIfNoExec('sox') | ||
@common_utils.skipIfNoExtension | ||
class TestLoadParams(TempDirMixin, PytorchTestCase): | ||
"""Test the correctness of frame parameters of `sox_io_backend.load`""" | ||
original = None | ||
path = None | ||
|
||
def setUp(self): | ||
super().setUp() | ||
sample_rate = 8000 | ||
self.original = get_wav_data('float32', num_channels=2) | ||
self.path = self.get_temp_path('test.wave') | ||
save_wav(self.path, self.original, sample_rate) | ||
|
||
@parameterized.expand(list(itertools.product( | ||
[0, 1, 10, 100, 1000], | ||
[-1, 1, 10, 100, 1000], | ||
)), name_func=get_test_name) | ||
def test_frame(self, frame_offset, num_frames): | ||
"""num_frames and frame_offset correctly specify the region of data""" | ||
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) | ||
frame_end = None if num_frames == -1 else frame_offset + num_frames | ||
self.assertEqual(found, self.original[:, frame_offset:frame_end]) | ||
|
||
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) | ||
def test_channels_first(self, channels_first): | ||
"""channels_first swaps axes""" | ||
found, _ = sox_io_backend.load(self.path, channels_first=channels_first) | ||
expected = self.original if channels_first else self.original.transpose(1, 0) | ||
self.assertEqual(found, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.