From 66dfeea633bd613df98cc36521991acd4ac122d7 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Mon, 14 Dec 2020 11:58:34 +0100 Subject: [PATCH 1/8] plop --- torchaudio/compliance/kaldi.py | 283 +++++++++------------------------ 1 file changed, 79 insertions(+), 204 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index afabebeb6a..41db1aeae7 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -3,6 +3,7 @@ import math import torch from torch import Tensor +from torch.nn import functional as F import torchaudio import torchaudio._internal.fft @@ -752,147 +753,59 @@ def mfcc( return feature -def _get_LR_indices_and_weights(orig_freq: float, - new_freq: float, - output_samples_in_unit: int, - window_width: float, - lowpass_cutoff: float, - lowpass_filter_width: int, - device: torch.device, - dtype: int) -> Tuple[Tensor, Tensor]: - r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for - resampling as well as the indices in which they are valid. LinearResample (LR) means - that the output signal is at linearly spaced intervals (i.e the output signal has a - frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample - the signal. - - The reason why the same filter is not used for multiple convolutions is because the - sinc function could sampled at different points in time. For example, suppose - a signal is sampled at the timestamps (seconds) - 0 16 32 - and we want it to be sampled at the timestamps (seconds) - 0 5 10 15 20 25 30 35 - at the timestamp of 16, the delta timestamps are - 16 11 6 1 4 9 14 19 - at the timestamp of 32, the delta timestamps are - 32 27 22 17 12 8 2 3 - - As we can see from deltas, the sinc function is sampled at different points of time - assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....] - for 16 vs [...., 2, 3, ....] for 32) - - Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then - there needs to be one filter. - - A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function - has infinite support (non-zero for all values) so instead it is truncated and multiplied by - a window function which gives it less-than-perfect rolloff [1]. - - [1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm - - Args: - orig_freq (float): The original frequency of the signal - new_freq (float): The desired frequency - output_samples_in_unit (int): The number of output samples in the smallest repeating unit: - num_samp_out = new_freq / Gcd(orig_freq, new_freq) - window_width (float): The width of the window which is nonzero - lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less - than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. - lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less - efficient. We suggest around 4 to 10 for normal use - - Returns: - (Tensor, Tensor): A tuple of ``min_input_index`` (which is the minimum indices - where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights - which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)). - """ - assert lowpass_cutoff < min(orig_freq, new_freq) / 2 - output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq - min_t = output_t - window_width - max_t = output_t + window_width - - min_input_index = torch.ceil(min_t * orig_freq) # size (output_samples_in_unit) - max_input_index = torch.floor(max_t * orig_freq) # size (output_samples_in_unit) - num_indices = max_input_index - min_input_index + 1 # size (output_samples_in_unit) - - max_weight_width = num_indices.max() - # create a group of weights of size (output_samples_in_unit, max_weight_width) - j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0) - input_index = min_input_index.unsqueeze(1) + j - delta_t = (input_index / orig_freq) - output_t.unsqueeze(1) - - weights = torch.zeros_like(delta_t) - inside_window_indices = delta_t.abs().lt(window_width) - # raised-cosine (Hanning) window with width `window_width` - weights[inside_window_indices] = 0.5 * (1 + torch.cos(2 * math.pi * lowpass_cutoff / - lowpass_filter_width * delta_t[inside_window_indices])) - - t_eq_zero_indices = delta_t.eq(0.0) - t_not_eq_zero_indices = ~t_eq_zero_indices - # sinc filter function - weights[t_not_eq_zero_indices] *= torch.sin( - 2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (math.pi * delta_t[t_not_eq_zero_indices]) - # limit of the function at t = 0 - weights[t_eq_zero_indices] *= 2 * lowpass_cutoff - - weights /= orig_freq # size (output_samples_in_unit, max_weight_width) - return min_input_index, weights - - -def _lcm(a: int, b: int) -> int: - return abs(a * b) // math.gcd(a, b) - - -def _get_num_LR_output_samples(input_num_samp: int, - samp_rate_in: float, - samp_rate_out: float) -> int: - r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that - the output signal is at linearly spaced intervals (i.e the output signal has a - frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample - the signal. - - Args: - input_num_samp (int): The number of samples in the input - samp_rate_in (float): The original frequency of the signal - samp_rate_out (float): The desired frequency - - Returns: - int: The number of output samples - """ - # For exact computation, we measure time in "ticks" of 1.0 / tick_freq, - # where tick_freq is the least common multiple of samp_rate_in and - # samp_rate_out. - samp_rate_in = int(samp_rate_in) - samp_rate_out = int(samp_rate_out) - - tick_freq = _lcm(samp_rate_in, samp_rate_out) - ticks_per_input_period = tick_freq // samp_rate_in - - # work out the number of ticks in the time interval - # [ 0, input_num_samp/samp_rate_in ). - interval_length_in_ticks = input_num_samp * ticks_per_input_period - if interval_length_in_ticks <= 0: - return 0 - ticks_per_output_period = tick_freq // samp_rate_out - # Get the last output-sample in the closed interval, i.e. replacing [ ) with - # [ ]. Note: integer division rounds down. See - # http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of - # the notation. - last_output_samp = interval_length_in_ticks // ticks_per_output_period - # We need the last output-sample in the open interval, so if it takes us to - # the end of the interval exactly, subtract one. - if last_output_samp * ticks_per_output_period == interval_length_in_ticks: - last_output_samp -= 1 - # First output-sample index is zero, so the number of output samples - # is the last output-sample plus one. - num_output_samp = last_output_samp + 1 - return num_output_samp +def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, + rolloff: float, device: torch.device, dtype: torch.dtype): + kernels = [] + base_freq = min(orig_freq, new_freq) + # rolloff will perform antialiasing filtering by removing the highest frequencies. + # At first I thought I only needed this when downsampling, but when upsampling + # you will get edge artifacts without this, the edge is equivalent to zero padding, + # which will add high freq artifacts. + base_freq *= rolloff + + # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) + # using the sinc interpolation formula: + # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t)) + # We can then sample the function x(t) with a different sample rate: + # y[j] = x(j / new_freq) + # or, + # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + + # We see here that y[j] is the convolution of x[i] with a specific filter, for which + # we take an FIR approximation, stopping when we see at least `zeros` zeros crossing. + # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq]. + # Indeed: + # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq)) + # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) + # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. + # This will explain the F.conv1d after, with a stride of orig_freq. + width = math.ceil(lowpass_filter_width * orig_freq / base_freq) + # If old_sr is still big after GCD reduction, most filters will be very unbalanced, i.e., + # they will have a lot of almost zero values to the left or to the right... + # There is probably a way to evaluate those filters more efficiently, but this is kept for + # future work. + idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype) + + for i in range(new_freq): + t = (-i / new_freq + idx / orig_freq) * base_freq + t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) + t *= math.pi + # we do not use torch.hann_window here as we need to evaluate the window + # at spectifics positions, not over a regular grid. + window = torch.cos(t / lowpass_filter_width / 2)**2 + sinc = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) + kernels.append(sinc(t) * window) + + scale = base_freq / orig_freq + return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width def resample_waveform(waveform: Tensor, orig_freq: float, new_freq: float, - lowpass_filter_width: int = 6) -> Tensor: + lowpass_filter_width: int = 6, + rolloff: float = 0.99) -> Tensor: r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e @@ -904,80 +817,42 @@ def resample_waveform(waveform: Tensor, Args: waveform (Tensor): The input signal of size (c, n) - orig_freq (float): The original frequency of the signal - new_freq (float): The desired frequency + orig_freq (int): The original frequency of the signal + new_freq (int): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + rolloff (float): use a lowpass filter that is `rolloff * new_freq / 2`, + to ensure sufficient margin due to the imperfection of the FIR filter used. + Lowering this value will reduce anti-aliasing, but will lose some of the high + frequency content. Returns: Tensor: The waveform at the new frequency - """ - device, dtype = waveform.device, waveform.dtype + .. caution:: + After dividing `orig_freq` and `new_freq` by their GCD, both should be small + for this implementation to be fast. + """ assert waveform.dim() == 2 + if int(orig_freq) != orig_freq or int(new_freq) != new_freq: + raise ValueError("orig_freq and new_freq should be integers") assert orig_freq > 0.0 and new_freq > 0.0 - min_freq = min(orig_freq, new_freq) - lowpass_cutoff = 0.99 * 0.5 * min_freq - - assert lowpass_cutoff * 2 <= min_freq - - base_freq = math.gcd(int(orig_freq), int(new_freq)) - input_samples_in_unit = int(orig_freq) // base_freq - output_samples_in_unit = int(new_freq) // base_freq - - window_width = lowpass_filter_width / (2.0 * lowpass_cutoff) - first_indices, weights = _get_LR_indices_and_weights( - orig_freq, new_freq, output_samples_in_unit, - window_width, lowpass_cutoff, lowpass_filter_width, device, dtype) - - assert first_indices.dim() == 1 - # TODO figure a better way to do this. conv1d reaches every element i*stride + padding - # all the weights have the same stride but have different padding. - # Current implementation takes the input and applies the various padding before - # doing a conv1d for that specific weight. - conv_stride = input_samples_in_unit - conv_transpose_stride = output_samples_in_unit - num_channels, wave_len = waveform.size() - window_size = weights.size(1) - tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq) - output = torch.zeros((num_channels, tot_output_samp), - device=device, dtype=dtype) - # eye size: (num_channels, num_channels, 1) - eye = torch.eye(num_channels, device=device, dtype=dtype).unsqueeze(2) - for i in range(first_indices.size(0)): - wave_to_conv = waveform - first_index = int(first_indices[i].item()) - if first_index >= 0: - # trim the signal as the filter will not be applied before the first_index - wave_to_conv = wave_to_conv[..., first_index:] - - # pad the right of the signal to allow partial convolutions meaning compute - # values for partial windows (e.g. end of the window is outside the signal length) - max_unit_index = (tot_output_samp - 1) // output_samples_in_unit - end_index_of_last_window = max_unit_index * conv_stride + window_size - current_wave_len = wave_len - first_index - right_padding = max(0, end_index_of_last_window + 1 - current_wave_len) - - left_padding = max(0, -first_index) - if left_padding != 0 or right_padding != 0: - wave_to_conv = torch.nn.functional.pad(wave_to_conv, (left_padding, right_padding)) - - conv_wave = torch.nn.functional.conv1d( - wave_to_conv.unsqueeze(0), weights[i].repeat(num_channels, 1, 1), - stride=conv_stride, groups=num_channels) - - # we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride] - dilated_conv_wave = torch.nn.functional.conv_transpose1d( - conv_wave, eye, stride=conv_transpose_stride).squeeze(0) - - # pad dilated_conv_wave so it reaches the output length if needed. - dialated_conv_wave_len = dilated_conv_wave.size(-1) - left_padding = i - right_padding = max(0, tot_output_samp - (left_padding + dialated_conv_wave_len)) - dilated_conv_wave = torch.nn.functional.pad( - dilated_conv_wave, (left_padding, right_padding))[..., :tot_output_samp] - - output += dilated_conv_wave - - return output + if orig_freq == new_freq: + return waveform + + orig_freq = int(orig_freq) + new_freq = int(new_freq) + gcd = math.gcd(orig_freq, new_freq) + orig_freq = orig_freq // gcd + new_freq = new_freq // gcd + + kernel, width = _get_sinc_resample_kernel( + orig_freq, new_freq, lowpass_filter_width, rolloff, waveform.device, waveform.dtype) + + num_wavs, length = waveform.shape + waveform = F.pad(waveform, (width, width + orig_freq)) + resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq) + resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) + target_length = int(new_freq * length / orig_freq) + return resampled[..., :target_length] From 124bfdaebc82f2dfb15cf3b3d9f2b12de95dc700 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Mon, 14 Dec 2020 13:46:29 +0100 Subject: [PATCH 2/8] remove identity fix for compatibility --- torchaudio/compliance/kaldi.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 41db1aeae7..bbc6a90291 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -794,8 +794,9 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt # we do not use torch.hann_window here as we need to evaluate the window # at spectifics positions, not over a regular grid. window = torch.cos(t / lowpass_filter_width / 2)**2 - sinc = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) - kernels.append(sinc(t) * window) + kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) + kernel.mul_(window) + kernels.append(kernel) scale = base_freq / orig_freq return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width @@ -838,9 +839,6 @@ def resample_waveform(waveform: Tensor, raise ValueError("orig_freq and new_freq should be integers") assert orig_freq > 0.0 and new_freq > 0.0 - if orig_freq == new_freq: - return waveform - orig_freq = int(orig_freq) new_freq = int(new_freq) gcd = math.gcd(orig_freq, new_freq) From 1f3a5e9b5995ee18911c91f7343c2c320bbd1d74 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Mon, 14 Dec 2020 13:50:10 +0100 Subject: [PATCH 3/8] same length as previous implementation --- torchaudio/compliance/kaldi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index bbc6a90291..c4cd1f3c40 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -852,5 +852,5 @@ def resample_waveform(waveform: Tensor, waveform = F.pad(waveform, (width, width + orig_freq)) resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq) resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) - target_length = int(new_freq * length / orig_freq) + target_length = int(math.ceil(new_freq * length / orig_freq)) return resampled[..., :target_length] From 4bf4b583882af22962abc2e65b64875623b8dca4 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Mon, 14 Dec 2020 13:59:10 +0100 Subject: [PATCH 4/8] fix doc --- torchaudio/compliance/kaldi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index c4cd1f3c40..69048b28cf 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -772,16 +772,16 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # We see here that y[j] is the convolution of x[i] with a specific filter, for which - # we take an FIR approximation, stopping when we see at least `zeros` zeros crossing. + # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing. # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq]. # Indeed: # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq)) - # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) - # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) + # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. # This will explain the F.conv1d after, with a stride of orig_freq. width = math.ceil(lowpass_filter_width * orig_freq / base_freq) - # If old_sr is still big after GCD reduction, most filters will be very unbalanced, i.e., + # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. From 4bc3ec3a7475c2bd91055fc8c60c46c096c11495 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Tue, 15 Dec 2020 11:23:50 +0100 Subject: [PATCH 5/8] Updated doc and type annotations --- torchaudio/compliance/kaldi.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 69048b28cf..872cf0acd9 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -754,14 +754,17 @@ def mfcc( def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, - rolloff: float, device: torch.device, dtype: torch.dtype): + lowpass_cutoff_ratio: float, + device: torch.device, dtype: torch.dtype): + assert lowpass_filter_width > 0 + assert 1 >= lowpass_cutoff_ratio > 0 kernels = [] base_freq = min(orig_freq, new_freq) # rolloff will perform antialiasing filtering by removing the highest frequencies. # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, the edge is equivalent to zero padding, # which will add high freq artifacts. - base_freq *= rolloff + base_freq *= lowpass_cutoff_ratio # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -803,10 +806,10 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt def resample_waveform(waveform: Tensor, - orig_freq: float, - new_freq: float, + orig_freq: int, + new_freq: int, lowpass_filter_width: int = 6, - rolloff: float = 0.99) -> Tensor: + lowpass_cutoff_ratio: float = 0.99) -> Tensor: r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e @@ -822,10 +825,10 @@ def resample_waveform(waveform: Tensor, new_freq (int): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) - rolloff (float): use a lowpass filter that is `rolloff * new_freq / 2`, - to ensure sufficient margin due to the imperfection of the FIR filter used. - Lowering this value will reduce anti-aliasing, but will lose some of the high - frequency content. + lowpass_cutoff_ratio (float): Controls the cutoff frequency of the low pass filter. + Lower values will reduce aliasing, but also lose more of the high frequency content. + Typical values range from 0.9 to 1 (Default is ``0.9``). + Returns: Tensor: The waveform at the new frequency @@ -845,8 +848,9 @@ def resample_waveform(waveform: Tensor, orig_freq = orig_freq // gcd new_freq = new_freq // gcd - kernel, width = _get_sinc_resample_kernel( - orig_freq, new_freq, lowpass_filter_width, rolloff, waveform.device, waveform.dtype) + kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, + lowpass_filter_width, lowpass_cutoff_ratio, + waveform.device, waveform.dtype) num_wavs, length = waveform.shape waveform = F.pad(waveform, (width, width + orig_freq)) From 9da8579b9f53e1dd959f3a4d35b75a046d252c3d Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Tue, 15 Dec 2020 11:28:04 +0100 Subject: [PATCH 6/8] fixing docs --- torchaudio/compliance/kaldi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 872cf0acd9..5bd48b629c 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -760,9 +760,9 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt assert 1 >= lowpass_cutoff_ratio > 0 kernels = [] base_freq = min(orig_freq, new_freq) - # rolloff will perform antialiasing filtering by removing the highest frequencies. + # `lowpass_cutoff_ratio` will perform antialiasing filtering by removing the highest frequencies. # At first I thought I only needed this when downsampling, but when upsampling - # you will get edge artifacts without this, the edge is equivalent to zero padding, + # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. base_freq *= lowpass_cutoff_ratio From e459eb49975cf7af67ad5d7b61b62cca9185d343 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Wed, 16 Dec 2020 10:44:24 +0100 Subject: [PATCH 7/8] keeping only change in implementation --- torchaudio/compliance/kaldi.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 5bd48b629c..3f8e871c54 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -754,17 +754,15 @@ def mfcc( def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int, - lowpass_cutoff_ratio: float, device: torch.device, dtype: torch.dtype): assert lowpass_filter_width > 0 - assert 1 >= lowpass_cutoff_ratio > 0 kernels = [] base_freq = min(orig_freq, new_freq) - # `lowpass_cutoff_ratio` will perform antialiasing filtering by removing the highest frequencies. + # This will perform antialiasing filtering by removing the highest frequencies. # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. - base_freq *= lowpass_cutoff_ratio + base_freq *= 0.99 # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -806,10 +804,9 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt def resample_waveform(waveform: Tensor, - orig_freq: int, - new_freq: int, - lowpass_filter_width: int = 6, - lowpass_cutoff_ratio: float = 0.99) -> Tensor: + orig_freq: float, + new_freq: float, + lowpass_filter_width: int = 6) -> Tensor: r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e @@ -821,14 +818,10 @@ def resample_waveform(waveform: Tensor, Args: waveform (Tensor): The input signal of size (c, n) - orig_freq (int): The original frequency of the signal - new_freq (int): The desired frequency + orig_freq (float): The original frequency of the signal + new_freq (float): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) - lowpass_cutoff_ratio (float): Controls the cutoff frequency of the low pass filter. - Lower values will reduce aliasing, but also lose more of the high frequency content. - Typical values range from 0.9 to 1 (Default is ``0.9``). - Returns: Tensor: The waveform at the new frequency @@ -838,8 +831,6 @@ def resample_waveform(waveform: Tensor, for this implementation to be fast. """ assert waveform.dim() == 2 - if int(orig_freq) != orig_freq or int(new_freq) != new_freq: - raise ValueError("orig_freq and new_freq should be integers") assert orig_freq > 0.0 and new_freq > 0.0 orig_freq = int(orig_freq) @@ -848,8 +839,7 @@ def resample_waveform(waveform: Tensor, orig_freq = orig_freq // gcd new_freq = new_freq // gcd - kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, - lowpass_filter_width, lowpass_cutoff_ratio, + kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, waveform.device, waveform.dtype) num_wavs, length = waveform.shape From 54e3462ac05f14d9c8c3f72edb19f4e1beff03f7 Mon Sep 17 00:00:00 2001 From: Alexandre Defossez Date: Thu, 17 Dec 2020 17:48:55 +0100 Subject: [PATCH 8/8] staying as close to master --- torchaudio/compliance/kaldi.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 3f8e871c54..5b47ef3350 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -825,10 +825,6 @@ def resample_waveform(waveform: Tensor, Returns: Tensor: The waveform at the new frequency - - .. caution:: - After dividing `orig_freq` and `new_freq` by their GCD, both should be small - for this implementation to be fast. """ assert waveform.dim() == 2 assert orig_freq > 0.0 and new_freq > 0.0