Skip to content

Commit

Permalink
add kaiser window
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed May 17, 2021
1 parent 52e7bfd commit dd6c076
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 31 deletions.
43 changes: 28 additions & 15 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchaudio_unittest import common_utils
from .compliance import utils as compliance_utils
from parameterized import parameterized


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
Expand Down Expand Up @@ -182,20 +183,26 @@ def get_output_fn(sound, args):

self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)

def test_resample_waveform_upsample_size(self):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)

def test_resample_waveform_downsample_size(self):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)

def test_resample_waveform_identity_size(self):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))

def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
atol=1e-1, rtol=1e-4):
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
Expand All @@ -211,7 +218,8 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
resampling_method=resampling_method).squeeze()

new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
Expand All @@ -222,27 +230,32 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact

self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)

def test_resample_waveform_downsample_accuracy(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2)
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)

def test_resample_waveform_upsample_accuracy(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

def test_resample_waveform_multi_channel(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3

multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)

for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5

multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)

# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2)
self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
6 changes: 4 additions & 2 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> Tensor:
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation") -> Tensor:
r"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
Expand All @@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff)
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width,
rolloff, resampling_method)
23 changes: 17 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
new_freq: float,
gcd: int,
lowpass_filter_width: int,
rolloff: float):
rolloff: float,
resampling_method: str,
beta: float = 6.):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
Expand Down Expand Up @@ -1352,15 +1354,20 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq)
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window

# we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
elif resampling_method == "kaiser_window":
beta = torch.tensor(beta, dtype=float)
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)
Expand Down Expand Up @@ -1403,6 +1410,8 @@ def resample(
new_freq: float,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
**kwargs,
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
Expand All @@ -1421,6 +1430,7 @@ def resample(
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional):
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Expand All @@ -1433,6 +1443,7 @@ def resample(

gcd = math.gcd(int(orig_freq), int(new_freq))

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, **kwargs)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
19 changes: 11 additions & 8 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ class Resample(torch.nn.Module):
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Expand All @@ -669,7 +670,8 @@ def __init__(self,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
rolloff: float = 0.99,
**kwargs) -> None:
super(Resample, self).__init__()

self.orig_freq = orig_freq
Expand All @@ -679,8 +681,12 @@ def __init__(self,
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

if self.resampling_method not in ['sinc_interpolation', 'kaiser_window']:
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff)
self.lowpass_filter_width, self.rolloff,
self.resampling_method, **kwargs)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -690,11 +696,8 @@ def forward(self, waveform: Tensor) -> Tensor:
Returns:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)


class ComplexNorm(torch.nn.Module):
Expand Down

0 comments on commit dd6c076

Please sign in to comment.