From 0c026d294d40cd131f3e3e2ecce4df02ab661143 Mon Sep 17 00:00:00 2001 From: wingsbr Date: Tue, 14 Nov 2017 09:23:31 -0600 Subject: [PATCH 1/6] Added a librispeech data generator. --- tensor2tensor/bin/t2t-datagen | 6 + tensor2tensor/data_generators/librispeech.py | 109 +++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 tensor2tensor/data_generators/librispeech.py diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 2ac0f0db2..b8a1027f3 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -44,6 +44,7 @@ from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import librispeech from tensor2tensor.utils import registry from tensor2tensor.utils import usr_dir @@ -113,6 +114,11 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: audio.timit_generator( FLAGS.data_dir, FLAGS.tmp_dir, False, 626, vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), + "librispeech": ( + lambda: librispeech.librispeech_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True), + lambda: librispeech.librispeech_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False)), } # pylint: enable=g-long-lambda diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py new file mode 100644 index 000000000..82b032c35 --- /dev/null +++ b/tensor2tensor/data_generators/librispeech.py @@ -0,0 +1,109 @@ +import os +from subprocess import call +import tarfile +import wave +import numpy as np +import six +from tensor2tensor.data_generators import generator_utils + +_LIBRISPEECH_TRAIN_DATASETS = [ + [ + "http://www.openslr.org/resources/12/train-clean-100.tar.gz", # pylint: disable=line-too-long + "train-clean-100" + ], + [ + "http://www.openslr.org/resources/12/train-clean-360.tar.gz", + "train-clean-360" + ], + [ + "http://www.openslr.org/resources/12/train-other-500.tar.gz", + "train-other-500" + ], +] +_LIBRISPEECH_TEST_DATASETS = [ + [ + "http://www.openslr.org/resources/12/dev-clean.tar.gz", + "dev-clean" + ], + [ + "http://www.openslr.org/resources/12/dev-other.tar.gz", + "dev-other" + ], +] + + +def _collect_data(directory, input_ext, transcription_ext): + """Traverses directory collecting input and target files.""" + # Directory from string to tuple pair of strings + # key: the filepath to a datafile including the datafile's basename. Example, + # if the datafile was "/path/to/datafile.wav" then the key would be + # "/path/to/datafile" + # value: a pair of strings (media_filepath, label) + data_files = dict() + for root, _, filenames in os.walk(directory): + transcripts = [filename for filename in filenames if transcription_ext in filename] + for transcript in transcripts: + basename = transcript.strip(transcription_ext) + transcript_path = os.path.join(root, transcript) + with open(transcript_path, 'r') as transcript_file: + for transcript_line in transcript_file: + line_contents = transcript_line.split(" ", 1) + assert len(line_contents) == 2 + media_base, label = line_contents + key = os.path.join(root, media_base) + assert key not in data_files + media_name = "%s.%s"%(media_base, input_ext) + media_path = os.path.join(root, media_name) + data_files[key] = (media_path, label) + return data_files + + +def _get_audio_data(filepath): + # Construct a true .wav file. + out_filepath = filepath.strip(".flac") + ".wav" + # Assumes sox is installed on system. Sox converts from FLAC to WAV. + call(["sox", filepath, out_filepath]) + wav_file = wave.open(open(out_filepath)) + frame_count = wav_file.getnframes() + byte_array = wav_file.readframes(frame_count) + + data = np.fromstring(byte_array, np.uint8).tolist() + return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels() + + +def librispeech_generator(data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0): + eos_list = [1] if eos_list is None else eos_list + datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS) + i = 0 + for url, subdir in datasets: + filename = os.path.basename(url) + compressed_file = generator_utils.maybe_download(tmp_dir, filename, url) + + read_type = "r:gz" if filename.endswith("tgz") else "r" + with tarfile.open(compressed_file, read_type) as corpus_tar: + # Create a subset of files that don't already exist. + # tarfile.extractall errors when encountering an existing file + # and tarfile.extract is extremely slow + members = [] + for f in corpus_tar: + if not os.path.isfile(os.path.join(tmp_dir, f.name)): + members.append(f) + corpus_tar.extractall(tmp_dir, members=members) + + data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir) + data_files = _collect_data(data_dir, "flac", "txt") + data_pairs = data_files.values() + for media_file, text_data in sorted(data_pairs)[start_from:]: + if how_many > 0 and i == how_many: + return + i += 1 + audio_data, sample_count, sample_width, num_channels = _get_audio_data( + media_file) + label = [ord(c) for c in text_data] + eos_list + yield { + "inputs": audio_data, + "audio/channel_count": [num_channels], + "audio/sample_count": [sample_count], + "audio/sample_width": [sample_width], + "targets": label + } \ No newline at end of file From 75ec0f6e9950bb5e76cf897b0e7e4e61fca5a0e4 Mon Sep 17 00:00:00 2001 From: wingsbr Date: Tue, 14 Nov 2017 09:36:30 -0600 Subject: [PATCH 2/6] . --- tensor2tensor/bin/t2t-datagen | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index b8a1027f3..e9eca3672 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -118,7 +118,7 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: librispeech.librispeech_generator( FLAGS.data_dir, FLAGS.tmp_dir, True), lambda: librispeech.librispeech_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False)), + FLAGS.data_dir, FLAGS.tmp_dir, False)), } # pylint: enable=g-long-lambda From 5365113cc17db280974f7c80e8c6847aec235fe8 Mon Sep 17 00:00:00 2001 From: wingsbr Date: Mon, 20 Nov 2017 16:32:52 -0600 Subject: [PATCH 3/6] Expanded to include librispeech Problem and Modality. --- tensor2tensor/bin/t2t-datagen | 8 +- tensor2tensor/data_generators/librispeech.py | 294 ++++++++++++++++--- 2 files changed, 254 insertions(+), 48 deletions(-) mode change 100644 => 100755 tensor2tensor/bin/t2t-datagen diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100644 new mode 100755 index e9eca3672..67890371b --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -44,7 +44,6 @@ from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wsj_parsing -from tensor2tensor.data_generators import librispeech from tensor2tensor.utils import registry from tensor2tensor.utils import usr_dir @@ -113,12 +112,7 @@ _SUPPORTED_PROBLEM_GENERATORS = { vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15), lambda: audio.timit_generator( FLAGS.data_dir, FLAGS.tmp_dir, False, 626, - vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), - "librispeech": ( - lambda: librispeech.librispeech_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True), - lambda: librispeech.librispeech_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False)), + vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), } # pylint: enable=g-long-lambda diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index 82b032c35..dcb5b3f88 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -1,12 +1,20 @@ +from tensor2tensor.data_generators import problem +from tensor2tensor.utils import registry +from tensor2tensor.models import transformer +from tensor2tensor.utils import modality +from tensor2tensor.layers import common_layers +from tensor2tensor.data_generators import text_encoder +import random +import tensorflow as tf +import numpy as np +from tensor2tensor.data_generators import generator_utils import os from subprocess import call import tarfile import wave -import numpy as np -import six -from tensor2tensor.data_generators import generator_utils + -_LIBRISPEECH_TRAIN_DATASETS = [ +'''_LIBRISPEECH_TRAIN_DATASETS = [ [ "http://www.openslr.org/resources/12/train-clean-100.tar.gz", # pylint: disable=line-too-long "train-clean-100" @@ -29,6 +37,18 @@ "http://www.openslr.org/resources/12/dev-other.tar.gz", "dev-other" ], +]''' +_LIBRISPEECH_TRAIN_DATASETS = [ + [ + "http://www.openslr.org/resources/12/dev-other.tar.gz", + "dev-other" + ], +] +_LIBRISPEECH_TEST_DATASETS = [ + [ + "http://www.openslr.org/resources/12/dev-clean.tar.gz", + "dev-clean" + ], ] @@ -69,41 +89,233 @@ def _get_audio_data(filepath): data = np.fromstring(byte_array, np.uint8).tolist() return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels() - - -def librispeech_generator(data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0): - eos_list = [1] if eos_list is None else eos_list - datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS) - i = 0 - for url, subdir in datasets: - filename = os.path.basename(url) - compressed_file = generator_utils.maybe_download(tmp_dir, filename, url) - - read_type = "r:gz" if filename.endswith("tgz") else "r" - with tarfile.open(compressed_file, read_type) as corpus_tar: - # Create a subset of files that don't already exist. - # tarfile.extractall errors when encountering an existing file - # and tarfile.extract is extremely slow - members = [] - for f in corpus_tar: - if not os.path.isfile(os.path.join(tmp_dir, f.name)): - members.append(f) - corpus_tar.extractall(tmp_dir, members=members) + + +class LibrispeechTextEncoder(text_encoder.TextEncoder): + + def encode(self, s): + return [ord[c] for c in s] + + def decode(self, ids): + """Transform a sequence of int ids into a human-readable string. + EOS is not expected in ids. + Args: + ids: list of integers to be converted. + Returns: + s: human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_) + return "".join([chr(d) for d in decoded_ids]) + + + +@registry.register_audio_modality +class LibrispeechModality(modality.Modality): + """Performs strided conv compressions for audio spectral data.""" + + def bottom(self, inputs): + """Transform input from data space to model space. + Args: + inputs: A Tensor with shape [batch, ...] + Returns: + body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. + """ + with tf.variable_scope(self.name): + # TODO(aidangomez): Will need to sort out a better audio pipeline + def xnet_resblock(x, filters, res_relu, name): + with tf.variable_scope(name): + # We only stride along the length dimension to preserve the spectral + # bins (which are tiny in dimensionality relative to length) + y = common_layers.separable_conv_block( + x, + filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], + first_relu=True, + padding="SAME", + force2d=True, + name="sep_conv_block") + y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1)) + return y + common_layers.conv_block( + x, + filters, [((1, 1), (1, 1))], + padding="SAME", + strides=(2, 1), + first_relu=res_relu, + force2d=True, + name="res_conv0") + + # Rescale from UINT8 to floats in [-1,-1] + signals = (tf.to_float(inputs)-127)/128. + #signals = tf.contrib.framework.nest.flatten(signals) + signals = tf.squeeze(signals, [2, 3]) + + # `stfts` is a complex64 Tensor representing the Short-time Fourier Transform of + # each signal in `signals`. Its shape is [batch_size, ?, fft_unique_bins] + # where fft_unique_bins = fft_length // 2 + 1 = 513. + stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512, + fft_length=1024) + + # An energy spectrogram is the magnitude of the complex-valued STFT. + # A float32 Tensor of shape [batch_size, ?, 513]. + magnitude_spectrograms = tf.abs(stfts) + + log_offset = 1e-6 + log_magnitude_spectrograms = tf.log(magnitude_spectrograms + log_offset) + + # Warp the linear-scale, magnitude spectrograms into the mel-scale. + num_spectrogram_bins = magnitude_spectrograms.shape[-1].value + lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64 + sample_rate = 16000 + linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( + num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, + upper_edge_hertz) + mel_spectrograms = tf.tensordot( + magnitude_spectrograms, linear_to_mel_weight_matrix, 1) + # Note: Shape inference for `tf.tensordot` does not currently handle this case. + mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( + linear_to_mel_weight_matrix.shape[-1:])) + + # Try without the conversion to MFCCs, first. + '''num_mfccs = 13 + # Keep the first `num_mfccs` MFCCs. + mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms( + log_mel_spectrograms)[..., :num_mfccs]''' + + x = tf.expand_dims(mel_spectrograms, 2) + x.set_shape([None, None, None, num_mel_bins]) + for i in xrange(self._model_hparams.audio_compression): + x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i) + return xnet_resblock(x, self._body_input_depth, False, + "compress_block_final") + + +@registry.register_problem() +class Librispeech(problem.Problem): + """Problem spec for English word to dictionary definition.""" + + @property + def is_character_level(self): + return True + + @property + def input_space_id(self): + return problem.SpaceID.AUDIO_SPECTRAL + + @property + def target_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def num_shards(self): + return 10 + + @property + def use_subword_tokenizer(self): + return False + + @property + def num_dev_shards(self): + return 1 + + @property + def use_train_shards_for_dev(self): + """If true, we only generate training data and hold out shards for dev.""" + return False + + def feature_encoders(self, data_dir): + return { + "inputs": text_encoder.TextEncoder(), #None, #DoNothingEncoder(), + "targets": LibrispeechTextEncoder(), + } + + def example_reading_spec(self): + data_fields = { + "inputs": tf.VarLenFeature(tf.int64), + #"audio/channel_count": tf.FixedLenFeature([], tf.int64), + #"audio/sample_count": tf.FixedLenFeature([], tf.int64), + #"audio/sample_width": tf.FixedLenFeature([], tf.int64), + "targets": tf.VarLenFeature(tf.int64), + } + data_items_to_decoders = None + return (data_fields, data_items_to_decoders) + + + def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0): + eos_list = [1] + datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS) + i = 0 + for url, subdir in datasets: + filename = os.path.basename(url) + compressed_file = generator_utils.maybe_download(tmp_dir, filename, url) + + read_type = "r:gz" if filename.endswith("tgz") else "r" + with tarfile.open(compressed_file, read_type) as corpus_tar: + # Create a subset of files that don't already exist. + # tarfile.extractall errors when encountering an existing file + # and tarfile.extract is extremely slow + members = [] + for f in corpus_tar: + if not os.path.isfile(os.path.join(tmp_dir, f.name)): + members.append(f) + corpus_tar.extractall(tmp_dir, members=members) - data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir) - data_files = _collect_data(data_dir, "flac", "txt") - data_pairs = data_files.values() - for media_file, text_data in sorted(data_pairs)[start_from:]: - if how_many > 0 and i == how_many: - return - i += 1 - audio_data, sample_count, sample_width, num_channels = _get_audio_data( - media_file) - label = [ord(c) for c in text_data] + eos_list - yield { - "inputs": audio_data, - "audio/channel_count": [num_channels], - "audio/sample_count": [sample_count], - "audio/sample_width": [sample_width], - "targets": label - } \ No newline at end of file + data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir) + data_files = _collect_data(data_dir, "flac", "txt") + data_pairs = data_files.values() + for media_file, text_data in sorted(data_pairs)[start_from:]: + if how_many > 0 and i == how_many: + return + i += 1 + audio_data, sample_count, sample_width, num_channels = _get_audio_data( + media_file) + label = [ord(c) for c in text_data] + eos_list + yield { + "inputs": audio_data, + "audio/channel_count": [num_channels], + "audio/sample_count": [sample_count], + "audio/sample_width": [sample_width], + "targets": label + } + + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + train_paths = self.training_filepaths(data_dir, self.num_shards, shuffled=False) + dev_paths = self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False) + if self.use_train_shards_for_dev: + all_paths = train_paths + dev_paths + generator_utils.generate_files(self.generator(data_dir, tmp_dir, True), all_paths) + generator_utils.shuffle_dataset(all_paths) + else: + generator_utils.generate_dataset_and_shuffle( + self.generator(data_dir, tmp_dir, True), train_paths, + self.generator(data_dir, tmp_dir, False), dev_paths) + + + def hparams(self, defaults, unused_model_hparams): + p = defaults + p.stop_at_eos = int(False) + p.input_modality = { "inputs": ("audio:librispeech_modality", None) } + p.target_modality = (registry.Modalities.SYMBOL, 256) + + def preprocess_example(self, example, mode, hparams): + return example + +# TODO: clean up hparams +@registry.register_hparams +def librispeech_hparams(): + hparams = transformer.transformer_base_single_gpu() # Or whatever you'd like to build off. + hparams.batch_size = 36 + hparams.audio_compression = 8 + hparams.hidden_size = 2048 + hparams.max_input_seq_length = 600000 + hparams.max_target_seq_length = 350 + hparams.max_length = hparams.max_input_seq_length + hparams.min_length_bucket = hparams.max_input_seq_length // 2 + hparams.learning_rate = 0.05 + hparams.train_steps = 5000000 + hparams.num_hidden_layers = 4 + return hparams From 844df4d0172b3df5fac50dd364b15dc08b6a393f Mon Sep 17 00:00:00 2001 From: wingsbr Date: Mon, 20 Nov 2017 16:36:01 -0600 Subject: [PATCH 4/6] Added librispeech to data_generators/all_problems.py --- tensor2tensor/data_generators/all_problems.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index c7f364cf1..2aca3d377 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -28,6 +28,7 @@ from tensor2tensor.data_generators import ice_parsing from tensor2tensor.data_generators import image from tensor2tensor.data_generators import imdb +from tensor2tensor.data_generators import librispeech from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import multinli from tensor2tensor.data_generators import problem_hparams From 98c7b413e3fa6a18faf262c30b3eed3a9359d085 Mon Sep 17 00:00:00 2001 From: wingsbr Date: Tue, 21 Nov 2017 09:07:49 -0600 Subject: [PATCH 5/6] Switched to full librispeech datasets. --- tensor2tensor/data_generators/librispeech.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index dcb5b3f88..5e83cfd51 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -14,7 +14,7 @@ import wave -'''_LIBRISPEECH_TRAIN_DATASETS = [ +_LIBRISPEECH_TRAIN_DATASETS = [ [ "http://www.openslr.org/resources/12/train-clean-100.tar.gz", # pylint: disable=line-too-long "train-clean-100" @@ -37,18 +37,6 @@ "http://www.openslr.org/resources/12/dev-other.tar.gz", "dev-other" ], -]''' -_LIBRISPEECH_TRAIN_DATASETS = [ - [ - "http://www.openslr.org/resources/12/dev-other.tar.gz", - "dev-other" - ], -] -_LIBRISPEECH_TEST_DATASETS = [ - [ - "http://www.openslr.org/resources/12/dev-clean.tar.gz", - "dev-clean" - ], ] From 23129f238b5abeecca38790215e272b31913cdb5 Mon Sep 17 00:00:00 2001 From: wingsbr Date: Tue, 21 Nov 2017 16:25:58 -0600 Subject: [PATCH 6/6] Variety of fixes based on PR comments. --- tensor2tensor/data_generators/librispeech.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index 5e83cfd51..de7ed94cc 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -82,7 +82,7 @@ def _get_audio_data(filepath): class LibrispeechTextEncoder(text_encoder.TextEncoder): def encode(self, s): - return [ord[c] for c in s] + return [self._num_reserved_ids + ord(c) for c in s] def decode(self, ids): """Transform a sequence of int ids into a human-readable string. @@ -97,7 +97,7 @@ def decode(self, ids): if 0 <= id_ < self._num_reserved_ids: decoded_ids.append(RESERVED_TOKENS[int(id_)]) else: - decoded_ids.append(id_) + decoded_ids.append(id_ - self._num_reserved_ids) return "".join([chr(d) for d in decoded_ids]) @@ -199,7 +199,7 @@ def target_space_id(self): @property def num_shards(self): - return 10 + return 100 @property def use_subword_tokenizer(self): @@ -214,9 +214,9 @@ def use_train_shards_for_dev(self): """If true, we only generate training data and hold out shards for dev.""" return False - def feature_encoders(self, data_dir): + def feature_encoders(self, _): return { - "inputs": text_encoder.TextEncoder(), #None, #DoNothingEncoder(), + "inputs": text_encoder.TextEncoder(), "targets": LibrispeechTextEncoder(), } @@ -233,8 +233,9 @@ def example_reading_spec(self): def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0): - eos_list = [1] + eos_list = [1] if eos_list is None else eos_list datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS) + num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids i = 0 for url, subdir in datasets: filename = os.path.basename(url) @@ -260,7 +261,7 @@ def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, ho i += 1 audio_data, sample_count, sample_width, num_channels = _get_audio_data( media_file) - label = [ord(c) for c in text_data] + eos_list + label = [num_reserved_ids + ord(c) for c in text_data] + eos_list yield { "inputs": audio_data, "audio/channel_count": [num_channels],