Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompute transforms.Resample kernel #1499

Merged
merged 5 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 56 additions & 39 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(


def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
orig_freq: float,
new_freq: float,
gcd: int,
lowpass_filter_width: int,
rolloff: float,
device: torch.device,
dtype: torch.dtype):
rolloff: float):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)

orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
Expand Down Expand Up @@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
idx = torch.arange(-width, width + orig_freq)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
Expand All @@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: float,
new_freq: float,
gcd: int,
kernel: Tensor,
width: int,
):
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering the fact that this function will be called from Transform, we might want to avoid moving parameter to a certain device.

In a use case like

  1. Initialize Resample
  2. Move the pipeline to a specific device
  3. Run a inference.

and let's say that the pipeline is in GPU but the waveform is on CPU. We want the pipeline to fail so that users can fix the pipeline, instead of performing the operation on CPU, which incurs the cost of moving kernel from GPU to CPU. (well and then it will fail in the next step of pipeline, which expects the Tensor to be on GPU)

I think, moving the kernel to target device can happen in between _get_kernel function and _apply_kernel functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, where do you think target device/dtype should initially be set in transforms?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, where do you think target device/dtype should initially be set in transforms?

They are expected to be set by users with .to, so we do not need to set it explicitly.

resamp = T.Resample(...)
resamp.to(device, dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, this is BC-breaking and left as a follow-up


num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]

# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled


def resample(
waveform: Tensor,
orig_freq: float,
Expand Down Expand Up @@ -1380,42 +1424,15 @@ def resample(

Returns:
Tensor: The waveform at the new frequency of dimension (..., time).

Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])

assert orig_freq > 0.0 and new_freq > 0.0

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)

orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
rolloff, waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]
gcd = math.gcd(int(orig_freq), int(new_freq))

# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
16 changes: 13 additions & 3 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from torch import Tensor
from torchaudio import functional as F

from .functional.functional import (
_get_sinc_resample_kernel,
_apply_sinc_resample_kernel,
)

__all__ = [
'Spectrogram',
Expand Down Expand Up @@ -647,18 +651,23 @@ class Resample(torch.nn.Module):
"""

def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
orig_freq: float = 16000,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__()
carolineechen marked this conversation as resolved.
Show resolved Hide resolved

self.orig_freq = orig_freq
self.new_freq = new_freq
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
Expand All @@ -668,7 +677,8 @@ def forward(self, waveform: Tensor) -> Tensor:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff)
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))

Expand Down