diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index a5c3354c9b..db423bea46 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -755,7 +755,8 @@ def mfcc( def resample_waveform(waveform: Tensor, orig_freq: float, new_freq: float, - lowpass_filter_width: int = 6) -> Tensor: + lowpass_filter_width: int = 6, + rolloff: float = 0.99) -> Tensor: r"""Resamples the waveform at the new frequency. This is a wrapper around ``torchaudio.functional.resample``. @@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor, new_freq (float): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Returns: Tensor: The waveform at the new frequency """ - return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width) + return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index cef2553d82..28c7ef7503 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1291,8 +1291,13 @@ def compute_kaldi_pitch( return result -def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, - device: torch.device, dtype: torch.dtype): +def _get_sinc_resample_kernel( + orig_freq: int, + new_freq: int, + lowpass_filter_width: int, + rolloff: float, + device: torch.device, + dtype: torch.dtype): assert lowpass_filter_width > 0 kernels = [] base_freq = min(orig_freq, new_freq) @@ -1300,7 +1305,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. - base_freq *= 0.99 + base_freq *= rolloff # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -1345,7 +1350,8 @@ def resample( waveform: Tensor, orig_freq: float, new_freq: float, - lowpass_filter_width: int = 6 + lowpass_filter_width: int = 6, + rolloff: float = 0.99, ) -> Tensor: r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample @@ -1362,6 +1368,8 @@ def resample( new_freq (float): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Returns: Tensor: The waveform at the new frequency of dimension (..., time). @@ -1379,7 +1387,7 @@ def resample( new_freq = new_freq // gcd kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, - waveform.device, waveform.dtype) + rolloff, waveform.device, waveform.dtype) num_wavs, length = waveform.shape waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index a04592f2c4..6b0a30a4d2 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -639,16 +639,24 @@ class Resample(torch.nn.Module): orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``) new_freq (float, optional): The desired frequency. (Default: ``16000``) resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``) + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) """ def __init__(self, orig_freq: int = 16000, new_freq: int = 16000, - resampling_method: str = 'sinc_interpolation') -> None: + resampling_method: str = 'sinc_interpolation', + lowpass_filter_width: int = 6, + rolloff: float = 0.99) -> None: super(Resample, self).__init__() self.orig_freq = orig_freq self.new_freq = new_freq self.resampling_method = resampling_method + self.lowpass_filter_width = lowpass_filter_width + self.rolloff = rolloff def forward(self, waveform: Tensor) -> Tensor: r""" @@ -659,7 +667,7 @@ def forward(self, waveform: Tensor) -> Tensor: Tensor: Output signal of dimension (..., time). """ if self.resampling_method == 'sinc_interpolation': - return F.resample(waveform, self.orig_freq, self.new_freq) + return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff) raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))