Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP french dataset #2459

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .dict-speechbrain.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1196,3 +1196,4 @@ quelques
Université
Università
vie
ois
1 change: 1 addition & 0 deletions recipes/ESTER+EPAC+ETAPE+REPERE/ASR/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
wav2vec2_checkpoint/
46 changes: 46 additions & 0 deletions recipes/ESTER+EPAC+ETAPE+REPERE/ASR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# *ESTER+EPAC+ETAPE+REPERE* (ELRA) CTC ASR with pre-trained wav2vec2.

Information about the datasets here:
- ESTER1: https://catalogue.elra.info/en-us/repository/browse/ELRA-S0241
- ESTER2: https://catalogue.elra.info/en-us/repository/browse/ELRA-S0338
- ETAPE: https://catalogue.elra.info/en-us/repository/browse/ELRA-E0046
- EPAC: https://catalogue.elra.info/en-us/repository/browse/ELRA-S0305
- REPERE: https://catalogue.elra.info/en-us/repository/browse/ELRA-E0044

**Supported pre-trained wav2vec2 from LeBenchmark:** [HuggingFace](https://huggingface.co/LeBenchmark)

## Installing Extra Dependencies

Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:

```
pip install -r extra_requirements.txt
```

# WFST-based CTC training/inference
To fine-tune a wav2vec 2.0 model with the WFST-based CTC loss, you can use the `train_with_wav2vec_k2.py` script. This will create a `lang` directory inside your output folder, which will contain the files required to build a lexicon FST. The tokenization method used here is a very basic character-based tokenization (e.g. `hello -> h e l l o`).

To use this script, you will first need to install `k2`. The integration has been tested with `k2==1.24.4` and `torch==2.0.1`, although it should also work with any `torch` version as long as `k2` supports it (compatibility list [here](https://k2-fsa.github.io/k2/installation/pre-compiled-cuda-wheels-linux/index.html)). You can install `k2` by following the instructions [here](https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cuda-example).

Using a lexicon FST (L) while training can help guide the model to better predictions. When decoding, you can either use a simple HL decoding graph (where H is the ctc topology), or use an HLG graph (where G is usually a 3-gram language model) to further improve the results. In addition, whole lattice rescoring is also supported. This typically happens with a 4-gram language model. See `hparams/train_with_wav2vec_k2.yaml`` for more details.

If you choose to use a 3-gram or a 4-gram language model, you can either supply pre-existing ARPA LMs for both cases, including the option to train your own, or you can specify the name in the YAML docstring for automatic downloading. Comprehensive instructions are provided in `train_hf_wav2vec_k2.yaml`.

For those interested in training their own language model, please consult our recipe at ESTER+EPAC+ETAPE+REPERE/LM/train_ngram.py.

Example usage:
```
python3 train_with_wav2vec_ctc_k2.py hparams/train_with_wav2vec_ctc_k2_phone.yaml --data_folder=/path/to/ESTER+EPAC+ETAPE+REPERE/parent_dir
```

To use the HLG graph (instead of the default HL), pass `--compose_HL_with_G=True`. To use the 4-gram LM for rescoring, pass the `--decoding_method=whole-lattice-rescoring` argument. Note that this will require more memory, as the whole lattice will be kept in memory during the decoding. In this recipe, the `lm_scale` used by default is 0.4.

| Release | Hyperparams file | Decoding method | Text Normalization for scoring | EPAC WER | ESTER1 WER | ESTER2 WER | ETAPE | REPERE |
|:----------:|:------------------------------------:|:---------------------------------:|:-------------------------------:|:----------:|:----------:|:----------:|:-----:|:-------|
| 21/05/2024 | train_with_wav2vec_ctc_k2_phone.yaml | k2CTC + HL graph + 1best decoding | No | 14.41 | 12.81 | 13.66 | 24.90 | 13.95 |
| 21/05/2024 | train_with_wav2vec_ctc_k2_char.yaml | k2CTC + HL graph + 1best decoding | No | 15.17 | 13.18 | 14.21 | 26.16 | 14.74 |
| 30/05/2024 | train_with_wav2vec_ctc_k2_phone.yaml | k2CTC + HL graph + 1best decoding | Yes | 9.49 | 10.19 | 11.36 | 23.01 | 11.58 |
| 30/05/2024 | train_with_wav2vec_ctc_k2_char.yaml | k2CTC + HL graph + 1best decoding | Yes | 10.96 | 11.00 | 12.39 | 24.83 | 12.88 |

## Text normalization for scoring
The script for text normalization scoring is available at: https://github.com/pchampio/UB-WER-NORM_ESTER-EPAC-ETAPE-REPERE/
4 changes: 4 additions & 0 deletions recipes/ESTER+EPAC+ETAPE+REPERE/ASR/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
ftfy
# k2 # It is better to install k2 with the procedure listed here: https://k2-fsa.github.io/k2/installation/from_wheels.html
num2words
soundfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# ################################
# Model: wav2vec2 + DNN + CTC + LM (k2)
# Augmentation: SpecAugment + Speed
# Authors: Pierre Champion 2023
# ################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1116
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/<merge_train_csv>_wav2vec2_<token_type>_k2/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# wav2vec2_hub: LeBenchmark/wav2vec2-FR-3K-base
# wav2vec2_out_shape: 768

# wav2vec2_hub: LeBenchmark/wav2vec2-FR-7K-large
# wav2vec2_out_shape: 1024

wav2vec2_hub: LeBenchmark/wav2vec2-FR-14K-xlarge
wav2vec2_out_shape: 1280

wav2vec2_folder: !ref wav2vec2_checkpoint

# Data files
data_folder: !PLACEHOLDER # e.g, /path/to/corpus (**/*.stm)
stm_directory: !ref <data_folder>/**/[^\.ne_e2\.|\.ne\.|\.spk\.|part\.]*.stm
wav_directory: !ref <data_folder>/**/*.wav
train_splits: {"train_ESTER2":["/ESTER2/train_trans_rapide/*", "/ESTER2/train/*"], "train_ESTER1":["/ESTER1/train/*"], "train_EPAC":["/EPAC/train/*"], "train_ETAPE":["/ETAPE/train/*"], "train_REPERE":["/REPERE/train/*"]}
dev_splits: {"dev_ESTER2":["/ESTER2/dev/*"], "dev_ESTER1":["/ESTER1/dev/*"], "dev_ETAPE":["/ETAPE/dev/*"], "dev_REPERE2014":["/REPERE/dev2014/*"]}
test_splits: {"test_ESTER2":["/ESTER2/test/*"], "test_ESTER1":["/ESTER1/test/*"], "test_ETAPE":["/ETAPE/test/*"], "test_EPAC":["/EPAC/test/*"], "test_REPERE2014":["/REPERE/test2014/*"]}
merge_train_csv: "train_ESTER2+train_ESTER1+train_EPAC+train_ETAPE+train_REPERE"
prep_save_folder: !ref <output_folder>
skip_prep: False
skip_token_prep: False
ckpt_interval_minutes: 10 # save checkpoint every N min
# the following CSVs are found in <output_folder>
train_csv: train.csv
valid_csv: dev_ESTER2.csv
test_csv:
- test_ESTER2.csv
- test_ESTER1.csv
- test_ETAPE.csv
- test_EPAC.csv
- test_REPERE2014.csv

# For k2 CTC training
caching: True
lang_dir: !ref <output_folder>/lang
vocab_file: !ref <output_folder>/vocab.txt
# token_type: phone
token_type: char
sil_prob: 0.
add_word_boundary: True
# For k2 decoding
test_search_beam: 32
# Beam size (for decoding)
test_output_beam: 8
test_min_active_state: 300
test_max_active_state: 3000
# Acoustic scale (mutliplied by the log probs)
ac_scale: 1.5
compose_HL_with_G: True
# 1best or whole-lattice-rescoring
# decoding_method: whole-lattice-rescoring
decoding_method: 1best
# LM scale to be used for rescoring. Only used if rescoring
rescoring_lm_scale: 0.4
# This is where the 3gram and (optionally) 4gram LM are stored
# They can be in either ARPA or FST format. If the former, then
# the FST equivalent will be created in the same directory by
# using kaldilm.
lm_dir: ../LM/results/n_gram_lm

G_arpa: 3-for-char-gram.arpa
G_rescoring_arpa: 4-for-char-gram.arpa

# Training parameters
number_of_epochs: 20
lr: 1.5
lr_wav2vec: 0.001
sorting: ascending # only ascending and descending are supported currently
precision: fp32
sample_rate: 16000

# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
batch_size: 8
num_workers: 20
grad_accumulation_factor: 2
test_batch_size: 12
# nonfinite loss does happen from time to time on some segments with char training
nonfinite_patience: 20

# In seconds
avoid_if_longer_than: 90.0
avoid_if_smaller_than: 0.5

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
num_workers: !ref <num_workers>

valid_dataloader_opts:
batch_size: !ref <batch_size>
num_workers: !ref <num_workers>

test_dataloader_opts:
batch_size: !ref <test_batch_size>
num_workers: !ref <num_workers>

# Model parameters
activation: !name:torch.nn.LeakyReLU
dnn_layers: 4
freeze_wav2vec: True
number_lines_for_tokens: 52

# Outputs
# in k2 check lang/tokens.txt
# BPE size, index(blank/eos/bos) = 0
output_neurons: !apply:speechbrain.lobes.utils.NumberOfLines
file: !ref <lang_dir>/tokens.txt
default: !ref <number_lines_for_tokens>

#
# Functions and classes
#
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
orig_freq: !ref <sample_rate>
speeds: [95, 100, 105]

# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq_low: 0 # Min frequency band dropout probability
drop_freq_high: 1 # Max frequency band dropout probability
drop_freq_count_low: 1 # Min number of frequency bands to drop
drop_freq_count_high: 3 # Max number of frequency bands to drop
drop_freq_width: 0.05 # Width of frequency bands to drop

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
drop_freq_low: !ref <drop_freq_low>
drop_freq_high: !ref <drop_freq_high>
drop_freq_count_low: !ref <drop_freq_count_low>
drop_freq_count_high: !ref <drop_freq_count_high>
drop_freq_width: !ref <drop_freq_width>

# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
parallel_augment: False
repeat_augment: 1
shuffle_augmentations: False
min_augmentations: 1
max_augmentations: 1
augment_prob: 1.0
augmentations: [
!ref <speed_perturb>,
!ref <drop_freq>,
]

enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
input_shape: [null, null, !ref <wav2vec2_out_shape>]
activation: !ref <activation>
dnn_blocks: !ref <dnn_layers>
dnn_neurons: !ref <wav2vec2_out_shape>

wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
save_path: !ref <wav2vec2_folder>

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <wav2vec2_out_shape>
n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True

ctc_cost: !name:speechbrain.k2_integration.losses.ctc_k2
reduction: mean
beam_size: 10

modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>

model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]

model_opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8

wav2vec_opt_class: !name:torch.optim.Adam
lr: !ref <lr_wav2vec>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0

lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
wav2vec2: !ref <wav2vec2>
model: !ref <model>
scheduler_model: !ref <lr_annealing_model>
scheduler_wav2vec: !ref <lr_annealing_wav2vec>
counter: !ref <epoch_counter>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True