Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kaiser window support to resampling #1509

Merged
merged 4 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 28 additions & 15 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchaudio_unittest import common_utils
from .compliance import utils as compliance_utils
from parameterized import parameterized


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
Expand Down Expand Up @@ -182,20 +183,26 @@ def get_output_fn(sound, args):

self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)

def test_resample_waveform_upsample_size(self):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_size(self, resampling_method):
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
resampling_method=resampling_method)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)

def test_resample_waveform_downsample_size(self):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)

def test_resample_waveform_identity_size(self):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_identity_size(self, resampling_method):
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
resampling_method=resampling_method)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))

def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
atol=1e-1, rtol=1e-4):
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
Expand All @@ -211,7 +218,8 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
resampling_method=resampling_method).squeeze()

new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
Expand All @@ -222,27 +230,32 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact

self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)

def test_resample_waveform_downsample_accuracy(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_downsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2)
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)

def test_resample_waveform_upsample_accuracy(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_upsample_accuracy(self, resampling_method):
for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

def test_resample_waveform_multi_channel(self):
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
def test_resample_waveform_multi_channel(self, resampling_method):
num_channels = 3

multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)

for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5

multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
resampling_method=resampling_method)

# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2)
self.test1_signal_sr // 2,
resampling_method=resampling_method)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
6 changes: 4 additions & 2 deletions test/torchaudio_unittest/transforms/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ def test_resample_size(self):

upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')
invalid_resampling_method = 'foo'

self.assertRaises(ValueError, invalid_resample, waveform)
with self.assertRaises(ValueError):
torchaudio.transforms.Resample(sample_rate, upsample_rate,
resampling_method=invalid_resampling_method)

upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation')
Expand Down
6 changes: 4 additions & 2 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> Tensor:
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation") -> Tensor:
r"""Resamples the waveform at the new frequency.

This is a wrapper around ``torchaudio.functional.resample``.
Expand All @@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff)
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width,
rolloff, resampling_method)
31 changes: 25 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
new_freq: float,
gcd: int,
lowpass_filter_width: int,
rolloff: float):
rolloff: float,
resampling_method: str,
beta: Optional[float]):

carolineechen marked this conversation as resolved.
Show resolved Hide resolved
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
Expand All @@ -1318,9 +1320,15 @@ def _get_sinc_resample_kernel(
"https://github.com/pytorch/audio/issues/1487."
)

if resampling_method not in ['sinc_interpolation', 'kaiser_window']:
raise ValueError('Invalid resampling method: {}'.format(resampling_method))

orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

if resampling_method == "kaiser_window" and beta is None:
beta = 14.769656459379492

assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
Expand Down Expand Up @@ -1352,15 +1360,20 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq)
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

through running separate scripts, I realized that #1499 introduced rounding errors if waveform dtype was greater than float32. adding dtype=float64 here to retain implementation accuracy from the prior version, and this implementation will likely be improved after additional discussion on transforms kernel dtype/device computation


for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window

# we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
if resampling_method == "sinc_interpolation":
carolineechen marked this conversation as resolved.
Show resolved Hide resolved
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
elif resampling_method == "kaiser_window":
beta = torch.tensor(beta, dtype=float)
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)
Expand Down Expand Up @@ -1403,6 +1416,8 @@ def resample(
new_freq: float,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
Expand All @@ -1421,6 +1436,9 @@ def resample(
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
beta (float, optional): The shape parameter used for kaiser window.

Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Expand All @@ -1433,6 +1451,7 @@ def resample(

gcd = math.gcd(int(orig_freq), int(new_freq))

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, beta)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
17 changes: 9 additions & 8 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,19 +657,22 @@ class Resample(torch.nn.Module):
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
resampling_method (str, optional): The resampling method.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

followup: This docstring can be improved.

lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.
"""

def __init__(self,
orig_freq: float = 16000,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
rolloff: float = 0.99,
beta: Optional[float] = None) -> None:
super(Resample, self).__init__()

self.orig_freq = orig_freq
Expand All @@ -680,7 +683,8 @@ def __init__(self,
self.rolloff = rolloff

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff)
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -690,11 +694,8 @@ def forward(self, waveform: Tensor) -> Tensor:
Returns:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)


class ComplexNorm(torch.nn.Module):
Expand Down