Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 101 additions & 68 deletions research/deep_speech/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from __future__ import division
from __future__ import print_function

import functools
import multiprocessing

import numpy as np
import scipy.io.wavfile as wavfile
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf

# pylint: disable=g-bad-import-order
from data.featurizer import AudioFeaturizer
from data.featurizer import TextFeaturizer
import data.featurizer as featurizer # pylint: disable=g-bad-import-order


class AudioConfig(object):
Expand All @@ -44,7 +45,7 @@ def __init__(self,
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
normalize: a boolean for whether apply normalization on the audio tensor.
normalize: a boolean for whether apply normalization on the audio feature.
spect_type: a string for the type of spectrogram to be extracted.
"""

Expand Down Expand Up @@ -78,90 +79,122 @@ def __init__(self, audio_config, data_path, vocab_file_path):
self.vocab_file_path = vocab_file_path


def _normalize_audio_feature(audio_feature):
"""Perform mean and variance normalization on the spectrogram feature.

Args:
audio_feature: a numpy array for the spectrogram feature.

Returns:
a numpy array of the normalized spectrogram.
"""
mean = np.mean(audio_feature, axis=0)
var = np.var(audio_feature, axis=0)
normalized = (audio_feature - mean) / (np.sqrt(var) + 1e-6)

return normalized


def _preprocess_audio(
audio_file_path, audio_sample_rate, audio_featurizer, normalize):
"""Load the audio file in memory and compute spectrogram feature."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == audio_sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
feature = featurizer.compute_spectrogram_feature(
data, audio_featurizer.frame_length, audio_featurizer.frame_step,
audio_featurizer.fft_length)
if normalize:
feature = _normalize_audio_feature(feature)
return feature


def _preprocess_transcript(transcript, token_to_index):
"""Process transcript as label features."""
return featurizer.compute_label_feature(transcript, token_to_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we should avoid the one line wrap function with same parameter.



def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
"""Generate a list of waveform, transcript pair.

Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
AS the waveforms are ordered in increasing length, audio samples in a
mini-batch have similar length.

Args:
dataset_config: an instance of DatasetConfig.
audio_featurizer: an instance of AudioFeaturizer.
token_to_index: the mapping from character to its index

Returns:
features and labels array processed from the audio/text input.
"""

file_path = dataset_config.data_path
sample_rate = dataset_config.audio_config.sample_rate
normalize = dataset_config.audio_config.normalize

with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
lines = lines[1:]
# Sort input data by the length of waveform.
lines.sort(key=lambda item: int(item[1]))

# Use multiprocessing for feature/label extraction
num_cores = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=num_cores)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One trick I learned recently is that contextlib can let you use a context manager in 2 & 3.

with contextlib.closing(multiprocessing.Pool(processes=num_cores)):
  pool.map(...)


features = pool.map(
functools.partial(
_preprocess_audio, audio_sample_rate=sample_rate,
audio_featurizer=audio_featurizer, normalize=normalize),
[line[0] for line in lines])
labels = pool.map(
functools.partial(
_preprocess_transcript, token_to_index=token_to_index),
[line[2] for line in lines])

pool.terminate()
return features, labels


class DeepSpeechDataset(object):
"""Dataset class for training/evaluation of DeepSpeech model."""

def __init__(self, dataset_config):
"""Initialize the class.

Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
by the increasing order of audio length (indicated by wav_filesize).
"""Initialize the DeepSpeechDataset class.

Args:
dataset_config: DatasetConfig object.
"""
self.config = dataset_config
# Instantiate audio feature extractor.
self.audio_featurizer = AudioFeaturizer(
self.audio_featurizer = featurizer.AudioFeaturizer(
sample_rate=self.config.audio_config.sample_rate,
frame_length=self.config.audio_config.frame_length,
frame_step=self.config.audio_config.frame_step,
fft_length=self.config.audio_config.fft_length,
spect_type=self.config.audio_config.spect_type)
fft_length=self.config.audio_config.fft_length)
# Instantiate text feature extractor.
self.text_featurizer = TextFeaturizer(
self.text_featurizer = featurizer.TextFeaturizer(
vocab_file=self.config.vocab_file_path)

self.speech_labels = self.text_featurizer.speech_labels
self.features, self.labels = self._preprocess_data(self.config.data_path)
self.features, self.labels = _preprocess_data(
self.config,
self.audio_featurizer,
self.text_featurizer.token_to_idx
)

self.num_feature_bins = (
self.features[0].shape[1] if len(self.features) else None)

def _preprocess_data(self, file_path):
"""Generate a list of waveform, transcript pair.

Note that the waveforms are ordered in increasing length, so that audio
samples in a mini-batch have similar length.

Args:
file_path: a string specifying the csv file path for a data set.

Returns:
features and labels array processed from the audio/text input.
"""

with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
lines = lines[1:]
# Sort input data by the length of waveform.
lines.sort(key=lambda item: int(item[1]))
features = [self._preprocess_audio(line[0]) for line in lines]
labels = [self._preprocess_transcript(line[2]) for line in lines]
return features, labels

def _normalize_audio_tensor(self, audio_tensor):
"""Perform mean and variance normalization on the spectrogram tensor.

Args:
audio_tensor: a tensor for the spectrogram feature.

Returns:
a tensor for the normalized spectrogram.
"""
mean, var = tf.nn.moments(audio_tensor, axes=[0])
normalized = (audio_tensor - mean) / (tf.sqrt(var) + 1e-6)
return normalized

def _preprocess_audio(self, audio_file_path):
"""Load the audio file in memory."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == self.config.audio_config.sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
feature = self.audio_featurizer.featurize(data)
if self.config.audio_config.normalize:
feature = self._normalize_audio_tensor(feature)
return tf.Session().run(
feature) # return a numpy array rather than a tensor

def _preprocess_transcript(self, transcript):
return self.text_featurizer.featurize(transcript)


def input_fn(batch_size, deep_speech_dataset, repeat=1):
"""Input function for model training and evaluation.
Expand Down
72 changes: 20 additions & 52 deletions research/deep_speech/data/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,21 @@
from __future__ import print_function

import codecs
import functools
import numpy as np
import tensorflow as tf
from scipy import signal


def compute_spectrogram_feature(waveform, frame_length, frame_step, fft_length):
"""Compute the spectrograms for the input waveform."""
_, _, stft = signal.stft(
waveform,
nperseg=frame_length,
noverlap=frame_step,
nfft=fft_length)

# Perform transpose to set its shape as [time_steps, feature_num_bins]
spectrogram = np.transpose(np.absolute(stft), (1, 0))
return spectrogram


class AudioFeaturizer(object):
Expand All @@ -30,64 +42,26 @@ def __init__(self,
sample_rate=16000,
frame_length=25,
frame_step=10,
fft_length=None,
window_fn=functools.partial(
tf.contrib.signal.hann_window, periodic=True),
spect_type="linear"):
fft_length=None):
"""Initialize the audio featurizer class according to the configs.

Args:
sample_rate: an integer specifying the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_fn: windowing function.
spect_type: a string for the type of spectrogram to be extracted.
Currently only support 'linear', otherwise will raise a value error.

Raises:
ValueError: In case of invalid arguments for `spect_type`.
"""
if spect_type != "linear":
raise ValueError("Unsupported spectrogram type: %s" % spect_type)
self.window_fn = window_fn
self.frame_length = int(sample_rate * frame_length / 1e3)
self.frame_step = int(sample_rate * frame_step / 1e3)
self.fft_length = fft_length if fft_length else int(2**(np.ceil(
np.log2(self.frame_length))))

def featurize(self, waveform):
"""Extract spectrogram feature tensors from the waveform."""
return self._compute_linear_spectrogram(waveform)

def _compute_linear_spectrogram(self, waveform):
"""Compute the linear-scale, magnitude spectrograms for the input waveform.

Args:
waveform: a float32 audio tensor.
Returns:
a float 32 tensor with shape [len, num_bins]
"""

# `stfts` is a complex64 Tensor representing the Short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1.
stfts = tf.contrib.signal.stft(
waveform,
frame_length=self.frame_length,
frame_step=self.frame_step,
fft_length=self.fft_length,
window_fn=self.window_fn,
pad_end=True)

# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [?, 257].
magnitude_spectrograms = tf.abs(stfts)
return magnitude_spectrograms

def _compute_mel_filterbank_features(self, waveform):
"""Compute the mel filterbank features."""
raise NotImplementedError("MFCC feature extraction not supported yet.")
def compute_label_feature(text, token_to_idx):
"""Convert string to a list of integers."""
tokens = list(text.strip().lower())
feats = [token_to_idx[token] for token in tokens]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just return here.

return feats


class TextFeaturizer(object):
Expand All @@ -114,9 +88,3 @@ def __init__(self, vocab_file):
self.idx_to_token[idx] = line
self.speech_labels += line
idx += 1

def featurize(self, text):
"""Convert string to a list of integers."""
tokens = list(text.strip().lower())
feats = [self.token_to_idx[token] for token in tokens]
return feats
Loading