From 1f17b78c9c652494cc3aaa516ba56c0e0aab2760 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 7 May 2020 21:19:25 +0000 Subject: [PATCH 1/4] Make fbank support cuda --- test/test_kaldi_compatibility.py | 16 +++++---- torchaudio/compliance/kaldi.py | 61 ++++++++++++++++++++------------ 2 files changed, 48 insertions(+), 29 deletions(-) diff --git a/test/test_kaldi_compatibility.py b/test/test_kaldi_compatibility.py index 87b5ca029d..2c29141317 100644 --- a/test/test_kaldi_compatibility.py +++ b/test/test_kaldi_compatibility.py @@ -37,7 +37,7 @@ def _run_kaldi(command, input_type, input_value): key = 'foo' process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) if input_type == 'ark': - kaldi_io.write_mat(process.stdin, input_value.numpy(), key=key) + kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key) elif input_type == 'scp': process.stdin.write(f'{key} {input_value}'.encode('utf8')) else: @@ -47,7 +47,7 @@ def _run_kaldi(command, input_type, input_value): return torch.from_numpy(result.copy()) # copy supresses some torch warning -class TestFunctional: +class Kaldi(common_utils.TestBaseMixin): @unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available') def test_sliding_window_cmn(self): """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" @@ -58,11 +58,11 @@ def test_sliding_window_cmn(self): 'norm_vars': False, } - tensor = torch.randn(40, 10) + tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device) result = F.sliding_window_cmn(tensor, **kwargs) command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-'] kaldi_result = _run_kaldi(command, 'ark', tensor) - torch.testing.assert_allclose(result, kaldi_result) + torch.testing.assert_allclose(result.cpu(), kaldi_result.to(self.dtype)) @unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available') def test_fbank(self): @@ -93,7 +93,11 @@ def test_fbank(self): } wave_file = common_utils.get_asset_path('kaldi_file.wav') - result = torchaudio.compliance.kaldi.fbank(torchaudio.load_wav(wave_file)[0], **kwargs) + waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device) + result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] kaldi_result = _run_kaldi(command, 'scp', wave_file) - torch.testing.assert_allclose(result, kaldi_result) + torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype)) + + +common_utils.define_test_suites(globals(), [Kaldi]) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 1790e93cf6..e73300455a 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -33,6 +33,10 @@ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + def _next_power_of_2(x: int) -> int: r"""Returns the smallest power of 2 that is greater than x """ @@ -60,7 +64,7 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg if snip_edges: if num_samples < window_size: - return torch.empty((0, 0)) + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) else: m = 1 + (num_samples - window_size) // window_shift else: @@ -83,24 +87,27 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg def _feature_window_function(window_type: str, window_size: int, - blackman_coeff: float) -> Tensor: + blackman_coeff: float, + device: torch.device, + dtype: torch.dtype +) -> Tensor: r"""Returns a window function with the given type and size """ if window_type == HANNING: - return torch.hann_window(window_size, periodic=False) + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) elif window_type == HAMMING: - return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46) + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) elif window_type == POVEY: # like hanning but goes to zero at edges - return torch.hann_window(window_size, periodic=False).pow(0.85) + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) elif window_type == RECTANGULAR: - return torch.ones(window_size) + return torch.ones(window_size, device=device, dtype=dtype) elif window_type == BLACKMAN: a = 2 * math.pi / (window_size - 1) - window_function = torch.arange(window_size) + window_function = torch.arange(window_size, device=device, dtype=dtype) # can't use torch.blackman_window as they use different coefficients return (blackman_coeff - 0.5 * torch.cos(a * window_function) + - (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)) + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype) else: raise Exception('Invalid window type ' + window_type) @@ -110,12 +117,12 @@ def _get_log_energy(strided_input: Tensor, energy_floor: float) -> Tensor: r"""Returns the log energy of size (m) for a strided_input (m,*) """ + device, dtype = strided_input.device, strided_input.dtype log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) if energy_floor == 0.0: return log_energy - else: - return torch.max(log_energy, - torch.tensor(math.log(energy_floor))) + return torch.max( + log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) def _get_waveform_and_window_properties(waveform: Tensor, @@ -160,12 +167,15 @@ def _get_window(waveform: Tensor, Returns: (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + # size (m, window_size) strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) if dither != 0.0: # Returns a random number strictly between 0 and 1 - x = torch.max(EPSILON, torch.rand(strided_input.shape)) + x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype)) rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x) strided_input = strided_input + rand_gauss * dither @@ -177,7 +187,7 @@ def _get_window(waveform: Tensor, if raw_energy: # Compute the log energy of each row/frame before applying preemphasis and # window function - signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m) + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) if preemphasis_coefficient != 0.0: # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j @@ -187,7 +197,7 @@ def _get_window(waveform: Tensor, # Apply window_function to each row/frame window_function = _feature_window_function( - window_type, window_size, blackman_coeff).unsqueeze(0) # size (1, window_size) + window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size) strided_input = strided_input * window_function # size (m, window_size) # Pad columns with zero until we reach size (m, padded_window_size) @@ -198,7 +208,7 @@ def _get_window(waveform: Tensor, # Compute energy after window function (not the raw one) if not raw_energy: - signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m) + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) return strided_input, signal_log_energy @@ -541,12 +551,14 @@ def fbank(waveform: Tensor, Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) where m is calculated in _get_strided """ + device, dtype = waveform.device, waveform.dtype + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) if len(waveform) < min_duration * sample_frequency: # signal is too short - return torch.empty(0) + return torch.empty(0, device=device, dtype=dtype) # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) strided_input, signal_log_energy = _get_window( @@ -563,6 +575,7 @@ def fbank(waveform: Tensor, # size (num_mel_bins, padded_window_size // 2) mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp) + mel_energies = mel_energies.to(device=device, dtype=dtype) # pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0) @@ -571,7 +584,7 @@ def fbank(waveform: Tensor, mel_energies = (power_spectrum * mel_energies).sum(dim=2) if use_log_fbank: # avoid log of zero (which should be prevented anyway by dithering) - mel_energies = torch.max(mel_energies, EPSILON).log() + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() # if use_energy then add it as the last column for htk_compat == true else first column if use_energy: @@ -737,7 +750,9 @@ def _get_LR_indices_and_weights(orig_freq: float, output_samples_in_unit: int, window_width: float, lowpass_cutoff: float, - lowpass_filter_width: int) -> Tuple[Tensor, Tensor]: + lowpass_filter_width: int, + device: torch.device, + dtype: torch.dtype) -> Tuple[Tensor, Tensor]: r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for resampling as well as the indices in which they are valid. LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e the output signal has a @@ -785,7 +800,7 @@ def _get_LR_indices_and_weights(orig_freq: float, which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)). """ assert lowpass_cutoff < min(orig_freq, new_freq) / 2 - output_t = torch.arange(0., output_samples_in_unit) / new_freq + output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq min_t = output_t - window_width max_t = output_t + window_width @@ -795,7 +810,7 @@ def _get_LR_indices_and_weights(orig_freq: float, max_weight_width = num_indices.max() # create a group of weights of size (output_samples_in_unit, max_weight_width) - j = torch.arange(max_weight_width).unsqueeze(0) + j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0) input_index = min_input_index.unsqueeze(1) + j delta_t = (input_index / orig_freq) - output_t.unsqueeze(1) @@ -905,9 +920,9 @@ def resample_waveform(waveform: Tensor, output_samples_in_unit = int(new_freq) // base_freq window_width = lowpass_filter_width / (2.0 * lowpass_cutoff) - first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, - window_width, lowpass_cutoff, lowpass_filter_width) - weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly + first_indices, weights = _get_LR_indices_and_weights( + orig_freq, new_freq, output_samples_in_unit, + window_width, lowpass_cutoff, lowpass_filter_width, device, dtype) assert first_indices.dim() == 1 # TODO figure a better way to do this. conv1d reaches every element i*stride + padding From df7daa33db1031e8827dc01888ad52b33696310a Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 8 May 2020 02:56:22 +0000 Subject: [PATCH 2/4] Reduce rtol for kaldi --- test/test_kaldi_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_kaldi_compatibility.py b/test/test_kaldi_compatibility.py index 2c29141317..580900cd1d 100644 --- a/test/test_kaldi_compatibility.py +++ b/test/test_kaldi_compatibility.py @@ -97,7 +97,7 @@ def test_fbank(self): result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] kaldi_result = _run_kaldi(command, 'scp', wave_file) - torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype)) + torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype), rtol=1e-4, atol=1e-8) common_utils.define_test_suites(globals(), [Kaldi]) From f6443fc0ea952f577165987864af6327aba3f335 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Sat, 9 May 2020 20:03:31 +0000 Subject: [PATCH 3/4] fix test --- torchaudio/compliance/kaldi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index e73300455a..c9187e94d1 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -89,7 +89,7 @@ def _feature_window_function(window_type: str, window_size: int, blackman_coeff: float, device: torch.device, - dtype: torch.dtype + dtype: int ) -> Tensor: r"""Returns a window function with the given type and size """ @@ -752,7 +752,7 @@ def _get_LR_indices_and_weights(orig_freq: float, lowpass_cutoff: float, lowpass_filter_width: int, device: torch.device, - dtype: torch.dtype) -> Tuple[Tensor, Tensor]: + dtype: int) -> Tuple[Tensor, Tensor]: r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for resampling as well as the indices in which they are valid. LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e the output signal has a From c2cd73ec8d938503f561565f2628201452990029 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 12 May 2020 03:14:57 +0000 Subject: [PATCH 4/4] fix flake8 --- torchaudio/compliance/kaldi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index c9187e94d1..8115c3a25f 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -89,8 +89,8 @@ def _feature_window_function(window_type: str, window_size: int, blackman_coeff: float, device: torch.device, - dtype: int -) -> Tensor: + dtype: int, + ) -> Tensor: r"""Returns a window function with the given type and size """ if window_type == HANNING: @@ -552,7 +552,7 @@ def fbank(waveform: Tensor, where m is calculated in _get_strided """ device, dtype = waveform.device, waveform.dtype - + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)