Skip to content

Commit

Permalink
Refactor sox_io load_test (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Mar 17, 2021
1 parent 6bad3a6 commit 80a8739
Showing 1 changed file with 58 additions and 194 deletions.
252 changes: 58 additions & 194 deletions test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,217 +29,76 @@


class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
self.assertEqual(data, expected)

def assert_24bit_wav(self, sample_rate, num_channels, normalize, duration):
""" `sox_io_backend.load` can load 24-bit signed PCM wav format. Since torch does not support the ``int24`` dtype,
we implicitly cast the resulting tensor to the ``int32`` dtype.
It is not possible to use #assert_wav method above, as #get_wav_data does not support
the 'int24' dtype. This is because torch does not support the ``int24`` dtype.
Hence, we must use the following workaround.
x
|
| 1. Generate 24-bit wav with Sox.
|
v 2. Convert 24-bit wav to 32-bit wav with Sox.
wav(24-bit) ----------------------> wav(32-bit)
| |
| 3. Load 24-bit wav with torchaudio| 4. Load 32-bit wav with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
# Underlying assumptions are:
# i. Sox properly converts from 24-bit to 32-bit
# ii. Loading 32-bit wav file with scipy is correct.
"""
path = self.get_temp_path('1.original.wav')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate 24-bit signed wav with Sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=24, duration=duration)

# 2. Convert from 24-bit wav to 32-bit wav with sox
sox_utils.convert_audio_file(path, ref_path, bit_depth=32)
# 3. Load 24-bit wav with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load 32-bit wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)

def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.load` can load mp3 format.
mp3 encoding introduces delay and boundary effects so
we create reference wav file from mp3
def assert_format(
self,
format: str,
sample_rate: float,
num_channels: int,
compression: float = None,
bit_depth: int = None,
duration: float = 1,
normalize: bool = True,
encoding: str = None,
atol: float = 4e-05,
rtol: float = 1.3e-06,
):
"""`sox_io_backend.load` can load given format correctly.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
x
|
| 1. Generate mp3 with Sox
| 1. Generate given format with Sox
|
v 2. Convert to wav with Sox
mp3 ------------------------------> wav
given format ----------------------> wav
| |
| 3. Load with torchaudio | 4. Load with scipy
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are:
i. Conversion of mp3 to wav with Sox preserves data.
Underlying assumptions are;
i. Conversion of given format to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio
By combining i & ii, step 2. and 4. allows to load reference given format
data without using torchaudio
"""
path = self.get_temp_path('1.original.mp3')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate mp3 with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load mp3 with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)

def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.load` can load flac format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.flac')
path = self.get_temp_path(f'1.original.{format}')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate flac with sox
# 1. Generate the given format with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, bit_depth=16, duration=duration)
path, sample_rate, num_channels, encoding=encoding,
compression=compression, bit_depth=bit_depth, duration=duration,
)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load flac with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.load` can load vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.vorbis')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, bit_depth=16, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load vorbis with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_sphere(self, sample_rate, num_channels, duration):
"""`sox_io_backend.load` can load sph format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.sph')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate sph with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load sph with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load amb format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.amb')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate amb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amb with torchaudio
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
# 3. Load the given format with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
self.assertEqual(data, data_ref, atol=atol, rtol=rtol)

def assert_amr_nb(self, duration):
"""`sox_io_backend.load` can load amr-nb format.
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
This test takes the same strategy as mp3 to compare the result
Wav data loaded with sox_io backend should match those with scipy
"""
sample_rate = 8000
num_channels = 1
path = self.get_temp_path('1.original.amr-nb')
ref_path = self.get_temp_path('2.reference.wav')

# 1. Generate amr-nb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amr-nb with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
self.assertEqual(data, expected)


@skipIfNoExec('sox')
Expand All @@ -263,7 +122,7 @@ def test_wav(self, dtype, sample_rate, num_channels, normalize):
)), name_func=name_func)
def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_24bit_wav(sample_rate, num_channels, normalize, duration=1)
self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1)

@parameterized.expand(list(itertools.product(
['int16'],
Expand Down Expand Up @@ -293,7 +152,7 @@ def test_multiple_channels(self, dtype, num_channels):
)), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)

@parameterized.expand(list(itertools.product(
[16000],
Expand All @@ -303,7 +162,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours)
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -312,7 +171,7 @@ def test_mp3_large(self, sample_rate, num_channels, bit_rate):
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1)

@parameterized.expand(list(itertools.product(
[16000],
Expand All @@ -322,7 +181,8 @@ def test_flac(self, sample_rate, num_channels, compression_level):
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, two_hours)
self.assert_format(
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours)

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -331,7 +191,7 @@ def test_flac_large(self, sample_rate, num_channels, compression_level):
)), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)
self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1)

@parameterized.expand(list(itertools.product(
[16000],
Expand All @@ -341,7 +201,8 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
self.assert_format(
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours)

@parameterized.expand(list(itertools.product(
['96k'],
Expand All @@ -366,7 +227,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_sphere(sample_rate, num_channels, duration=1)
self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'],
Expand All @@ -375,12 +236,15 @@ def test_sphere(self, sample_rate, num_channels):
[False, True],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1)
"""`sox_io_backend.load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
self.assert_format(
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize)

def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_amr_nb(duration=1)
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)


@skipIfNoExec('sox')
Expand Down

0 comments on commit 80a8739

Please sign in to comment.