-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add eval and parallel dataset #4651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,15 @@ | |
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import functools | ||
import multiprocessing | ||
|
||
import numpy as np | ||
import scipy.io.wavfile as wavfile | ||
from six.moves import xrange # pylint: disable=redefined-builtin | ||
import tensorflow as tf | ||
|
||
# pylint: disable=g-bad-import-order | ||
from data.featurizer import AudioFeaturizer | ||
from data.featurizer import TextFeaturizer | ||
import data.featurizer as featurizer # pylint: disable=g-bad-import-order | ||
|
||
|
||
class AudioConfig(object): | ||
|
@@ -44,7 +45,7 @@ def __init__(self, | |
frame_length: an integer for the length of a spectrogram frame, in ms. | ||
frame_step: an integer for the frame stride, in ms. | ||
fft_length: an integer for the number of fft bins. | ||
normalize: a boolean for whether apply normalization on the audio tensor. | ||
normalize: a boolean for whether apply normalization on the audio feature. | ||
spect_type: a string for the type of spectrogram to be extracted. | ||
""" | ||
|
||
|
@@ -78,90 +79,122 @@ def __init__(self, audio_config, data_path, vocab_file_path): | |
self.vocab_file_path = vocab_file_path | ||
|
||
|
||
def _normalize_audio_feature(audio_feature): | ||
"""Perform mean and variance normalization on the spectrogram feature. | ||
|
||
Args: | ||
audio_feature: a numpy array for the spectrogram feature. | ||
|
||
Returns: | ||
a numpy array of the normalized spectrogram. | ||
""" | ||
mean = np.mean(audio_feature, axis=0) | ||
var = np.var(audio_feature, axis=0) | ||
normalized = (audio_feature - mean) / (np.sqrt(var) + 1e-6) | ||
|
||
return normalized | ||
|
||
|
||
def _preprocess_audio( | ||
audio_file_path, audio_sample_rate, audio_featurizer, normalize): | ||
"""Load the audio file in memory and compute spectrogram feature.""" | ||
tf.logging.info( | ||
"Extracting spectrogram feature for {}".format(audio_file_path)) | ||
sample_rate, data = wavfile.read(audio_file_path) | ||
assert sample_rate == audio_sample_rate | ||
if data.dtype not in [np.float32, np.float64]: | ||
data = data.astype(np.float32) / np.iinfo(data.dtype).max | ||
feature = featurizer.compute_spectrogram_feature( | ||
data, audio_featurizer.frame_length, audio_featurizer.frame_step, | ||
audio_featurizer.fft_length) | ||
if normalize: | ||
feature = _normalize_audio_feature(feature) | ||
return feature | ||
|
||
|
||
def _preprocess_transcript(transcript, token_to_index): | ||
"""Process transcript as label features.""" | ||
return featurizer.compute_label_feature(transcript, token_to_index) | ||
|
||
|
||
def _preprocess_data(dataset_config, audio_featurizer, token_to_index): | ||
"""Generate a list of waveform, transcript pair. | ||
|
||
Each dataset file contains three columns: "wav_filename", "wav_filesize", | ||
and "transcript". This function parses the csv file and stores each example | ||
by the increasing order of audio length (indicated by wav_filesize). | ||
AS the waveforms are ordered in increasing length, audio samples in a | ||
mini-batch have similar length. | ||
|
||
Args: | ||
dataset_config: an instance of DatasetConfig. | ||
audio_featurizer: an instance of AudioFeaturizer. | ||
token_to_index: the mapping from character to its index | ||
|
||
Returns: | ||
features and labels array processed from the audio/text input. | ||
""" | ||
|
||
file_path = dataset_config.data_path | ||
sample_rate = dataset_config.audio_config.sample_rate | ||
normalize = dataset_config.audio_config.normalize | ||
|
||
with tf.gfile.Open(file_path, "r") as f: | ||
lines = f.read().splitlines() | ||
lines = [line.split("\t") for line in lines] | ||
# Skip the csv header. | ||
lines = lines[1:] | ||
# Sort input data by the length of waveform. | ||
lines.sort(key=lambda item: int(item[1])) | ||
|
||
# Use multiprocessing for feature/label extraction | ||
num_cores = multiprocessing.cpu_count() | ||
pool = multiprocessing.Pool(processes=num_cores) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One trick I learned recently is that contextlib can let you use a context manager in 2 & 3.
|
||
|
||
features = pool.map( | ||
functools.partial( | ||
_preprocess_audio, audio_sample_rate=sample_rate, | ||
audio_featurizer=audio_featurizer, normalize=normalize), | ||
[line[0] for line in lines]) | ||
labels = pool.map( | ||
functools.partial( | ||
_preprocess_transcript, token_to_index=token_to_index), | ||
[line[2] for line in lines]) | ||
|
||
pool.terminate() | ||
return features, labels | ||
|
||
|
||
class DeepSpeechDataset(object): | ||
"""Dataset class for training/evaluation of DeepSpeech model.""" | ||
|
||
def __init__(self, dataset_config): | ||
"""Initialize the class. | ||
|
||
Each dataset file contains three columns: "wav_filename", "wav_filesize", | ||
and "transcript". This function parses the csv file and stores each example | ||
by the increasing order of audio length (indicated by wav_filesize). | ||
"""Initialize the DeepSpeechDataset class. | ||
|
||
Args: | ||
dataset_config: DatasetConfig object. | ||
""" | ||
self.config = dataset_config | ||
# Instantiate audio feature extractor. | ||
self.audio_featurizer = AudioFeaturizer( | ||
self.audio_featurizer = featurizer.AudioFeaturizer( | ||
sample_rate=self.config.audio_config.sample_rate, | ||
frame_length=self.config.audio_config.frame_length, | ||
frame_step=self.config.audio_config.frame_step, | ||
fft_length=self.config.audio_config.fft_length, | ||
spect_type=self.config.audio_config.spect_type) | ||
fft_length=self.config.audio_config.fft_length) | ||
# Instantiate text feature extractor. | ||
self.text_featurizer = TextFeaturizer( | ||
self.text_featurizer = featurizer.TextFeaturizer( | ||
vocab_file=self.config.vocab_file_path) | ||
|
||
self.speech_labels = self.text_featurizer.speech_labels | ||
self.features, self.labels = self._preprocess_data(self.config.data_path) | ||
self.features, self.labels = _preprocess_data( | ||
self.config, | ||
self.audio_featurizer, | ||
self.text_featurizer.token_to_idx | ||
) | ||
|
||
self.num_feature_bins = ( | ||
self.features[0].shape[1] if len(self.features) else None) | ||
|
||
def _preprocess_data(self, file_path): | ||
"""Generate a list of waveform, transcript pair. | ||
|
||
Note that the waveforms are ordered in increasing length, so that audio | ||
samples in a mini-batch have similar length. | ||
|
||
Args: | ||
file_path: a string specifying the csv file path for a data set. | ||
|
||
Returns: | ||
features and labels array processed from the audio/text input. | ||
""" | ||
|
||
with tf.gfile.Open(file_path, "r") as f: | ||
lines = f.read().splitlines() | ||
lines = [line.split("\t") for line in lines] | ||
# Skip the csv header. | ||
lines = lines[1:] | ||
# Sort input data by the length of waveform. | ||
lines.sort(key=lambda item: int(item[1])) | ||
features = [self._preprocess_audio(line[0]) for line in lines] | ||
labels = [self._preprocess_transcript(line[2]) for line in lines] | ||
return features, labels | ||
|
||
def _normalize_audio_tensor(self, audio_tensor): | ||
"""Perform mean and variance normalization on the spectrogram tensor. | ||
|
||
Args: | ||
audio_tensor: a tensor for the spectrogram feature. | ||
|
||
Returns: | ||
a tensor for the normalized spectrogram. | ||
""" | ||
mean, var = tf.nn.moments(audio_tensor, axes=[0]) | ||
normalized = (audio_tensor - mean) / (tf.sqrt(var) + 1e-6) | ||
return normalized | ||
|
||
def _preprocess_audio(self, audio_file_path): | ||
"""Load the audio file in memory.""" | ||
tf.logging.info( | ||
"Extracting spectrogram feature for {}".format(audio_file_path)) | ||
sample_rate, data = wavfile.read(audio_file_path) | ||
assert sample_rate == self.config.audio_config.sample_rate | ||
if data.dtype not in [np.float32, np.float64]: | ||
data = data.astype(np.float32) / np.iinfo(data.dtype).max | ||
feature = self.audio_featurizer.featurize(data) | ||
if self.config.audio_config.normalize: | ||
feature = self._normalize_audio_tensor(feature) | ||
return tf.Session().run( | ||
feature) # return a numpy array rather than a tensor | ||
|
||
def _preprocess_transcript(self, transcript): | ||
return self.text_featurizer.featurize(transcript) | ||
|
||
|
||
def input_fn(batch_size, deep_speech_dataset, repeat=1): | ||
"""Input function for model training and evaluation. | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,21 @@ | |
from __future__ import print_function | ||
|
||
import codecs | ||
import functools | ||
import numpy as np | ||
import tensorflow as tf | ||
from scipy import signal | ||
|
||
|
||
def compute_spectrogram_feature(waveform, frame_length, frame_step, fft_length): | ||
"""Compute the spectrograms for the input waveform.""" | ||
_, _, stft = signal.stft( | ||
waveform, | ||
nperseg=frame_length, | ||
noverlap=frame_step, | ||
nfft=fft_length) | ||
|
||
# Perform transpose to set its shape as [time_steps, feature_num_bins] | ||
spectrogram = np.transpose(np.absolute(stft), (1, 0)) | ||
return spectrogram | ||
|
||
|
||
class AudioFeaturizer(object): | ||
|
@@ -30,64 +42,26 @@ def __init__(self, | |
sample_rate=16000, | ||
frame_length=25, | ||
frame_step=10, | ||
fft_length=None, | ||
window_fn=functools.partial( | ||
tf.contrib.signal.hann_window, periodic=True), | ||
spect_type="linear"): | ||
fft_length=None): | ||
"""Initialize the audio featurizer class according to the configs. | ||
|
||
Args: | ||
sample_rate: an integer specifying the sample rate of the input waveform. | ||
frame_length: an integer for the length of a spectrogram frame, in ms. | ||
frame_step: an integer for the frame stride, in ms. | ||
fft_length: an integer for the number of fft bins. | ||
window_fn: windowing function. | ||
spect_type: a string for the type of spectrogram to be extracted. | ||
Currently only support 'linear', otherwise will raise a value error. | ||
|
||
Raises: | ||
ValueError: In case of invalid arguments for `spect_type`. | ||
""" | ||
if spect_type != "linear": | ||
raise ValueError("Unsupported spectrogram type: %s" % spect_type) | ||
self.window_fn = window_fn | ||
self.frame_length = int(sample_rate * frame_length / 1e3) | ||
self.frame_step = int(sample_rate * frame_step / 1e3) | ||
self.fft_length = fft_length if fft_length else int(2**(np.ceil( | ||
np.log2(self.frame_length)))) | ||
|
||
def featurize(self, waveform): | ||
"""Extract spectrogram feature tensors from the waveform.""" | ||
return self._compute_linear_spectrogram(waveform) | ||
|
||
def _compute_linear_spectrogram(self, waveform): | ||
"""Compute the linear-scale, magnitude spectrograms for the input waveform. | ||
|
||
Args: | ||
waveform: a float32 audio tensor. | ||
Returns: | ||
a float 32 tensor with shape [len, num_bins] | ||
""" | ||
|
||
# `stfts` is a complex64 Tensor representing the Short-time Fourier | ||
# Transform of each signal in `signals`. Its shape is | ||
# [?, fft_unique_bins] where fft_unique_bins = fft_length // 2 + 1. | ||
stfts = tf.contrib.signal.stft( | ||
waveform, | ||
frame_length=self.frame_length, | ||
frame_step=self.frame_step, | ||
fft_length=self.fft_length, | ||
window_fn=self.window_fn, | ||
pad_end=True) | ||
|
||
# An energy spectrogram is the magnitude of the complex-valued STFT. | ||
# A float32 Tensor of shape [?, 257]. | ||
magnitude_spectrograms = tf.abs(stfts) | ||
return magnitude_spectrograms | ||
|
||
def _compute_mel_filterbank_features(self, waveform): | ||
"""Compute the mel filterbank features.""" | ||
raise NotImplementedError("MFCC feature extraction not supported yet.") | ||
def compute_label_feature(text, token_to_idx): | ||
"""Convert string to a list of integers.""" | ||
tokens = list(text.strip().lower()) | ||
feats = [token_to_idx[token] for token in tokens] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can just return here. |
||
return feats | ||
|
||
|
||
class TextFeaturizer(object): | ||
|
@@ -114,9 +88,3 @@ def __init__(self, vocab_file): | |
self.idx_to_token[idx] = line | ||
self.speech_labels += line | ||
idx += 1 | ||
|
||
def featurize(self, text): | ||
"""Convert string to a list of integers.""" | ||
tokens = list(text.strip().lower()) | ||
feats = [self.token_to_idx[token] for token in tokens] | ||
return feats |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually we should avoid the one line wrap function with same parameter.