Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-Breaking] Avoid moving resampling kernel device and dtype moves #1514

Merged
merged 1 commit into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 10 additions & 5 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,9 @@ def _get_sinc_resample_kernel(
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float]):
beta: Optional[float],
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
Expand Down Expand Up @@ -1360,7 +1362,8 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
idx_dtype = dtype if dtype is not None else torch.float64
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if someone passes a low precision type like uint8? I think it might be better to pick whatever dtype is most efficient for this operation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following offline discussion, we can keep higher precision type float64 because the kernel computation is a one-time computation, whose dimensions are limited to roughly orig_freq // gcd x new_freq // gcd. normal resampling frequencies will generally have large gcd, in which case dtype differences will have minor computation differences


for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
Expand All @@ -1379,7 +1382,10 @@ def _get_sinc_resample_kernel(
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to just return the kernel and do the dtype and device cast after the callsite, since you're not using dtype outside of arange.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following offline discussion, it is fine to convert to this generally "default" type prior to returning the kernel to the call from transforms

return kernels, width


def _apply_sinc_resample_kernel(
Expand All @@ -1396,7 +1402,6 @@ def _apply_sinc_resample_kernel(
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
Expand Down Expand Up @@ -1452,6 +1457,6 @@ def resample(
gcd = math.gcd(int(orig_freq), int(new_freq))

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, beta)
resampling_method, beta, waveform.device, waveform.dtype)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
12 changes: 9 additions & 3 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,11 @@ class Resample(torch.nn.Module):
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.

Note: If resampling on waveforms of higher precision than float32, there may be a small loss of precision
because the kernel is cached once as float32. If high precision resampling is important for your application,
the functional form will retain higher precision, but run slower because it does not cache the kernel.
Alternatively, you could rewrite a transform that caches a higher precision kernel.
"""

def __init__(self,
Expand All @@ -682,9 +687,10 @@ def __init__(self,
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand Down