-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
train_with_whisper.py
329 lines (269 loc) · 11.4 KB
/
train_with_whisper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#!/usr/bin/env python3
"""Recipe for training a whisper-based ASR system with librispeech.
The system employs whisper from OpenAI (https://cdn.openai.com/papers/whisper.pdf).
This recipe take the whisper encoder-decoder to fine-tune on the NLL.
If you want to only use the whisper encoder system, please refer to the recipe
speechbrain/recipes/LibriSpeech/ASR/CTC/train_with_whisper.py
To run this recipe, do the following:
> python train_with_whisper.py hparams/train_hf_whisper.yaml
Authors
* Adel Moumen 2022
* Titouan Parcollet 2022
"""
import os
import sys
import torch
import logging
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.data_utils import undo_padding
from hyperpyyaml import load_hyperpyyaml
from pathlib import Path
logger = logging.getLogger(__name__)
# Define training procedure
class ASR(sb.Brain):
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the output probabilities."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
bos_tokens, bos_tokens_lens = batch.tokens_bos
# Add augmentation if specified
if stage == sb.Stage.TRAIN:
if hasattr(self.hparams, "augmentation"):
wavs = self.hparams.augmentation(wavs, wav_lens)
# We compute the padding mask and replace the values with the pad_token_id
# that the Whisper decoder expect to see.
abs_tokens_lens = (bos_tokens_lens * bos_tokens.shape[1]).long()
pad_mask = (
torch.arange(abs_tokens_lens.max(), device=self.device)[None, :]
< abs_tokens_lens[:, None]
)
bos_tokens[~pad_mask] = self.tokenizer.pad_token_id
# Forward encoder + decoder
enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens)
log_probs = self.hparams.log_softmax(logits)
hyps = None
if stage == sb.Stage.VALID:
hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens)
elif stage == sb.Stage.TEST:
hyps, _ = self.hparams.test_beam_searcher(enc_out, wav_lens)
return log_probs, hyps, wav_lens
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss NLL given predictions and targets."""
log_probs, hyps, wav_lens, = predictions
batch = batch.to(self.device)
ids = batch.id
tokens_eos, tokens_eos_lens = batch.tokens_eos
loss = self.hparams.nll_loss(
log_probs, tokens_eos, length=tokens_eos_lens,
)
if stage != sb.Stage.TRAIN:
tokens, tokens_lens = batch.tokens
# Decode token terms to words
predicted_words = self.tokenizer.batch_decode(
hyps, skip_special_tokens=True
)
# Convert indices to words
target_words = undo_padding(tokens, tokens_lens)
target_words = self.tokenizer.batch_decode(
target_words, skip_special_tokens=True
)
if hasattr(self.hparams, "normalized_transcripts"):
predicted_words = [
self.tokenizer._normalize(text).split(" ")
for text in predicted_words
]
target_words = [
self.tokenizer._normalize(text).split(" ")
for text in target_words
]
else:
predicted_words = [text.split(" ") for text in predicted_words]
target_words = [text.split(" ") for text in target_words]
self.wer_metric.append(ids, predicted_words, target_words)
self.cer_metric.append(ids, predicted_words, target_words)
return loss
def on_stage_start(self, stage, epoch):
"""Gets called at the beginning of each epoch"""
if stage != sb.Stage.TRAIN:
self.cer_metric = self.hparams.cer_computer()
self.wer_metric = self.hparams.error_rate_computer()
def on_stage_end(self, stage, stage_loss, epoch):
"""Gets called at the end of an epoch."""
# Compute/store important stats
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["CER"] = self.cer_metric.summarize("error_rate")
stage_stats["WER"] = self.wer_metric.summarize("error_rate")
# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID:
old_lr_whisper, new_lr_whisper = self.hparams.lr_annealing_whisper(
stage_stats["loss"]
)
sb.nnet.schedulers.update_learning_rate(
self.optimizer, new_lr_whisper
)
self.hparams.train_logger.log_stats(
stats_meta={"epoch": epoch, "lr_whisper": old_lr_whisper},
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
with open(self.hparams.wer_file, "w") as w:
self.wer_metric.write_stats(w)
def dataio_prepare(hparams, tokenizer):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined functions."""
data_folder = hparams["data_folder"]
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
)
if hparams["sorting"] == "ascending":
# we sort training data to speed up training and get better results.
train_data = train_data.filtered_sorted(sort_key="duration")
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["train_loader_kwargs"]["shuffle"] = False
elif hparams["sorting"] == "descending":
train_data = train_data.filtered_sorted(
sort_key="duration", reverse=True
)
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["train_loader_kwargs"]["shuffle"] = False
elif hparams["sorting"] == "random":
pass
else:
raise NotImplementedError(
"sorting must be random, ascending or descending"
)
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
)
valid_data = valid_data.filtered_sorted(sort_key="duration")
# test is separate
test_datasets = {}
for csv_file in hparams["test_csv"]:
name = Path(csv_file).stem
test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=csv_file, replacements={"data_root": data_folder}
)
test_datasets[name] = test_datasets[name].filtered_sorted(
sort_key="duration"
)
datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
# 2. Define audio pipeline:
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
sig = sb.dataio.dataio.read_audio(wav)
return sig
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
# 3. Define text pipeline:
@sb.utils.data_pipeline.takes("wrd")
@sb.utils.data_pipeline.provides(
"wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
)
def text_pipeline(wrd):
yield wrd
tokens_list = tokenizer.encode(wrd)
# avoid bos and eos tokens.
tokens_list = tokens_list[1:-1]
yield tokens_list
tokens_bos = torch.LongTensor([hparams["bos_index"]] + tokens_list)
yield tokens_bos
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
yield tokens_eos
tokens = torch.LongTensor(tokens_list)
yield tokens
sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
# 4. Set output:
sb.dataio.dataset.set_output_keys(
datasets,
["id", "sig", "tokens_list", "tokens_bos", "tokens_eos", "tokens"],
)
return train_data, valid_data, test_datasets
if __name__ == "__main__":
# CLI:
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# Dataset prep (parsing Librispeech)
from librispeech_prepare import prepare_librispeech # noqa
# multi-gpu (ddp) save data preparation
run_on_main(
prepare_librispeech,
kwargs={
"data_folder": hparams["data_folder"],
"tr_splits": hparams["train_splits"],
"dev_splits": hparams["dev_splits"],
"te_splits": hparams["test_splits"],
"save_folder": hparams["save_folder"],
"merge_lst": hparams["train_splits"],
"merge_name": "train.csv",
"skip_prep": hparams["skip_prep"],
},
)
# Defining tokenizer and loading it
tokenizer = hparams["whisper"].tokenizer
tokenizer.set_prefix_tokens(hparams["language"], "transcribe", False)
# we need to prepare the tokens for searchers
hparams["valid_greedy_searcher"].set_decoder_input_tokens(
tokenizer.prefix_tokens
)
hparams["valid_greedy_searcher"].set_language_token(
tokenizer.prefix_tokens[1]
)
hparams["test_beam_searcher"].set_decoder_input_tokens(
tokenizer.prefix_tokens
)
hparams["test_beam_searcher"].set_language_token(tokenizer.prefix_tokens[1])
# here we create the datasets objects as well as tokenization and encoding
train_data, valid_data, test_datasets = dataio_prepare(hparams, tokenizer)
# Trainer initialization
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
opt_class=hparams["whisper_opt_class"],
)
# We load the pretrained whisper model
if "pretrainer" in hparams.keys():
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected(asr_brain.device)
# We dynamicaly add the tokenizer to our brain class.
# NB: This tokenizer corresponds to the one used for Whisper.
asr_brain.tokenizer = tokenizer
# Training
asr_brain.fit(
asr_brain.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=hparams["train_loader_kwargs"],
valid_loader_kwargs=hparams["valid_loader_kwargs"],
)
# Testing
for k in test_datasets.keys(): # keys are test_clean, test_other etc
asr_brain.hparams.wer_file = os.path.join(
hparams["output_folder"], "wer_{}.txt".format(k)
)
asr_brain.evaluate(
test_datasets[k], test_loader_kwargs=hparams["test_loader_kwargs"]
)