Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-picked 0.9] Ensure resampling identity is unchanged #1537

Merged
merged 1 commit into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 15 additions & 25 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Expand Up @@ -56,15 +56,12 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def setUp(self):
super().setUp()

# 1. test signal for testing resampling
self.test1_signal_sr = 16000
self.test1_signal = common_utils.get_whitenoise(
sample_rate=self.test1_signal_sr, duration=0.5,
# test signal for testing resampling
self.test_signal_sr = 16000
self.test_signal = common_utils.get_whitenoise(
sample_rate=self.test_signal_sr, duration=0.5,
)

# 2. test audio file corresponding to saved kaldi ark files
self.test2_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')

# separating test files by their types (e.g 'spec', 'fbank', etc.)
for f in os.listdir(kaldi_output_dir):
dash_idx = f.find('-')
Expand Down Expand Up @@ -176,30 +173,23 @@ def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2])
return output

self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)

@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
upsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)
self.assertTrue(upsample_sound.size(-1) == self.test_signal.size(-1) * 2)

@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)
self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1) // 2)

@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))
self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1))

def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
Expand Down Expand Up @@ -244,18 +234,18 @@ def test_resample_waveform_upsample_accuracy(self, resampling_method):
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3

multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
multi_sound = self.test_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)

for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5

multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test_signal_sr, self.test_signal_sr // 2,
resampling_method=resampling_method)

# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2,
single_channel = self.test_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test_signal_sr,
self.test_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Expand Up @@ -259,6 +259,16 @@ def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):

self.assertEqual(specgrams, specgrams_copy)

@parameterized.expand(list(itertools.product(
["sinc_interpolation", "kaiser_window"],
[16000, 44100],
)))
def test_resample_identity(self, resampling_method, sample_rate):
waveform = get_whitenoise(sample_rate=sample_rate, duration=1)

resampled = F.resample(waveform, sample_rate, sample_rate)
self.assertEqual(waveform, resampled)

def test_resample_no_warning(self):
sample_rate = 44100
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)
Expand Down
13 changes: 13 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_test_impl.py
@@ -1,3 +1,4 @@
import itertools
import warnings

import torch
Expand All @@ -8,6 +9,7 @@
get_whitenoise,
get_spectrogram,
)
from parameterized import parameterized


def _get_ratio(mat):
Expand Down Expand Up @@ -77,3 +79,14 @@ def test_melscale_unset_weight_warning(self):
warnings.simplefilter("always")
T.MelScale(n_mels=64, sample_rate=8000, n_stft=201)
assert len(caught_warnings) == 0

@parameterized.expand(list(itertools.product(
["sinc_interpolation", "kaiser_window"],
[16000, 44100],
)))
def test_resample_identity(self, resampling_method, sample_rate):
waveform = get_whitenoise(sample_rate=sample_rate, duration=1)

resampler = T.Resample(sample_rate, sample_rate)
resampled = resampler(waveform)
self.assertEqual(waveform, resampled)
3 changes: 3 additions & 0 deletions torchaudio/functional/functional.py
Expand Up @@ -1449,6 +1449,9 @@ def resample(

assert orig_freq > 0.0 and new_freq > 0.0

if orig_freq == new_freq:
return waveform

gcd = math.gcd(int(orig_freq), int(new_freq))

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
Expand Down
11 changes: 7 additions & 4 deletions torchaudio/transforms.py
Expand Up @@ -696,10 +696,11 @@ def __init__(self,
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)
if self.orig_freq != self.new_freq:
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -709,6 +710,8 @@ def forward(self, waveform: Tensor) -> Tensor:
Returns:
Tensor: Output signal of dimension (..., time).
"""
if self.orig_freq == self.new_freq:
return waveform
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)

Expand Down