/
train.py
393 lines (335 loc) · 14.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#!/usr/bin/env/python3
"""Recipe for training a sequence-to-sequence ASR system with librispeech.
The system employs an encoder, a decoder, and an attention mechanism
between them. Decoding is performed with beamsearch coupled with a neural
language model.
To run this recipe, do the following:
> python train.py hparams/train_BPE1000.yaml
With the default hyperparameters, the system employs a CRDNN encoder.
The decoder is based on a standard GRU. Beamsearch coupled with a RNN
language model is used on the top of decoder probabilities.
The neural network is trained on both CTC and negative-log likelihood
targets and sub-word units estimated with Byte Pairwise Encoding (BPE)
are used as basic recognition tokens. Training is performed on the full
LibriSpeech dataset (960 h).
The experiment file is flexible enough to support a large variety of
different systems. By properly changing the parameter files, you can try
different encoders, decoders, tokens (e.g, characters instead of BPE),
training split (e.g, train-clean 100 rather than the full one), and many
other possible variations.
This recipe assumes that the tokenizer and the LM are already trained.
To avoid token mismatches, the tokenizer used for the acoustic model is
the same use for the LM. The recipe downloads the pre-trained tokenizer
and LM.
If you would like to train a full system from scratch do the following:
1- Train a tokenizer (see ../../Tokenizer)
2- Train a language model (see ../../LM)
3- Train the acoustic model (with this code).
Authors
* Ju-Chieh Chou 2020
* Mirco Ravanelli 2020
* Abdel Heba 2020
* Peter Plantinga 2020
* Samuele Cornell 2020
* Andreas Nautsch 2021
"""
import logging
import sys
from pathlib import Path
import torch
from hyperpyyaml import load_hyperpyyaml
import speechbrain as sb
from speechbrain.utils.distributed import if_main_process, run_on_main
logger = logging.getLogger(__name__)
# Define training procedure
class ASR(sb.Brain):
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the output probabilities."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
tokens_bos, _ = batch.tokens_bos
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
# Add waveform augmentation if specified.
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
tokens_bos = self.hparams.wav_augment.replicate_labels(tokens_bos)
# Forward pass
feats = self.hparams.compute_features(wavs)
feats = self.modules.normalize(feats, wav_lens)
x = self.modules.enc(feats.detach())
e_in = self.modules.emb(tokens_bos) # y_in bos + tokens
h, _ = self.modules.dec(e_in, x, wav_lens)
# Output layer for seq2seq log-probabilities
logits = self.modules.seq_lin(h)
p_seq = self.hparams.log_softmax(logits)
# Compute outputs
p_ctc, p_tokens = None, None
if stage == sb.Stage.TRAIN:
current_epoch = self.hparams.epoch_counter.current
if current_epoch <= self.hparams.number_of_ctc_epochs:
# Output layer for ctc log-probabilities
logits = self.modules.ctc_lin(x)
p_ctc = self.hparams.log_softmax(logits)
else:
if stage == sb.Stage.VALID:
# Get token strings from index prediction
p_tokens, _, _, _ = self.hparams.valid_search(x, wav_lens)
else:
p_tokens, _, _, _ = self.hparams.test_search(x, wav_lens)
return p_ctc, p_seq, wav_lens, p_tokens
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss (CTC+NLL) given predictions and targets."""
current_epoch = self.hparams.epoch_counter.current
p_ctc, p_seq, wav_lens, predicted_tokens = predictions
ids = batch.id
tokens_eos, tokens_eos_lens = batch.tokens_eos
tokens, tokens_lens = batch.tokens
# Labels must be extended if parallel augmentation or concatenated
# augmentation was performed on the input (increasing the time dimension)
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
(
tokens,
tokens_lens,
tokens_eos,
tokens_eos_lens,
) = self.hparams.wav_augment.replicate_multiple_labels(
tokens, tokens_lens, tokens_eos, tokens_eos_lens
)
loss_seq = self.hparams.seq_cost(
p_seq, tokens_eos, length=tokens_eos_lens
)
# Add ctc loss if necessary
if (
stage == sb.Stage.TRAIN
and current_epoch <= self.hparams.number_of_ctc_epochs
):
loss_ctc = self.hparams.ctc_cost(
p_ctc, tokens, wav_lens, tokens_lens
)
loss = self.hparams.ctc_weight * loss_ctc
loss += (1 - self.hparams.ctc_weight) * loss_seq
else:
loss = loss_seq
if stage != sb.Stage.TRAIN:
# Decode token terms to words
predicted_words = [
self.tokenizer.decode_ids(utt_seq).split(" ")
for utt_seq in predicted_tokens
]
target_words = [wrd.split(" ") for wrd in batch.wrd]
self.wer_metric.append(ids, predicted_words, target_words)
self.cer_metric.append(ids, predicted_words, target_words)
return loss
def on_stage_start(self, stage, epoch):
"""Gets called at the beginning of each epoch"""
if stage != sb.Stage.TRAIN:
self.cer_metric = self.hparams.cer_computer()
self.wer_metric = self.hparams.error_rate_computer()
def on_stage_end(self, stage, stage_loss, epoch):
"""Gets called at the end of a epoch."""
# Compute/store important stats
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["CER"] = self.cer_metric.summarize("error_rate")
stage_stats["WER"] = self.wer_metric.summarize("error_rate")
# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID:
old_lr, new_lr = self.hparams.lr_annealing(stage_stats["WER"])
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
self.hparams.train_logger.log_stats(
stats_meta={"epoch": epoch, "lr": old_lr},
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"WER": stage_stats["WER"]},
min_keys=["WER"],
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
if if_main_process():
with open(self.hparams.test_wer_file, "w") as w:
self.wer_metric.write_stats(w)
def dataio_prepare(hparams):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined functions.
"""
data_folder = hparams["data_folder"]
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["train_csv"],
replacements={"data_root": data_folder},
)
if hparams["sorting"] == "ascending":
# we sort training data to speed up training and get better results.
train_data = train_data.filtered_sorted(sort_key="duration")
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["train_dataloader_opts"]["shuffle"] = False
elif hparams["sorting"] == "descending":
train_data = train_data.filtered_sorted(
sort_key="duration", reverse=True
)
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["train_dataloader_opts"]["shuffle"] = False
elif hparams["sorting"] == "random":
pass
else:
raise NotImplementedError(
"sorting must be random, ascending or descending"
)
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["valid_csv"],
replacements={"data_root": data_folder},
)
valid_data = valid_data.filtered_sorted(sort_key="duration")
# test is separate
test_datasets = {}
for csv_file in hparams["test_csv"]:
name = Path(csv_file).stem
test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=csv_file, replacements={"data_root": data_folder}
)
test_datasets[name] = test_datasets[name].filtered_sorted(
sort_key="duration"
)
datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
# We get the tokenizer as we need it to encode the labels when creating
# mini-batches.
tokenizer = hparams["tokenizer"]
# 2. Define audio pipeline:
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
sig = sb.dataio.dataio.read_audio(wav)
return sig
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
# 3. Define text pipeline:
@sb.utils.data_pipeline.takes("wrd")
@sb.utils.data_pipeline.provides(
"wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
)
def text_pipeline(wrd):
yield wrd
tokens_list = tokenizer.encode_as_ids(wrd)
yield tokens_list
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
yield tokens_bos
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
yield tokens_eos
tokens = torch.LongTensor(tokens_list)
yield tokens
sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
# 4. Set output:
sb.dataio.dataset.set_output_keys(
datasets,
["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"],
)
train_batch_sampler = None
valid_batch_sampler = None
if hparams["dynamic_batching"]:
from speechbrain.dataio.batch import PaddedBatch # noqa
from speechbrain.dataio.dataloader import SaveableDataLoader # noqa
from speechbrain.dataio.sampler import DynamicBatchSampler # noqa
dynamic_hparams = hparams["dynamic_batch_sampler"]
hop_size = hparams["feats_hop_size"]
train_batch_sampler = DynamicBatchSampler(
train_data,
length_func=lambda x: x["duration"] * (1 / hop_size),
**dynamic_hparams,
)
valid_batch_sampler = DynamicBatchSampler(
valid_data,
length_func=lambda x: x["duration"] * (1 / hop_size),
**dynamic_hparams,
)
return (
train_data,
valid_data,
test_datasets,
train_batch_sampler,
valid_batch_sampler,
)
if __name__ == "__main__":
# CLI:
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# Dataset prep (parsing Librispeech)
from librispeech_prepare import prepare_librispeech # noqa
# multi-gpu (ddp) save data preparation
run_on_main(
prepare_librispeech,
kwargs={
"data_folder": hparams["data_folder"],
"tr_splits": hparams["train_splits"],
"dev_splits": hparams["dev_splits"],
"te_splits": hparams["test_splits"],
"save_folder": hparams["output_folder"],
"merge_lst": hparams["train_splits"],
"merge_name": "train.csv",
"skip_prep": hparams["skip_prep"],
},
)
run_on_main(hparams["prepare_noise_data"])
# here we create the datasets objects as well as tokenization and encoding
(
train_data,
valid_data,
test_datasets,
train_bsampler,
valid_bsampler,
) = dataio_prepare(hparams)
# We download the pretrained LM from HuggingFace (or elsewhere depending on
# the path given in the YAML file). The tokenizer is loaded at the same time.
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected()
# Trainer initialization
asr_brain = ASR(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# We dynamically add the tokenizer to our brain class.
# NB: This tokenizer corresponds to the one used for the LM!!
asr_brain.tokenizer = hparams["tokenizer"]
train_dataloader_opts = hparams["train_dataloader_opts"]
valid_dataloader_opts = hparams["valid_dataloader_opts"]
if train_bsampler is not None:
train_dataloader_opts = {"batch_sampler": train_bsampler}
if valid_bsampler is not None:
valid_dataloader_opts = {"batch_sampler": valid_bsampler}
# Training
asr_brain.fit(
asr_brain.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=train_dataloader_opts,
valid_loader_kwargs=valid_dataloader_opts,
)
import os
# Testing
if not os.path.exists(hparams["output_wer_folder"]):
os.makedirs(hparams["output_wer_folder"])
for k in test_datasets.keys(): # keys are test_clean, test_other etc
asr_brain.hparams.test_wer_file = os.path.join(
hparams["output_wer_folder"], f"wer_{k}.txt"
)
asr_brain.evaluate(
test_datasets[k],
test_loader_kwargs=hparams["test_dataloader_opts"],
min_key="WER",
)