Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Kaldi Spectrogram #119

Merged
merged 22 commits into from
Jun 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test/assets/kaldi_file.wav
Binary file not shown.
80 changes: 80 additions & 0 deletions test/compliance/generate_test_stft_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import random

# Path to the compute-spectrogram-feats executable.
EXE_PATH = '/scratch/jamarshon/kaldi/src/featbin/compute-spectrogram-feats'

# Path to the scp file. An example of its contents would be "my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav"
# where the space separates an id from a wav file.
SCP_PATH = 'scp:/scratch/jamarshon/downloads/a.scp'
# The directory to which the stft features will be written to.
OUTPUT_DIR = 'ark:/scratch/jamarshon/audio/test/assets/kaldi/'

# The number of samples inside the input wave file read from `SCP_PATH`
WAV_LEN = 20

# How many output files should be generated.
NUMBER_OF_OUTPUTS = 100

WINDOWS = ['hamming', 'hanning', 'povey', 'rectangular', 'blackman']


def generate_rand_boolean():
# Generates a random boolean ('true', 'false')
return 'true' if random.randint(0, 1) else 'false'


def generate_rand_window_type():
# Generates a random window type
return WINDOWS[random.randint(0, len(WINDOWS) - 1)]


def run():
for i in range(NUMBER_OF_OUTPUTS):
inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5),
'dither': '0',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the tests were done with dither=0. and the energy_floor set to a random value.

'energy_floor': '%.4f' % (random.random() * 5),
'frame_length': '%.4f' % (float(random.randint(2, WAV_LEN - 1)) / 16000 * 1000),
'frame_shift': '%.4f' % (float(random.randint(1, WAV_LEN - 1)) / 16000 * 1000),
'preemphasis_coefficient': '%.2f' % random.random(),
'raw_energy': generate_rand_boolean(),
'remove_dc_offset': generate_rand_boolean(),
'round_to_power_of_two': generate_rand_boolean(),
'snip_edges': generate_rand_boolean(),
'subtract_mean': generate_rand_boolean(),
'window_type': generate_rand_window_type()
}

fn = 'spec-' + ('-'.join(list(inputs.values())))

arg = [
EXE_PATH,
'--blackman-coeff=' + inputs['blackman_coeff'],
'--dither=' + inputs['dither'],
'--energy-floor=' + inputs['energy_floor'],
'--frame-length=' + inputs['frame_length'],
'--frame-shift=' + inputs['frame_shift'],
'--preemphasis-coefficient=' + inputs['preemphasis_coefficient'],
'--raw-energy=' + inputs['raw_energy'],
'--remove-dc-offset=' + inputs['remove_dc_offset'],
'--round-to-power-of-two=' + inputs['round_to_power_of_two'],
'--sample-frequency=16000',
'--snip-edges=' + inputs['snip_edges'],
'--subtract-mean=' + inputs['subtract_mean'],
'--window-type=' + inputs['window_type'],
SCP_PATH,
OUTPUT_DIR + fn + '.ark'
]

print(fn)
print(inputs)
print(' '.join(arg))

try:
subprocess.call(arg)
except Exception:
pass


if __name__ == '__main__':
run()
140 changes: 140 additions & 0 deletions test/compliance/test_kaldi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import math
import os
import test.common_utils
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
# just a copy of ExtractWindow from feature-window.cc in python
def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
if snip_edges:
return frame * window_shift
else:
midpoint_of_frame = frame * window_shift + window_shift // 2
beginning_of_frame = midpoint_of_frame - window_size // 2
return beginning_of_frame

sample_offset = 0
num_samples = sample_offset + wave.size(0)
start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges)
end_sample = start_sample + frame_length

if snip_edges:
assert(start_sample >= sample_offset and end_sample <= num_samples)
else:
assert(sample_offset == 0 or start_sample >= sample_offset)

wave_start = start_sample - sample_offset
wave_end = wave_start + frame_length
if wave_start >= 0 and wave_end <= wave.size(0):
window[f, :] = wave[wave_start:(wave_start + frame_length)]
else:
wave_dim = wave.size(0)
for s in range(frame_length):
s_in_wave = s + wave_start
while s_in_wave < 0 or s_in_wave >= wave_dim:
if s_in_wave < 0:
s_in_wave = - s_in_wave - 1
else:
s_in_wave = 2 * wave_dim - 1 - s_in_wave
window[f, s] = wave[s_in_wave]


