Skip to content

Commit

Permalink
reimplement in torchscriptable way
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed May 20, 2021
1 parent 5f697b8 commit 0623b3f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,17 @@ def func(tensor):
self._assert_consistency(func, tensor)

def test_resample(self):
def func(tensor):
def func_sinc(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")

def func_kaiser(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2)
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")

tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
self._assert_consistency(func_sinc, tensor)
self._assert_consistency(func_kaiser, tensor)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
Expand Down
14 changes: 6 additions & 8 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def _get_sinc_resample_kernel(
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
beta: float,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None):

Expand All @@ -1328,9 +1328,6 @@ def _get_sinc_resample_kernel(
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

if resampling_method == "kaiser_window" and beta is None:
beta = 14.769656459379492

assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
Expand Down Expand Up @@ -1373,9 +1370,10 @@ def _get_sinc_resample_kernel(
# at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
elif resampling_method == "kaiser_window":
beta = torch.tensor(beta, dtype=float)
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta)
else:
# kaiser_window
beta_tensor = torch.tensor(beta)
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
Expand Down Expand Up @@ -1422,7 +1420,7 @@ def resample(
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation",
beta: Optional[float] = None,
beta: float = 14.769656459379492,
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def __init__(self,
resampling_method: str = 'sinc_interpolation',
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None) -> None:
beta: float = 14.769656459379492) -> None:
super(Resample, self).__init__()

self.orig_freq = orig_freq
Expand Down

0 comments on commit 0623b3f

Please sign in to comment.