Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply functional batch consistency tests to batches of different items #1315

Merged
merged 15 commits into from
Feb 28, 2021
Merged
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/common_utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def get_whitenoise(
# so we only fork on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn([int(sample_rate * duration)], dtype=torch.float32, device='cpu')
tensor = torch.randn([n_channels, int(sample_rate * duration)],
dtype=torch.float32, device='cpu')
tensor /= 2.0
tensor *= scale_factor
tensor.clamp_(-1.0, 1.0)
tensor = tensor.repeat([n_channels, 1])
if not channels_first:
tensor = tensor.t()
return convert_tensor_encoding(tensor, dtype)
Expand Down
161 changes: 93 additions & 68 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,42 @@
import itertools
import math

from parameterized import parameterized
from parameterized import parameterized, parameterized_class
import torch
import torchaudio
import torchaudio.functional as F

from torchaudio_unittest import common_utils


@parameterized_class([
# Single-item batch isolates problems that come purely from adding a
# dimension (rather than processing multiple items)
{"batch_size": 1},
{"batch_size": 3},
])
class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module"""
backend = 'default'

def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8,
rtol=1e-5, seed=42, **kwargs):
# run then batch the result
self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42,
**kwargs):
n = batch.size(0)

# Compute items separately, then batch the result
torch.random.manual_seed(seed)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.repeat([batch_size] + [1] * expected.dim())
items_input = batch.clone()
items_result = torch.stack([
functional(items_input[i], *args, **kwargs) for i in range(n)
])

# batch the input and run
# Batch the input and run
torch.random.manual_seed(seed)
pattern = [batch_size] + [1] * tensor.dim()
computed = functional(tensor.repeat(pattern), *args, **kwargs)
batch_input = batch.clone()
batch_result = functional(batch_input, *args, **kwargs)

self.assertEqual(computed, expected, rtol=rtol, atol=atol)

def assert_batch_consistencies(
self, functional, tensor, *args, atol=1e-8, rtol=1e-5,
seed=42, **kwargs):
self.assert_batch_consistency(
functional, tensor, *args, batch_size=1, atol=atol,
rtol=rtol, seed=seed, **kwargs)
self.assert_batch_consistency(
functional, tensor, *args, batch_size=3, atol=atol,
rtol=rtol, seed=seed, **kwargs)
self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol)
self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol)

def test_griffinlim(self):
n_fft = 400
Expand All @@ -48,37 +49,44 @@ def test_griffinlim(self):
momentum = 0.99
n_iter = 32
length = 1000
tensor = torch.rand((1, 201, 6))
self.assert_batch_consistencies(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize,
torch.random.manual_seed(0)
batch = torch.rand(self.batch_size, 1, 201, 6)
self.assert_batch_consistency(
F.griffinlim, batch, window, n_fft, hop, ws, power, normalize,
n_iter, momentum, length, 0, atol=5e-5)

@parameterized.expand(list(itertools.product(
[100, 440],
[8000, 16000, 44100],
[1, 2],
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels):
waveform = common_utils.get_sinusoid(
frequency=frequency, sample_rate=sample_rate,
n_channels=n_channels, duration=5)
self.assert_batch_consistencies(
F.detect_pitch_frequency, waveform, sample_rate)
def test_detect_pitch_frequency(self, sample_rate, n_channels):
# Use different frequencies to ensure each item in the batch returns a
# different answer.
torch.manual_seed(0)
frequencies = torch.randint(100, 1000, [self.batch_size])
waveforms = torch.stack([
common_utils.get_sinusoid(
frequency=frequency, sample_rate=sample_rate,
n_channels=n_channels, duration=5)
for frequency in frequencies
])
self.assert_batch_consistency(
F.detect_pitch_frequency, waveforms, sample_rate)

def test_amplitude_to_DB(self):
torch.manual_seed(0)
spec = torch.rand(2, 100, 100) * 200
spec = torch.rand(self.batch_size, 2, 100, 100) * 200

amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))

# Test with & without a `top_db` clamp
self.assert_batch_consistencies(
self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=None)
self.assert_batch_consistencies(
self.assert_batch_consistency(
F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.)

Expand Down Expand Up @@ -140,53 +148,70 @@ def test_amplitude_to_DB_not_channelwise_clamps(self):
assert (difference >= 1e-5).any()

def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(
F.contrast, waveform, enhancement_amount=80.)
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also take this opportunity to replace these waveform generation with get_whitenoise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing

self.assert_batch_consistency(
F.contrast, waveforms, enhancement_amount=80.)

def test_dcshift(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(
F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.dcshift, waveforms, shift=0.5, limiter_gain=0.05)

def test_overdrive(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(
F.overdrive, waveform, gain=45, colour=30)
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.overdrive, waveforms, gain=45, colour=30)

def test_phaser(self):
sample_rate = 44100
n_channels = 2
waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, duration=5,
)
self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
sample_rate=sample_rate, n_channels=self.batch_size * n_channels,
duration=1)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(F.phaser, batch, sample_rate)

def test_flanger(self):
torch.random.manual_seed(40)
waveform = torch.rand(2, 100) - 0.5
torch.random.manual_seed(0)
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
sample_rate = 44100
self.assert_batch_consistencies(F.flanger, waveform, sample_rate)
self.assert_batch_consistency(F.flanger, waveforms, sample_rate)

def test_sliding_window_cmn(self):
waveform = torch.randn(2, 1024) - 0.5
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=True, norm_vars=True)
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=True, norm_vars=False)
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=False, norm_vars=True)
self.assert_batch_consistencies(
F.sliding_window_cmn, waveform, center=False, norm_vars=False)

def test_vad(self):
common_utils.set_audio_backend('default')
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(
F.vad, waveform, sample_rate=sample_rate)
waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=True)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=True, norm_vars=False)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
self.assert_batch_consistency(
F.sliding_window_cmn, waveforms, center=False, norm_vars=False)

def test_vad_from_file(self):
filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
waveform, sample_rate = common_utils.load_wav(filepath)
# Each channel is slightly offset - we can use this to create a batch
# with different items.
batch = waveform.view(2, 1, -1)
self.assert_batch_consistency(F.vad, batch, sample_rate=sample_rate)

def test_vad_different_items(self):
"""Separate test to ensure VAD consistency with differing items."""
sample_rate = 44100
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
self.assert_batch_consistency(
F.vad, waveforms, sample_rate=sample_rate)

@common_utils.skipIfNoExtension
def test_compute_kaldi_pitch(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self.assert_batch_consistencies(F.compute_kaldi_pitch, waveform, sample_rate=sample_rate)
n_channels = 2
waveform = common_utils.get_whitenoise(
sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)