Skip to content

Commit

Permalink
Update deep speech model with pure tensorflow API implementation (#4730)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update
  • Loading branch information
haozha111 authored and yhliang2018 committed Jul 11, 2018
1 parent 37ba230 commit d90f558
Show file tree
Hide file tree
Showing 8 changed files with 548 additions and 496 deletions.
61 changes: 61 additions & 0 deletions research/deep_speech/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# DeepSpeech2 Model
## Overview
This is an implementation of the [DeepSpeech2](https://arxiv.org/pdf/1512.02595.pdf) model. Current implementation is based on the code from the authors' [DeepSpeech code](https://github.com/PaddlePaddle/DeepSpeech) and the implementation in the [MLPerf Repo](https://github.com/mlperf/reference/tree/master/speech_recognition).

DeepSpeech2 is an end-to-end deep neural network for automatic speech
recognition (ASR). It consists of 2 convolutional layers, 5 bidirectional RNN
layers and a fully connected layer. The feature in use is linear spectrogram
extracted from audio input. The network uses Connectionist Temporal Classification [CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf) as the loss function.

## Dataset
The [OpenSLR LibriSpeech Corpus](http://www.openslr.org/12/) are used for model training and evaluation.

The training data is a combination of train-clean-100 and train-clean-360 (~130k
examples in total). The validation set is dev-clean which has 2.7K lines.
The download script will preprocess the data into three columns: wav_filename,
wav_filesize, transcript. data/dataset.py will parse the csv file and build a
tf.data.Dataset object to feed data. Within each epoch (except for the
first if sortagrad is enabled), the training data will be shuffled batch-wise.

## Running Code

### Configure Python path
Add the top-level /models folder to the Python path with the command:
```
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```

### Install dependencies

First install shared dependencies before running the code. Issue the following command:
```
pip3 install -r requirements.txt
```
or
```
pip install -r requirements.txt
```

### Download and preprocess dataset
To download the dataset, issue the following command:
```
python data/download.py
```
Arguments:
* `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/librispeech_data`.

Use the `--help` or `-h` flag to get a full list of possible arguments.

### Train and evaluate model
To train and evaluate the model, issue the following command:
```
python deep_speech.py
```
Arguments:
* `--model_dir`: Directory to save model training checkpoints. By default, it is `/tmp/deep_speech_model/`.
* `--train_data_dir`: Directory of the training dataset.
* `--eval_data_dir`: Directory of the evaluation dataset.
* `--num_gpus`: Number of GPUs to use (specify -1 if you want to use all available GPUs).

There are other arguments about DeepSpeech2 model and training/evaluation process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.

230 changes: 125 additions & 105 deletions research/deep_speech/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from __future__ import division
from __future__ import print_function

import functools
import multiprocessing

import math
import random
# pylint: disable=g-bad-import-order
import numpy as np
import scipy.io.wavfile as wavfile
from six.moves import xrange # pylint: disable=redefined-builtin
import soundfile
import tensorflow as tf
# pylint: enable=g-bad-import-order

import data.featurizer as featurizer # pylint: disable=g-bad-import-order

Expand All @@ -33,40 +34,37 @@ class AudioConfig(object):

def __init__(self,
sample_rate,
frame_length,
frame_step,
fft_length=None,
normalize=False,
spect_type="linear"):
window_ms,
stride_ms,
normalize=False):
"""Initialize the AudioConfig class.
Args:
sample_rate: an integer denoting the sample rate of the input waveform.
frame_length: an integer for the length of a spectrogram frame, in ms.
frame_step: an integer for the frame stride, in ms.
fft_length: an integer for the number of fft bins.
window_ms: an integer for the length of a spectrogram frame, in ms.
stride_ms: an integer for the frame stride, in ms.
normalize: a boolean for whether apply normalization on the audio feature.
spect_type: a string for the type of spectrogram to be extracted.
"""

self.sample_rate = sample_rate
self.frame_length = frame_length
self.frame_step = frame_step
self.fft_length = fft_length
self.window_ms = window_ms
self.stride_ms = stride_ms
self.normalize = normalize
self.spect_type = spect_type


class DatasetConfig(object):
"""Config class for generating the DeepSpeechDataset."""

def __init__(self, audio_config, data_path, vocab_file_path):
def __init__(self, audio_config, data_path, vocab_file_path, sortagrad):
"""Initialize the configs for deep speech dataset.
Args:
audio_config: AudioConfig object specifying the audio-related configs.
data_path: a string denoting the full path of a manifest file.
vocab_file_path: a string specifying the vocabulary file path.
sortagrad: a boolean, if set to true, audio sequences will be fed by
increasing length in the first training epoch, which will
expedite network convergence.
Raises:
RuntimeError: file path not exist.
Expand All @@ -77,6 +75,7 @@ def __init__(self, audio_config, data_path, vocab_file_path):
assert tf.gfile.Exists(vocab_file_path)
self.data_path = data_path
self.vocab_file_path = vocab_file_path
self.sortagrad = sortagrad


def _normalize_audio_feature(audio_feature):
Expand All @@ -95,30 +94,23 @@ def _normalize_audio_feature(audio_feature):
return normalized


def _preprocess_audio(
audio_file_path, audio_sample_rate, audio_featurizer, normalize):
"""Load the audio file in memory and compute spectrogram feature."""
tf.logging.info(
"Extracting spectrogram feature for {}".format(audio_file_path))
sample_rate, data = wavfile.read(audio_file_path)
assert sample_rate == audio_sample_rate
if data.dtype not in [np.float32, np.float64]:
data = data.astype(np.float32) / np.iinfo(data.dtype).max
def _preprocess_audio(audio_file_path, audio_featurizer, normalize):
"""Load the audio file and compute spectrogram feature."""
data, _ = soundfile.read(audio_file_path)
feature = featurizer.compute_spectrogram_feature(
data, audio_featurizer.frame_length, audio_featurizer.frame_step,
audio_featurizer.fft_length)
data, audio_featurizer.sample_rate, audio_featurizer.stride_ms,
audio_featurizer.window_ms)
# Feature normalization
if normalize:
feature = _normalize_audio_feature(feature)
return feature


def _preprocess_transcript(transcript, token_to_index):
"""Process transcript as label features."""
return featurizer.compute_label_feature(transcript, token_to_index)
# Adding Channel dimension for conv2D input.
feature = np.expand_dims(feature, axis=2)
return feature


def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
"""Generate a list of waveform, transcript pair.
def _preprocess_data(file_path):
"""Generate a list of tuples (wav_filename, wav_filesize, transcript).
Each dataset file contains three columns: "wav_filename", "wav_filesize",
and "transcript". This function parses the csv file and stores each example
Expand All @@ -127,42 +119,23 @@ def _preprocess_data(dataset_config, audio_featurizer, token_to_index):
mini-batch have similar length.
Args:
dataset_config: an instance of DatasetConfig.
audio_featurizer: an instance of AudioFeaturizer.
token_to_index: the mapping from character to its index
file_path: a string specifying the csv file path for a dataset.
Returns:
features and labels array processed from the audio/text input.
A list of tuples (wav_filename, wav_filesize, transcript) sorted by
file_size.
"""

file_path = dataset_config.data_path
sample_rate = dataset_config.audio_config.sample_rate
normalize = dataset_config.audio_config.normalize

tf.logging.info("Loading data set {}".format(file_path))
with tf.gfile.Open(file_path, "r") as f:
lines = f.read().splitlines()
lines = [line.split("\t") for line in lines]
# Skip the csv header.
# Skip the csv header in lines[0].
lines = lines[1:]
# Sort input data by the length of waveform.
# The metadata file is tab separated.
lines = [line.split("\t", 2) for line in lines]
# Sort input data by the length of audio sequence.
lines.sort(key=lambda item: int(item[1]))

# Use multiprocessing for feature/label extraction
num_cores = multiprocessing.cpu_count()
pool = multiprocessing.Pool(processes=num_cores)

features = pool.map(
functools.partial(
_preprocess_audio, audio_sample_rate=sample_rate,
audio_featurizer=audio_featurizer, normalize=normalize),
[line[0] for line in lines])
labels = pool.map(
functools.partial(
_preprocess_transcript, token_to_index=token_to_index),
[line[2] for line in lines])

pool.terminate()
return features, labels
return [tuple(line) for line in lines]


class DeepSpeechDataset(object):
Expand All @@ -178,22 +151,52 @@ def __init__(self, dataset_config):
# Instantiate audio feature extractor.
self.audio_featurizer = featurizer.AudioFeaturizer(
sample_rate=self.config.audio_config.sample_rate,
frame_length=self.config.audio_config.frame_length,
frame_step=self.config.audio_config.frame_step,
fft_length=self.config.audio_config.fft_length)
window_ms=self.config.audio_config.window_ms,
stride_ms=self.config.audio_config.stride_ms)
# Instantiate text feature extractor.
self.text_featurizer = featurizer.TextFeaturizer(
vocab_file=self.config.vocab_file_path)

self.speech_labels = self.text_featurizer.speech_labels
self.features, self.labels = _preprocess_data(
self.config,
self.audio_featurizer,
self.text_featurizer.token_to_idx
)
self.entries = _preprocess_data(self.config.data_path)
# The generated spectrogram will have 161 feature bins.
self.num_feature_bins = 161


self.num_feature_bins = (
self.features[0].shape[1] if len(self.features) else None)
def batch_wise_dataset_shuffle(entries, epoch_index, sortagrad, batch_size):
"""Batch-wise shuffling of the data entries.
Each data entry is in the format of (audio_file, file_size, transcript).
If epoch_index is 0 and sortagrad is true, we don't perform shuffling and
return entries in sorted file_size order. Otherwise, do batch_wise shuffling.
Args:
entries: a list of data entries.
epoch_index: an integer of epoch index
sortagrad: a boolean to control whether sorting the audio in the first
training epoch.
batch_size: an integer for the batch size.
Returns:
The shuffled data entries.
"""
shuffled_entries = []
if epoch_index == 0 and sortagrad:
# No need to shuffle.
shuffled_entries = entries
else:
# Shuffle entries batch-wise.
max_buckets = int(math.floor(len(entries) / batch_size))
total_buckets = [i for i in xrange(max_buckets)]
random.shuffle(total_buckets)
shuffled_entries = []
for i in total_buckets:
shuffled_entries.extend(entries[i * batch_size : (i + 1) * batch_size])
# If the last batch doesn't contain enough batch_size examples,
# just append it to the shuffled_entries.
shuffled_entries.extend(entries[max_buckets * batch_size:])

return shuffled_entries


def input_fn(batch_size, deep_speech_dataset, repeat=1):
Expand All @@ -207,49 +210,66 @@ def input_fn(batch_size, deep_speech_dataset, repeat=1):
Returns:
a tf.data.Dataset object for model to consume.
"""
features = deep_speech_dataset.features
labels = deep_speech_dataset.labels
# Dataset properties
data_entries = deep_speech_dataset.entries
num_feature_bins = deep_speech_dataset.num_feature_bins
audio_featurizer = deep_speech_dataset.audio_featurizer
feature_normalize = deep_speech_dataset.config.audio_config.normalize
text_featurizer = deep_speech_dataset.text_featurizer

def _gen_data():
for i in xrange(len(features)):
feature = np.expand_dims(features[i], axis=2)
input_length = [features[i].shape[0]]
label_length = [len(labels[i])]
yield {
"features": feature,
"labels": labels[i],
"input_length": input_length,
"label_length": label_length
}
"""Dataset generator function."""
for audio_file, _, transcript in data_entries:
features = _preprocess_audio(
audio_file, audio_featurizer, feature_normalize)
labels = featurizer.compute_label_feature(
transcript, text_featurizer.token_to_index)
input_length = [features.shape[0]]
label_length = [len(labels)]
# Yield a tuple of (features, labels) where features is a dict containing
# all info about the actual data features.
yield (
{
"features": features,
"input_length": input_length,
"label_length": label_length
},
labels)

dataset = tf.data.Dataset.from_generator(
_gen_data,
output_types={
"features": tf.float32,
"labels": tf.int32,
"input_length": tf.int32,
"label_length": tf.int32
},
output_shapes={
"features": tf.TensorShape([None, num_feature_bins, 1]),
"labels": tf.TensorShape([None]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
})
output_types=(
{
"features": tf.float32,
"input_length": tf.int32,
"label_length": tf.int32
},
tf.int32),
output_shapes=(
{
"features": tf.TensorShape([None, num_feature_bins, 1]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
},
tf.TensorShape([None]))
)

# Repeat and batch the dataset
dataset = dataset.repeat(repeat)

# Padding the features to its max length dimensions.
dataset = dataset.padded_batch(
batch_size=batch_size,
padded_shapes={
"features": tf.TensorShape([None, num_feature_bins, 1]),
"labels": tf.TensorShape([None]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
})
padded_shapes=(
{
"features": tf.TensorShape([None, num_feature_bins, 1]),
"input_length": tf.TensorShape([1]),
"label_length": tf.TensorShape([1])
},
tf.TensorShape([None]))
)

# Prefetch to improve speed of input pipeline.
dataset = dataset.prefetch(1)
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset

Loading

0 comments on commit d90f558

Please sign in to comment.