Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions test/test_kaldi_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _run_kaldi(command, input_type, input_value):
key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.numpy(), key=key)
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else:
Expand All @@ -47,7 +47,7 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning


class TestFunctional:
class Kaldi(common_utils.TestBaseMixin):
@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
Expand All @@ -58,11 +58,11 @@ def test_sliding_window_cmn(self):
'norm_vars': False,
}

tensor = torch.randn(40, 10)
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'ark', tensor)
torch.testing.assert_allclose(result, kaldi_result)
torch.testing.assert_allclose(result.cpu(), kaldi_result.to(self.dtype))

@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
def test_fbank(self):
Expand Down Expand Up @@ -93,7 +93,11 @@ def test_fbank(self):

}
wave_file = common_utils.get_asset_path('kaldi_file.wav')
result = torchaudio.compliance.kaldi.fbank(torchaudio.load_wav(wave_file)[0], **kwargs)
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
torch.testing.assert_allclose(result, kaldi_result)
torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype), rtol=1e-4, atol=1e-8)


common_utils.define_test_suites(globals(), [Kaldi])
61 changes: 38 additions & 23 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]


def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)


def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x
"""
Expand Down Expand Up @@ -60,7 +64,7 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg

if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0))
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
Expand All @@ -83,24 +87,27 @@ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edg

def _feature_window_function(window_type: str,
window_size: int,
blackman_coeff: float) -> Tensor:
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size
"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False)
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46)
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False).pow(0.85)
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size)
return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (blackman_coeff - 0.5 * torch.cos(a * window_function) +
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function))
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype)
else:
raise Exception('Invalid window type ' + window_type)

Expand All @@ -110,12 +117,12 @@ def _get_log_energy(strided_input: Tensor,
energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)
"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
else:
return torch.max(log_energy,
torch.tensor(math.log(energy_floor)))
return torch.max(
log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))


def _get_waveform_and_window_properties(waveform: Tensor,
Expand Down Expand Up @@ -160,12 +167,15 @@ def _get_window(waveform: Tensor,
Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)

# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)

if dither != 0.0:
# Returns a random number strictly between 0 and 1
x = torch.max(EPSILON, torch.rand(strided_input.shape))
x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype))
rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither

Expand All @@ -177,7 +187,7 @@ def _get_window(waveform: Tensor,
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m)
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)

if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
Expand All @@ -187,7 +197,7 @@ def _get_window(waveform: Tensor,

# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff).unsqueeze(0) # size (1, window_size)
window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)

# Pad columns with zero until we reach size (m, padded_window_size)
Expand All @@ -198,7 +208,7 @@ def _get_window(waveform: Tensor,

# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m)
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)

return strided_input, signal_log_energy

Expand Down Expand Up @@ -541,12 +551,14 @@ def fbank(waveform: Tensor,
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype

waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)

if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
return torch.empty(0, device=device, dtype=dtype)

# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
Expand All @@ -563,6 +575,7 @@ def fbank(waveform: Tensor,
# size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.to(device=device, dtype=dtype)

# pad right column with zeros and add dimension, size (1, num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0).unsqueeze(0)
Expand All @@ -571,7 +584,7 @@ def fbank(waveform: Tensor,
mel_energies = (power_spectrum * mel_energies).sum(dim=2)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, EPSILON).log()
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()

# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
Expand Down Expand Up @@ -737,7 +750,9 @@ def _get_LR_indices_and_weights(orig_freq: float,
output_samples_in_unit: int,
window_width: float,
lowpass_cutoff: float,
lowpass_filter_width: int) -> Tuple[Tensor, Tensor]:
lowpass_filter_width: int,
device: torch.device,
dtype: int) -> Tuple[Tensor, Tensor]:
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a
Expand Down Expand Up @@ -785,7 +800,7 @@ def _get_LR_indices_and_weights(orig_freq: float,
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0., output_samples_in_unit) / new_freq
output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq
min_t = output_t - window_width
max_t = output_t + window_width

Expand All @@ -795,7 +810,7 @@ def _get_LR_indices_and_weights(orig_freq: float,

max_weight_width = num_indices.max()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j = torch.arange(max_weight_width).unsqueeze(0)
j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0)
input_index = min_input_index.unsqueeze(1) + j
delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)

Expand Down Expand Up @@ -905,9 +920,9 @@ def resample_waveform(waveform: Tensor,
output_samples_in_unit = int(new_freq) // base_freq

window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width)
weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly
first_indices, weights = _get_LR_indices_and_weights(
orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width, device, dtype)

assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
Expand Down