Skip to content

Commit

Permalink
add transforms warning and fix input type
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed May 12, 2021
1 parent 5d5dc1f commit 7a482fb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 32 deletions.
55 changes: 30 additions & 25 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,10 +1299,28 @@ def compute_kaldi_pitch(


def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
orig_freq: float,
new_freq: float,
gcd: int,
lowpass_filter_width: int,
rolloff: float):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)

orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
Expand Down Expand Up @@ -1353,11 +1371,15 @@ def _get_sinc_resample_kernel(

def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: int,
new_freq: int,
orig_freq: float,
new_freq: float,
gcd: int,
kernel: Tensor,
width: int,
):
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
Expand Down Expand Up @@ -1403,31 +1425,14 @@ def resample(
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Note: ``transforms.Resample` precomputes and reuses the resampling kernel, so using it will result in
Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
"""

assert orig_freq > 0.0 and new_freq > 0.0

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
warnings.warn(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)

orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd
gcd = math.gcd(int(orig_freq), int(new_freq))

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, rolloff)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, kernel, width)
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
15 changes: 8 additions & 7 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,20 +651,21 @@ class Resample(torch.nn.Module):
"""

def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
orig_freq: float = 16000,
new_freq: float = 16000,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99) -> None:
super(Resample, self).__init__()
self.orig_freq = int(orig_freq)
self.new_freq = int(new_freq)
self.gcd = math.gcd(self.orig_freq, self.new_freq)

self.orig_freq = orig_freq
self.new_freq = new_freq
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
self.resampling_method = resampling_method
self.lowpass_filter_width = lowpass_filter_width
self.rolloff = rolloff

self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq // self.gcd, self.new_freq // self.gcd,
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
self.lowpass_filter_width, self.rolloff)

def forward(self, waveform: Tensor) -> Tensor:
Expand All @@ -676,7 +677,7 @@ def forward(self, waveform: Tensor) -> Tensor:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
return _apply_sinc_resample_kernel(waveform, self.orig_freq // self.gcd, self.new_freq // self.gcd,
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
self.kernel, self.width)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
Expand Down

0 comments on commit 7a482fb

Please sign in to comment.