class Test_Kaldi(unittest.TestCase):
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')

def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
waveform = torch.arange(num_samples).float()
output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)

# from NumFrames in feature-window.cc
n = window_size
if snip_edges:
m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift
else:
m = (num_samples + (window_shift // 2)) // window_shift

self.assertTrue(output.dim() == 2)
self.assertTrue(output.shape[0] == m and output.shape[1] == n)

window = torch.empty((m, window_size))

for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges)
self.assertTrue(torch.allclose(window, output))

def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and
# 0 < window_shift.
for num_samples in range(1, 20):
for window_size in range(1, num_samples + 1):
for window_shift in range(1, 2 * num_samples + 1):
for snip_edges in range(0, 2):
self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)

def _create_data_set(self):
# used to generate the dataset to test on. this is not used in testing (offline procedure)
test_dirpath = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
sr = 16000
x = torch.arange(0, 20).float()
# between [-6,6]
y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
# between [-2^30, 2^30]
y = (y / 6 * (1 << 30)).long()
# clear the last 16 bits because they aren't used anyways
y = ((y >> 16) << 16).float()
jamarshon marked this conversation as resolved.
Show resolved Hide resolved
torchaudio.save(test_filepath, y, sr)
sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
print(y >> 16)
self.assertTrue(sample_rate == sr)
self.assertTrue(torch.allclose(y, sound))

def test_spectrogram(self):
sound, sample_rate = torchaudio.load_wav(self.test_filepath)
kaldi_output_dir = os.path.join(self.test_dirpath, 'assets', 'kaldi')
files = list(filter(lambda x: x.startswith('spec'), os.listdir(kaldi_output_dir)))
print('Results:', len(files))

for f in files:
print(f)
kaldi_output_path = os.path.join(kaldi_output_dir, f)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)}

assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
kaldi_output = kaldi_output_dict['my_id']

args = f.split('-')
args[-1] = os.path.splitext(args[-1])[0]
assert len(args) == 13, 'invalid test kaldi file name'

spec_output = kaldi.spectrogram(
sound,
blackman_coeff=float(args[1]),
dither=float(args[2]),
energy_floor=float(args[3]),
frame_length=float(args[4]),
frame_shift=float(args[5]),
preemphasis_coefficient=float(args[6]),
raw_energy=args[7] == 'true',
remove_dc_offset=args[8] == 'true',
round_to_power_of_two=args[9] == 'true',
snip_edges=args[10] == 'true',
subtract_mean=args[11] == 'true',
window_type=args[12])

error = spec_output - kaldi_output
mse = error.pow(2).sum() / spec_output.numel()
max_error = torch.max(error.abs())

print('mse:', mse.item(), 'max_error:', max_error.item())
self.assertTrue(spec_output.shape, kaldi_output.shape)
self.assertTrue(torch.allclose(spec_output, kaldi_output, atol=1e-3, rtol=0))


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TORCHAUDIODS(Dataset):

def __init__(self):
self.asset_dirpath = os.path.join(self.test_dirpath, "assets")
sound_files = list(filter(lambda x: '.wav' in x or '.mp3' in x, os.listdir(self.asset_dirpath)))
sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
self.data = [os.path.join(self.asset_dirpath, fn) for fn in sound_files]
self.si, self.ei = torchaudio.info(os.path.join(self.asset_dirpath, "sinewave.wav"))
self.si.precision = 16
Expand Down
10 changes: 9 additions & 1 deletion torchaudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import _torch_sox

from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy
from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy, compliance


def check_input(src):
Expand Down Expand Up @@ -92,6 +92,14 @@ def load(filepath,
return out, sample_rate


def load_wav(filepath, **kwargs):
""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits.
"""
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)


def save(filepath, src, sample_rate, precision=16, channels_first=True):
"""Convenience function for `save_encinfo`.

Expand Down
1 change: 1 addition & 0 deletions torchaudio/compliance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import kaldi
Loading