Skip to content

Commit

Permalink
restructure internals
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed May 12, 2021
1 parent ffd2f0a commit 74c0ac0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 43 deletions.
62 changes: 30 additions & 32 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,9 +1302,7 @@ def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
lowpass_filter_width: int,
rolloff: float,
device: torch.device,
dtype: torch.dtype):
rolloff: float):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
Expand Down Expand Up @@ -1336,7 +1334,7 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
idx = torch.arange(-width, width + orig_freq)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
Expand All @@ -1353,13 +1351,36 @@ def _get_sinc_resample_kernel(
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: int,
new_freq: int,
kernel: Tensor,
width: int,
):
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]

# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled


def resample(
waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
kernel: Tensor = None
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
Expand All @@ -1378,20 +1399,13 @@ def resample(
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
kernel (Tensor, optional): Tensor of dimension (f, 1, w) representing the windowed sinc function that is
used in convolution to calculate the resampled waveform. ``f = new_freq_gcd`` and ``w = 2 *
math.ceil(lowpass_filter_width * (orig_freq_gcd) / (rolloff * min(orig_freq_gcd, new_freq_gcd)) + orig_freq_gcd``,
where ``new_freq_gcd`` and ``orig_freq_gcd`` are equal to ``new_freq // gcd`` and ``old_freq // gcd``
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Note: transforms.Resample passes in a precomputed kernel, which will result in more efficient computation if reusing
the same set of resampling parameters to resample multiple waveforms.
Note: ``transforms.Resample` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])

assert orig_freq > 0.0 and new_freq > 0.0

Expand All @@ -1401,22 +1415,6 @@ def resample(
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

if kernel == None:
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
rolloff, waveform.device, waveform.dtype)
else:
base_freq = min(orig_freq, new_freq) * rolloff
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
assert kernel.shape[0] == new_freq
assert kernel.shape[2] == 2 * width + orig_freq

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]

# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, rolloff)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, kernel, width)
return resampled
23 changes: 12 additions & 11 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from torch import Tensor
from torchaudio import functional as F

from .functional.functional import _get_sinc_resample_kernel
from .functional.functional import (
_get_sinc_resample_kernel,
_apply_sinc_resample_kernel,
)

__all__ = [
'Spectrogram',
Expand Down Expand Up @@ -654,12 +657,15 @@ def __init__(self,
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.orig_freq = int(orig_freq)
self.new_freq = int(new_freq)
self.gcd = math.gcd(self.orig_freq, self.new_freq)
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff
self.kernel = None

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq // self.gcd, self.new_freq // self.gcd,
self.lowpass_filter_width, self.rolloff)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -670,13 +676,8 @@ def forward(self, waveform: Tensor) -> Tensor:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
if self.kernel == None:
gcd = math.gcd(self.orig_freq, self.new_freq)
orig_freq = self.orig_freq // gcd
new_freq = self.new_freq // gcd
self.kernel, _ = _get_sinc_resample_kernel(orig_freq, new_freq, self.lowpass_filter_width,
self.rolloff, waveform.device, waveform.dtype)
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.kernel)
return _apply_sinc_resample_kernel(waveform, self.orig_freq // self.gcd, self.new_freq // self.gcd,
self.kernel, self.width)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))

Expand Down

0 comments on commit 74c0ac0

Please sign in to comment.