This repository was archived by the owner on Jul 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Added a librispeech data generator. #419
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
0c026d2
Added a librispeech data generator.
wingsbr 75ec0f6
.
wingsbr 5365113
Expanded to include librispeech Problem and Modality.
wingsbr 844df4d
Added librispeech to data_generators/all_problems.py
wingsbr 98c7b41
Switched to full librispeech datasets.
wingsbr 23129f2
Variety of fixes based on PR comments.
wingsbr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,310 @@ | ||
| from tensor2tensor.data_generators import problem | ||
| from tensor2tensor.utils import registry | ||
| from tensor2tensor.models import transformer | ||
| from tensor2tensor.utils import modality | ||
| from tensor2tensor.layers import common_layers | ||
| from tensor2tensor.data_generators import text_encoder | ||
| import random | ||
| import tensorflow as tf | ||
| import numpy as np | ||
| from tensor2tensor.data_generators import generator_utils | ||
| import os | ||
| from subprocess import call | ||
| import tarfile | ||
| import wave | ||
|
|
||
|
|
||
| _LIBRISPEECH_TRAIN_DATASETS = [ | ||
| [ | ||
| "http://www.openslr.org/resources/12/train-clean-100.tar.gz", # pylint: disable=line-too-long | ||
| "train-clean-100" | ||
| ], | ||
| [ | ||
| "http://www.openslr.org/resources/12/train-clean-360.tar.gz", | ||
| "train-clean-360" | ||
| ], | ||
| [ | ||
| "http://www.openslr.org/resources/12/train-other-500.tar.gz", | ||
| "train-other-500" | ||
| ], | ||
| ] | ||
| _LIBRISPEECH_TEST_DATASETS = [ | ||
| [ | ||
| "http://www.openslr.org/resources/12/dev-clean.tar.gz", | ||
| "dev-clean" | ||
| ], | ||
| [ | ||
| "http://www.openslr.org/resources/12/dev-other.tar.gz", | ||
| "dev-other" | ||
| ], | ||
| ] | ||
|
|
||
|
|
||
| def _collect_data(directory, input_ext, transcription_ext): | ||
| """Traverses directory collecting input and target files.""" | ||
| # Directory from string to tuple pair of strings | ||
| # key: the filepath to a datafile including the datafile's basename. Example, | ||
| # if the datafile was "/path/to/datafile.wav" then the key would be | ||
| # "/path/to/datafile" | ||
| # value: a pair of strings (media_filepath, label) | ||
| data_files = dict() | ||
| for root, _, filenames in os.walk(directory): | ||
| transcripts = [filename for filename in filenames if transcription_ext in filename] | ||
| for transcript in transcripts: | ||
| basename = transcript.strip(transcription_ext) | ||
| transcript_path = os.path.join(root, transcript) | ||
| with open(transcript_path, 'r') as transcript_file: | ||
| for transcript_line in transcript_file: | ||
| line_contents = transcript_line.split(" ", 1) | ||
| assert len(line_contents) == 2 | ||
| media_base, label = line_contents | ||
| key = os.path.join(root, media_base) | ||
| assert key not in data_files | ||
| media_name = "%s.%s"%(media_base, input_ext) | ||
| media_path = os.path.join(root, media_name) | ||
| data_files[key] = (media_path, label) | ||
| return data_files | ||
|
|
||
|
|
||
| def _get_audio_data(filepath): | ||
| # Construct a true .wav file. | ||
| out_filepath = filepath.strip(".flac") + ".wav" | ||
| # Assumes sox is installed on system. Sox converts from FLAC to WAV. | ||
| call(["sox", filepath, out_filepath]) | ||
| wav_file = wave.open(open(out_filepath)) | ||
| frame_count = wav_file.getnframes() | ||
| byte_array = wav_file.readframes(frame_count) | ||
|
|
||
| data = np.fromstring(byte_array, np.uint8).tolist() | ||
| return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels() | ||
|
|
||
|
|
||
| class LibrispeechTextEncoder(text_encoder.TextEncoder): | ||
|
|
||
| def encode(self, s): | ||
| return [self._num_reserved_ids + ord(c) for c in s] | ||
|
|
||
| def decode(self, ids): | ||
| """Transform a sequence of int ids into a human-readable string. | ||
| EOS is not expected in ids. | ||
| Args: | ||
| ids: list of integers to be converted. | ||
| Returns: | ||
| s: human-readable string. | ||
| """ | ||
| decoded_ids = [] | ||
| for id_ in ids: | ||
| if 0 <= id_ < self._num_reserved_ids: | ||
| decoded_ids.append(RESERVED_TOKENS[int(id_)]) | ||
| else: | ||
| decoded_ids.append(id_ - self._num_reserved_ids) | ||
| return "".join([chr(d) for d in decoded_ids]) | ||
|
|
||
|
|
||
|
|
||
| @registry.register_audio_modality | ||
| class LibrispeechModality(modality.Modality): | ||
| """Performs strided conv compressions for audio spectral data.""" | ||
|
|
||
| def bottom(self, inputs): | ||
| """Transform input from data space to model space. | ||
| Args: | ||
| inputs: A Tensor with shape [batch, ...] | ||
| Returns: | ||
| body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. | ||
| """ | ||
| with tf.variable_scope(self.name): | ||
| # TODO(aidangomez): Will need to sort out a better audio pipeline | ||
| def xnet_resblock(x, filters, res_relu, name): | ||
| with tf.variable_scope(name): | ||
| # We only stride along the length dimension to preserve the spectral | ||
| # bins (which are tiny in dimensionality relative to length) | ||
| y = common_layers.separable_conv_block( | ||
| x, | ||
| filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], | ||
| first_relu=True, | ||
| padding="SAME", | ||
| force2d=True, | ||
| name="sep_conv_block") | ||
| y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1)) | ||
| return y + common_layers.conv_block( | ||
| x, | ||
| filters, [((1, 1), (1, 1))], | ||
| padding="SAME", | ||
| strides=(2, 1), | ||
| first_relu=res_relu, | ||
| force2d=True, | ||
| name="res_conv0") | ||
|
|
||
| # Rescale from UINT8 to floats in [-1,-1] | ||
| signals = (tf.to_float(inputs)-127)/128. | ||
| #signals = tf.contrib.framework.nest.flatten(signals) | ||
| signals = tf.squeeze(signals, [2, 3]) | ||
|
|
||
| # `stfts` is a complex64 Tensor representing the Short-time Fourier Transform of | ||
| # each signal in `signals`. Its shape is [batch_size, ?, fft_unique_bins] | ||
| # where fft_unique_bins = fft_length // 2 + 1 = 513. | ||
| stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512, | ||
| fft_length=1024) | ||
|
|
||
| # An energy spectrogram is the magnitude of the complex-valued STFT. | ||
| # A float32 Tensor of shape [batch_size, ?, 513]. | ||
| magnitude_spectrograms = tf.abs(stfts) | ||
|
|
||
| log_offset = 1e-6 | ||
| log_magnitude_spectrograms = tf.log(magnitude_spectrograms + log_offset) | ||
|
|
||
| # Warp the linear-scale, magnitude spectrograms into the mel-scale. | ||
| num_spectrogram_bins = magnitude_spectrograms.shape[-1].value | ||
| lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64 | ||
| sample_rate = 16000 | ||
| linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( | ||
| num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, | ||
| upper_edge_hertz) | ||
| mel_spectrograms = tf.tensordot( | ||
| magnitude_spectrograms, linear_to_mel_weight_matrix, 1) | ||
| # Note: Shape inference for `tf.tensordot` does not currently handle this case. | ||
| mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( | ||
| linear_to_mel_weight_matrix.shape[-1:])) | ||
|
|
||
| # Try without the conversion to MFCCs, first. | ||
| '''num_mfccs = 13 | ||
| # Keep the first `num_mfccs` MFCCs. | ||
| mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms( | ||
| log_mel_spectrograms)[..., :num_mfccs]''' | ||
|
|
||
| x = tf.expand_dims(mel_spectrograms, 2) | ||
| x.set_shape([None, None, None, num_mel_bins]) | ||
| for i in xrange(self._model_hparams.audio_compression): | ||
| x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i) | ||
| return xnet_resblock(x, self._body_input_depth, False, | ||
| "compress_block_final") | ||
|
|
||
|
|
||
| @registry.register_problem() | ||
| class Librispeech(problem.Problem): | ||
| """Problem spec for English word to dictionary definition.""" | ||
|
|
||
| @property | ||
| def is_character_level(self): | ||
| return True | ||
|
|
||
| @property | ||
| def input_space_id(self): | ||
| return problem.SpaceID.AUDIO_SPECTRAL | ||
|
|
||
| @property | ||
| def target_space_id(self): | ||
| return problem.SpaceID.EN_CHR | ||
|
|
||
| @property | ||
| def num_shards(self): | ||
| return 100 | ||
|
|
||
| @property | ||
| def use_subword_tokenizer(self): | ||
| return False | ||
|
|
||
| @property | ||
| def num_dev_shards(self): | ||
| return 1 | ||
|
|
||
| @property | ||
| def use_train_shards_for_dev(self): | ||
| """If true, we only generate training data and hold out shards for dev.""" | ||
| return False | ||
|
|
||
| def feature_encoders(self, _): | ||
| return { | ||
| "inputs": text_encoder.TextEncoder(), | ||
| "targets": LibrispeechTextEncoder(), | ||
| } | ||
|
|
||
| def example_reading_spec(self): | ||
| data_fields = { | ||
| "inputs": tf.VarLenFeature(tf.int64), | ||
| #"audio/channel_count": tf.FixedLenFeature([], tf.int64), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be reserved!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sorry, I don't understand. What are you suggesting? |
||
| #"audio/sample_count": tf.FixedLenFeature([], tf.int64), | ||
| #"audio/sample_width": tf.FixedLenFeature([], tf.int64), | ||
| "targets": tf.VarLenFeature(tf.int64), | ||
| } | ||
| data_items_to_decoders = None | ||
| return (data_fields, data_items_to_decoders) | ||
|
|
||
|
|
||
| def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0): | ||
| eos_list = [1] if eos_list is None else eos_list | ||
| datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS) | ||
| num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids | ||
| i = 0 | ||
| for url, subdir in datasets: | ||
| filename = os.path.basename(url) | ||
| compressed_file = generator_utils.maybe_download(tmp_dir, filename, url) | ||
|
|
||
| read_type = "r:gz" if filename.endswith("tgz") else "r" | ||
| with tarfile.open(compressed_file, read_type) as corpus_tar: | ||
| # Create a subset of files that don't already exist. | ||
| # tarfile.extractall errors when encountering an existing file | ||
| # and tarfile.extract is extremely slow | ||
| members = [] | ||
| for f in corpus_tar: | ||
| if not os.path.isfile(os.path.join(tmp_dir, f.name)): | ||
| members.append(f) | ||
| corpus_tar.extractall(tmp_dir, members=members) | ||
|
|
||
| data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir) | ||
| data_files = _collect_data(data_dir, "flac", "txt") | ||
| data_pairs = data_files.values() | ||
| for media_file, text_data in sorted(data_pairs)[start_from:]: | ||
| if how_many > 0 and i == how_many: | ||
| return | ||
| i += 1 | ||
| audio_data, sample_count, sample_width, num_channels = _get_audio_data( | ||
| media_file) | ||
| label = [num_reserved_ids + ord(c) for c in text_data] + eos_list | ||
| yield { | ||
| "inputs": audio_data, | ||
| "audio/channel_count": [num_channels], | ||
| "audio/sample_count": [sample_count], | ||
| "audio/sample_width": [sample_width], | ||
| "targets": label | ||
| } | ||
|
|
||
|
|
||
| def generate_data(self, data_dir, tmp_dir, task_id=-1): | ||
| train_paths = self.training_filepaths(data_dir, self.num_shards, shuffled=False) | ||
| dev_paths = self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False) | ||
| if self.use_train_shards_for_dev: | ||
| all_paths = train_paths + dev_paths | ||
| generator_utils.generate_files(self.generator(data_dir, tmp_dir, True), all_paths) | ||
| generator_utils.shuffle_dataset(all_paths) | ||
| else: | ||
| generator_utils.generate_dataset_and_shuffle( | ||
| self.generator(data_dir, tmp_dir, True), train_paths, | ||
| self.generator(data_dir, tmp_dir, False), dev_paths) | ||
|
|
||
|
|
||
| def hparams(self, defaults, unused_model_hparams): | ||
| p = defaults | ||
| p.stop_at_eos = int(False) | ||
| p.input_modality = { "inputs": ("audio:librispeech_modality", None) } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. registry.Modalities.AUDIO
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't that result in the base Audio modality being used, and bypass the custom signal processing added to |
||
| p.target_modality = (registry.Modalities.SYMBOL, 256) | ||
|
|
||
| def preprocess_example(self, example, mode, hparams): | ||
| return example | ||
|
|
||
| # TODO: clean up hparams | ||
| @registry.register_hparams | ||
| def librispeech_hparams(): | ||
| hparams = transformer.transformer_base_single_gpu() # Or whatever you'd like to build off. | ||
| hparams.batch_size = 36 | ||
| hparams.audio_compression = 8 | ||
| hparams.hidden_size = 2048 | ||
| hparams.max_input_seq_length = 600000 | ||
| hparams.max_target_seq_length = 350 | ||
| hparams.max_length = hparams.max_input_seq_length | ||
| hparams.min_length_bucket = hparams.max_input_seq_length // 2 | ||
| hparams.learning_rate = 0.05 | ||
| hparams.train_steps = 5000000 | ||
| hparams.num_hidden_layers = 4 | ||
| return hparams | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformer_base_single_gpu, does this used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand the question, but I used transformer_base_single_gpu as the basis for the hparams because that was what was referenced in all of the examples:
https://github.com/tensorflow/tensor2tensor/blob/master/docs/new_problem.md
https://github.com/tensorflow/tensor2tensor/blob/master/docs/walkthrough.md
https://github.com/tensorflow/tensor2tensor/blob/master/README.md