Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed May 19, 2021
1 parent 7078fcd commit a041c3f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
15 changes: 10 additions & 5 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,9 @@ def _get_sinc_resample_kernel(
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float]):
beta: Optional[float],
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
Expand Down Expand Up @@ -1360,7 +1362,8 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
idx_dtype = dtype if dtype is not None else torch.float64
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
Expand All @@ -1379,7 +1382,10 @@ def _get_sinc_resample_kernel(
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
return kernels, width


def _apply_sinc_resample_kernel(
Expand All @@ -1396,7 +1402,6 @@ def _apply_sinc_resample_kernel(
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
Expand Down Expand Up @@ -1452,6 +1457,6 @@ def resample(
gcd = math.gcd(int(orig_freq), int(new_freq))

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
resampling_method, beta)
resampling_method, beta, waveform.device, waveform.dtype)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
12 changes: 9 additions & 3 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,11 @@ class Resample(torch.nn.Module):
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
beta (float, optional): The shape parameter used for kaiser window.
Note: If resampling on waveforms of higher precision than float32, there may be a small loss of precision
because the kernel is cached once as float32. If high precision resampling is important for your application,
the functional form will retain higher precision, but run slower because it does not cache the kernel.
Alternatively, you could rewrite a transform that caches a higher precision kernel.
"""

def __init__(self,
Expand All @@ -682,9 +687,10 @@ def __init__(self,
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff,
self.resampling_method, beta)
self.register_buffer('kernel', kernel)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand Down

0 comments on commit a041c3f

Please sign in to comment.