Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,8 @@ def mfcc(
def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6) -> Tensor:
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> Tensor:
r"""Resamples the waveform at the new frequency.

This is a wrapper around ``torchaudio.functional.resample``.
Expand All @@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)

Returns:
Tensor: The waveform at the new frequency
"""
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff)
18 changes: 13 additions & 5 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,16 +1291,21 @@ def compute_kaldi_pitch(
return result


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
lowpass_filter_width: int,
rolloff: float,
device: torch.device,
dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99
base_freq *= rolloff

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
Expand Down Expand Up @@ -1345,7 +1350,8 @@ def resample(
waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
Expand All @@ -1362,6 +1368,8 @@ def resample(
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)

Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Expand All @@ -1379,7 +1387,7 @@ def resample(
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)
rolloff, waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
Expand Down
12 changes: 10 additions & 2 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,16 +639,24 @@ class Resample(torch.nn.Module):
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
"""

def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None:
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -659,7 +667,7 @@ def forward(self, waveform: Tensor) -> Tensor:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return F.resample(waveform, self.orig_freq, self.new_freq)
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))

Expand Down