-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Squashed commits from pchampio to avoid storing HF model junk in the repository (still have a backup of the branch somewhere)
- Loading branch information
Showing
22 changed files
with
2,083 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1196,3 +1196,4 @@ quelques | |
Université | ||
Università | ||
vie | ||
ois |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
wav2vec2_checkpoint/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
ftfy | ||
# k2 # It is better to install k2 with the procedure listed here: https://k2-fsa.github.io/k2/installation/from_wheels.html | ||
num2words | ||
soundfile |
228 changes: 228 additions & 0 deletions
228
recipes/ESTER+EPAC+ETAPE+REPERE/ASR/hparams/train_with_wav2vec_ctc_k2_char.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
# ################################ | ||
# Model: wav2vec2 + DNN + CTC + LM (k2) | ||
# Augmentation: SpecAugment + Speed | ||
# Authors: Pierre Champion 2023 | ||
# ################################ | ||
|
||
# Seed needs to be set at top of yaml, before objects with parameters are made | ||
seed: 1116 | ||
__set_seed: !apply:torch.manual_seed [!ref <seed>] | ||
output_folder: !ref results/<merge_train_csv>_wav2vec2_<token_type>_k2/<seed> | ||
save_folder: !ref <output_folder>/save | ||
train_log: !ref <output_folder>/train_log.txt | ||
|
||
# wav2vec2_hub: LeBenchmark/wav2vec2-FR-3K-base | ||
# wav2vec2_out_shape: 768 | ||
|
||
# wav2vec2_hub: LeBenchmark/wav2vec2-FR-7K-large | ||
# wav2vec2_out_shape: 1024 | ||
|
||
wav2vec2_hub: LeBenchmark/wav2vec2-FR-14K-xlarge | ||
wav2vec2_out_shape: 1280 | ||
|
||
wav2vec2_folder: !ref wav2vec2_checkpoint | ||
|
||
# Data files | ||
data_folder: !PLACEHOLDER # e.g, /path/to/corpus (**/*.stm) | ||
stm_directory: !ref <data_folder>/**/[^\.ne_e2\.|\.ne\.|\.spk\.|part\.]*.stm | ||
wav_directory: !ref <data_folder>/**/*.wav | ||
train_splits: {"train_ESTER2":["/ESTER2/train_trans_rapide/*", "/ESTER2/train/*"], "train_ESTER1":["/ESTER1/train/*"], "train_EPAC":["/EPAC/train/*"], "train_ETAPE":["/ETAPE/train/*"], "train_REPERE":["/REPERE/train/*"]} | ||
dev_splits: {"dev_ESTER2":["/ESTER2/dev/*"], "dev_ESTER1":["/ESTER1/dev/*"], "dev_ETAPE":["/ETAPE/dev/*"], "dev_REPERE2014":["/REPERE/dev2014/*"]} | ||
test_splits: {"test_ESTER2":["/ESTER2/test/*"], "test_ESTER1":["/ESTER1/test/*"], "test_ETAPE":["/ETAPE/test/*"], "test_EPAC":["/EPAC/test/*"], "test_REPERE2014":["/REPERE/test2014/*"]} | ||
merge_train_csv: "train_ESTER2+train_ESTER1+train_EPAC+train_ETAPE+train_REPERE" | ||
prep_save_folder: !ref <output_folder> | ||
skip_prep: False | ||
skip_token_prep: False | ||
ckpt_interval_minutes: 10 # save checkpoint every N min | ||
# the following CSVs are found in <output_folder> | ||
train_csv: train.csv | ||
valid_csv: dev_ESTER2.csv | ||
test_csv: | ||
- test_ESTER2.csv | ||
- test_ESTER1.csv | ||
- test_ETAPE.csv | ||
- test_EPAC.csv | ||
- test_REPERE2014.csv | ||
|
||
# For k2 CTC training | ||
caching: True | ||
lang_dir: !ref <output_folder>/lang | ||
vocab_file: !ref <output_folder>/vocab.txt | ||
# token_type: phone | ||
token_type: char | ||
sil_prob: 0. | ||
add_word_boundary: True | ||
# For k2 decoding | ||
test_search_beam: 32 | ||
# Beam size (for decoding) | ||
test_output_beam: 8 | ||
test_min_active_state: 300 | ||
test_max_active_state: 3000 | ||
# Acoustic scale (mutliplied by the log probs) | ||
ac_scale: 1.5 | ||
compose_HL_with_G: True | ||
# 1best or whole-lattice-rescoring | ||
# decoding_method: whole-lattice-rescoring | ||
decoding_method: 1best | ||
# LM scale to be used for rescoring. Only used if rescoring | ||
rescoring_lm_scale: 0.4 | ||
# This is where the 3gram and (optionally) 4gram LM are stored | ||
# They can be in either ARPA or FST format. If the former, then | ||
# the FST equivalent will be created in the same directory by | ||
# using kaldilm. | ||
lm_dir: ../LM/results/n_gram_lm | ||
|
||
G_arpa: 3-for-char-gram.arpa | ||
G_rescoring_arpa: 4-for-char-gram.arpa | ||
|
||
# Training parameters | ||
number_of_epochs: 20 | ||
lr: 1.5 | ||
lr_wav2vec: 0.001 | ||
sorting: ascending # only ascending and descending are supported currently | ||
precision: fp32 | ||
sample_rate: 16000 | ||
|
||
# With data_parallel batch_size is split into N jobs | ||
# With DDP batch_size is multiplied by N jobs | ||
batch_size: 8 | ||
num_workers: 20 | ||
grad_accumulation_factor: 2 | ||
test_batch_size: 12 | ||
# nonfinite loss does happen from time to time on some segments with char training | ||
nonfinite_patience: 20 | ||
|
||
# In seconds | ||
avoid_if_longer_than: 90.0 | ||
avoid_if_smaller_than: 0.5 | ||
|
||
# Dataloader options | ||
train_dataloader_opts: | ||
batch_size: !ref <batch_size> | ||
num_workers: !ref <num_workers> | ||
|
||
valid_dataloader_opts: | ||
batch_size: !ref <batch_size> | ||
num_workers: !ref <num_workers> | ||
|
||
test_dataloader_opts: | ||
batch_size: !ref <test_batch_size> | ||
num_workers: !ref <num_workers> | ||
|
||
# Model parameters | ||
activation: !name:torch.nn.LeakyReLU | ||
dnn_layers: 4 | ||
freeze_wav2vec: True | ||
number_lines_for_tokens: 52 | ||
|
||
# Outputs | ||
# in k2 check lang/tokens.txt | ||
# BPE size, index(blank/eos/bos) = 0 | ||
output_neurons: !apply:speechbrain.lobes.utils.NumberOfLines | ||
file: !ref <lang_dir>/tokens.txt | ||
default: !ref <number_lines_for_tokens> | ||
|
||
# | ||
# Functions and classes | ||
# | ||
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter | ||
limit: !ref <number_of_epochs> | ||
|
||
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb | ||
orig_freq: !ref <sample_rate> | ||
speeds: [95, 100, 105] | ||
|
||
# Frequency drop: randomly drops a number of frequency bands to zero. | ||
drop_freq_low: 0 # Min frequency band dropout probability | ||
drop_freq_high: 1 # Max frequency band dropout probability | ||
drop_freq_count_low: 1 # Min number of frequency bands to drop | ||
drop_freq_count_high: 3 # Max number of frequency bands to drop | ||
drop_freq_width: 0.05 # Width of frequency bands to drop | ||
|
||
drop_freq: !new:speechbrain.augment.time_domain.DropFreq | ||
drop_freq_low: !ref <drop_freq_low> | ||
drop_freq_high: !ref <drop_freq_high> | ||
drop_freq_count_low: !ref <drop_freq_count_low> | ||
drop_freq_count_high: !ref <drop_freq_count_high> | ||
drop_freq_width: !ref <drop_freq_width> | ||
|
||
# Augmenter: Combines previously defined augmentations to perform data augmentation | ||
wav_augment: !new:speechbrain.augment.augmenter.Augmenter | ||
parallel_augment: False | ||
repeat_augment: 1 | ||
shuffle_augmentations: False | ||
min_augmentations: 1 | ||
max_augmentations: 1 | ||
augment_prob: 1.0 | ||
augmentations: [ | ||
!ref <speed_perturb>, | ||
!ref <drop_freq>, | ||
] | ||
|
||
enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN | ||
input_shape: [null, null, !ref <wav2vec2_out_shape>] | ||
activation: !ref <activation> | ||
dnn_blocks: !ref <dnn_layers> | ||
dnn_neurons: !ref <wav2vec2_out_shape> | ||
|
||
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 | ||
source: !ref <wav2vec2_hub> | ||
output_norm: True | ||
freeze: !ref <freeze_wav2vec> | ||
save_path: !ref <wav2vec2_folder> | ||
|
||
ctc_lin: !new:speechbrain.nnet.linear.Linear | ||
input_size: !ref <wav2vec2_out_shape> | ||
n_neurons: !ref <output_neurons> | ||
|
||
log_softmax: !new:speechbrain.nnet.activations.Softmax | ||
apply_log: True | ||
|
||
ctc_cost: !name:speechbrain.k2_integration.losses.ctc_k2 | ||
reduction: mean | ||
beam_size: 10 | ||
|
||
modules: | ||
wav2vec2: !ref <wav2vec2> | ||
enc: !ref <enc> | ||
ctc_lin: !ref <ctc_lin> | ||
|
||
model: !new:torch.nn.ModuleList | ||
- [!ref <enc>, !ref <ctc_lin>] | ||
|
||
model_opt_class: !name:torch.optim.Adadelta | ||
lr: !ref <lr> | ||
rho: 0.95 | ||
eps: 1.e-8 | ||
|
||
wav2vec_opt_class: !name:torch.optim.Adam | ||
lr: !ref <lr_wav2vec> | ||
|
||
lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler | ||
initial_value: !ref <lr> | ||
improvement_threshold: 0.0025 | ||
annealing_factor: 0.8 | ||
patient: 0 | ||
|
||
lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler | ||
initial_value: !ref <lr_wav2vec> | ||
improvement_threshold: 0.0025 | ||
annealing_factor: 0.9 | ||
patient: 0 | ||
|
||
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer | ||
checkpoints_dir: !ref <save_folder> | ||
recoverables: | ||
wav2vec2: !ref <wav2vec2> | ||
model: !ref <model> | ||
scheduler_model: !ref <lr_annealing_model> | ||
scheduler_wav2vec: !ref <lr_annealing_wav2vec> | ||
counter: !ref <epoch_counter> | ||
|
||
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger | ||
save_file: !ref <train_log> | ||
|
||
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
|
||
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
split_tokens: True |
Oops, something went wrong.