-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
robust_asr_16k.yaml
199 lines (163 loc) · 5.99 KB
/
robust_asr_16k.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Model: wav2vec2 + DNN + CTC
# Augmentation: SpecAugment
# Authors: Sangeet Sagar 2023
# ################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 8200
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/robust_asr/<seed>
test_wer_file: !ref <output_folder>/wer_test.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
# URL for the biggest whisper model.
whisper_hub: !ref openai/whisper-large-v2
whisper_folder: !ref <save_folder>/whisper_checkpoint
language: german
# Path to pre-trained models
pretrained_whisper_path: speechbrain/whisper_rescuespeech
pretrained_enhance_path: speechbrain/sepformer_rescuespeech
epochs_before_lr_drop: 2
unfreeze_epoch: !ref <epochs_before_lr_drop> + 1
frozen_models: [encoder, decoder, masknet, whisper]
unfrozen_models: [masknet, whisper]
# Dataset prep parameters
data_folder: !PLACEHOLDER
train_tsv_file: !ref <data_folder>/train.tsv
dev_tsv_file: !ref <data_folder>/dev.tsv
test_tsv_file: !ref <data_folder>/test.tsv
accented_letters: True
train_csv: !ref <output_folder>/train.csv
valid_csv: !ref <output_folder>/dev.csv
test_csv: !ref <output_folder>/test.csv
skip_prep: False
# We remove utterance slonger than 10s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 10.0
## Model parameters- Enhance model
dereverberate: False
save_audio: True
enhance_sample_rate: 16000
limit_training_signal_len: False
training_signal_len: 64000
use_wavedrop: False
use_speedperturb: True
use_freq_domain: False
use_rand_shift: False
min_shift: -8000
max_shift: 8000
## Training parameters- ASR
do_augmentation: False
number_of_epochs: 10
lr_whisper: 0.00003
sorting: ascending
auto_mix_prec: False
asr_sample_rate: 16000
ckpt_interval_minutes: 30 # save checkpoint every N min
checkpoint_avg: 5
# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 6 per GPU to fit 16GB of VRAM
batch_size: 2
test_batch_size: 2
# These values are only used for the searchers.
# They needs to be hardcoded and should not be changed with Whisper.
# They are used as part of the searching process.
# The bos token of the searcher will be timestamp_index
# and will be concatenated with the bos, language and task tokens.
timestamp_index: 50363
eos_index: 50257
bos_index: 50258
# Decoding parameters
min_decode_ratio: 0.0
max_decode_ratio: 1.0
test_beam_size: 8
# Whisper model parameters
freeze_whisper: False
freeze_encoder_only: False
freeze_encoder: True
train_loader_kwargs:
batch_size: !ref <batch_size>
valid_loader_kwargs:
batch_size: !ref <batch_size>
test_loader_kwargs:
batch_size: !ref <test_batch_size>
# Loss weights
sepformer_weight: 0.1
asr_weight: 1
# Functions and classes
speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
perturb_prob: 1.0
drop_freq_prob: 0.0
drop_chunk_prob: 0.0
sample_rate: !ref <enhance_sample_rate>
speeds: [95, 100, 105]
enhance_model: !include:../models/sepformer.yaml
augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
sample_rate: !ref <asr_sample_rate>
speeds: [95, 100, 105]
whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
source: !ref <whisper_hub>
freeze: !ref <freeze_whisper>
save_path: !ref <whisper_folder>
encoder_only: !ref <freeze_encoder_only>
freeze_encoder: !ref <freeze_encoder>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
nll_loss: !name:speechbrain.nnet.losses.nll_loss
whisper_opt_class: !name:torch.optim.AdamW
lr: !ref <lr_whisper>
weight_decay: 0.01
valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
model: !ref <whisper>
bos_index: !ref <timestamp_index>
eos_index: !ref <eos_index>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
module: [!ref <whisper>]
bos_index: !ref <timestamp_index>
eos_index: !ref <eos_index>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_whisper>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
# Enhance loss
enhance_loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper
# Change the path to use a local model instead of the remote one
asr_pretrained: !new:speechbrain.utils.parameter_transfer.Pretrainer
collect_in: !ref <save_folder>
loadables:
encoder: !ref <enhance_model[Encoder]>
masknet: !ref <enhance_model[MaskNet]>
decoder: !ref <enhance_model[Decoder]>
whisper: !ref <whisper>
paths:
encoder: !ref <pretrained_enhance_path>/encoder.ckpt
decoder: !ref <pretrained_enhance_path>/decoder.ckpt
masknet: !ref <pretrained_enhance_path>/masknet.ckpt
whisper: !ref <pretrained_whisper_path>/whisper.ckpt
modules:
encoder: !ref <enhance_model[Encoder]>
masknet: !ref <enhance_model[MaskNet]>
decoder: !ref <enhance_model[Decoder]>
whisper: !ref <whisper>
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
encoder: !ref <enhance_model[Encoder]>
decoder: !ref <enhance_model[Decoder]>
masknet: !ref <enhance_model[MaskNet]>
whisper: !ref <whisper>
scheduler_whisper: !ref <lr_annealing_whisper>
counter: !ref <epoch_counter>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True