Skip to content

Commit

Permalink
Apply misc updates to functional/batch_consistency_test.py (#1341)
Browse files Browse the repository at this point in the history
* Parameterize `test_sliding_window_cmn`

* Extract test naming function

* Pass a spectrogram to `F.sliding_window_cmn`

* Set manual seed for remaining rand calls in suite
  • Loading branch information
jcaw committed Mar 5, 2021
1 parent 9a96fb7 commit 64551a6
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from torchaudio_unittest import common_utils


def _name_from_args(func, _, params):
"""Return a parameterized test name, based on parameter values."""
return "{}_{}".format(
func.__name__,
"_".join(str(arg) for arg in params.args))


@parameterized_class([
# Single-item batch isolates problems that come purely from adding a
# dimension (rather than processing multiple items)
Expand Down Expand Up @@ -58,7 +65,7 @@ def test_griffinlim(self):
@parameterized.expand(list(itertools.product(
[8000, 16000, 44100],
[1, 2],
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
)), name_func=_name_from_args)
def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a
# different answer.
Expand Down Expand Up @@ -180,16 +187,16 @@ def test_flanger(self):
sample_rate = 44100
self.assert_batch_consistency(F.flanger, waveforms, sample_rate)

def test_sliding_window_cmn(self):
waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=True)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=False)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
@parameterized.expand(list(itertools.product(
[True, False], # center
[True, False], # norm_vars
)), name_func=_name_from_args)
def test_sliding_window_cmn(self, center, norm_vars):
torch.manual_seed(0)
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=False)
F.sliding_window_cmn, spectrogram, center=center,
norm_vars=norm_vars)

def test_vad_from_file(self):
filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
Expand All @@ -202,6 +209,7 @@ def test_vad_from_file(self):
def test_vad_different_items(self):
"""Separate test to ensure VAD consistency with differing items."""
sample_rate = 44100
torch.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)
Expand Down

0 comments on commit 64551a6

Please sign in to comment.