Skip to content

Commit

Permalink
Add F.resample torchscript test (#1516)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed May 20, 2021
1 parent a21b08e commit 7763ed8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,28 @@ def func(tensor):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)

def test_resample_sinc(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation")

tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)

def test_resample_kaiser(self):
def func(tensor):
sr1, sr2 = 16000., 8000.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window")

def func_beta(tensor):
sr1, sr2 = 16000., 8000.
beta = 6.
return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta)

tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
Expand Down
12 changes: 6 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,9 +1328,6 @@ def _get_sinc_resample_kernel(
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd

if resampling_method == "kaiser_window" and beta is None:
beta = 14.769656459379492

assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
Expand Down Expand Up @@ -1373,9 +1370,12 @@ def _get_sinc_resample_kernel(
# at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
elif resampling_method == "kaiser_window":
beta = torch.tensor(beta, dtype=float)
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta)
else:
# kaiser_window
if beta is None:
beta = 14.769656459379492
beta_tensor = torch.tensor(float(beta))
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
Expand Down

0 comments on commit 7763ed8

Please sign in to comment.