Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTC-only training recipes for LibriSpeech (code from Samsung AI Cambridge) #2290

Merged
merged 22 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions recipes/LibriSpeech/ASR/CTC/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@ To run a fine-tuning of "WavLM" with signal downsampled inputs (for faster train
```
python train_with_wav2vec.py hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml --downsampling_factor 2
```
To train a model from scratch (without any pre-training), please firstly go to the Tokenizer folder to train a tokenizer:

```
cd ../../Tokenizer
python train.py hparams/128_bpe.yaml
```
Then, go back to this directory. You can train a Branchformer CTC model with:

```
python train_from_scratch.py hparams/train_branchformer.yaml
```
or a Conformer CTC model with:

```
python train_from_scratch.py hparams/train_conformer.yaml
```
# KenLM n-gram CTC rescoring
To enable n-gram rescoring during the decoding, you can download the LibriSpeech official LM from [here](https://www.openslr.org/11/). Please make sure to install the extra dependencies first. Any KenLM language model may be used with this rescoring technique. Results are reported without rescoring.

Expand All @@ -35,6 +50,9 @@ To enable n-gram rescoring during the decoding, you can download the LibriSpeech
| 09-09-21 | train_hf_wav2vec.yaml | 960h | 1.90 | [Link](https://huggingface.co/speechbrain/asr-wav2vec2-librispeech) | [Link](https://www.dropbox.com/sh/qj2ps85g8oiicrj/AAAxlkQw5Pfo0M9EyHMi8iAra?dl=0) | 1xRTX8000 48GB |
| 22-09-22 | train_sb_wav2vec.yaml | 960h | 4.2 | Not Avail. | Not Avail. | 2xTesla V100 32GB |
| 06-12-23 | train_hf_whisper.yaml (small) | 960h | 4.89 | Not Avail. | Not Avail. | 4xRTX 2080 Ti |
| 06-12-23 | train_branchformer.yaml (25.9M) | 960h | 3.6 (no LM) | Not Avail. | Not Avail. | 8xA40 46G |
| 06-12-23 | train_conformer.yaml (28.8M) | 960h | 3.7 (no LM) | Not Avail. | Not Avail. | 8xA40 46G |


# Downsampling inputs for faster fine-tuning and inferences using SSL Models
This repository contains the code allowing to reproduce part of the results obtained in the paper : "Fine-tuning Strategies for Faster Inference using Speech Self-Supervised Models: A Comparative Study"
Expand Down
214 changes: 214 additions & 0 deletions recipes/LibriSpeech/ASR/CTC/hparams/train_branchformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# ############################################################################
# Model: E2E ASR with CTC
# Encoder: Branchformer Encoder
# Decoder: CTC Only character level!
# Tokens: BPE
# Training: Librispeech 960h
# Authors: Titouan Parcollet
# Shucong Zhang
# ############################################################################
# Seed needs to be set at top of yaml, before objects with parameters are made

seed: 3402
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/branchformer_ctc/
wer_file: !ref <output_folder>/wer.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech
# If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES
# then data_folder_rirs should be /localscratch/xxx_corpus
# otherwise the dataset will automatically be downloaded
# data_folder_rirs: !ref <data_folder>
train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
dev_splits: ["dev-clean"]
test_splits: ["test-clean", "test-other"]
skip_prep: False
train_csv: !ref <output_folder>/train.csv
valid_csv: !ref <output_folder>/dev-clean.csv
test_csv:
- !ref <output_folder>/test-clean.csv
- !ref <output_folder>/test-other.csv

pretrained_tokenizer_path: !PLACEHOLDER # e.g., /path/to/128_bpe_model
tokenizer: !new:sentencepiece.SentencePieceProcessor

number_of_epochs: 500
batch_size: 16 # This works for 2x GPUs with 32GB
grad_accumulation_factor: 2
max_grad_norm: 5.0
sorting: descending #random
num_workers: 8
loss_reduction: batchmean
valid_search_interval: 1

lr_adam: 0.001
weight_decay: 0.0005

# Feature parameters
sample_rate: 16000
n_fft: 512
n_mels: 80
win_length: 25

# Training parameters
# To make Transformers converge, the global bath size should be large enough.
# The global batch size is max_batch_len * n_gpus * gradient_accumulation.
# Empirically, we used 850 * 8 A40 45G GPUs * 2 or 1700 * 4 A100 80G * 2.
# Please, set your parameters accordingly.
dynamic_batching: True
max_batch_len: 850
max_batch_len_val: 100 # we reduce it as the beam is much wider (VRAM)
num_bucket: 200

dynamic_batch_sampler:
max_batch_len: !ref <max_batch_len>
max_batch_len_val: !ref <max_batch_len_val>
num_buckets: !ref <num_bucket>
shuffle_ex: False # if true re-creates batches at each epoch shuffling examples.
batch_ordering: random
max_batch_ex: 128

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
shuffle: True
num_workers: !ref <num_workers>

valid_dataloader_opts:
batch_size: 1

test_dataloader_opts:
batch_size: 1

####################### Model parameters ###########################
# Transformer
attention_type: RelPosMHAXL
d_model: 256
nhead: 4
csgu_linear_units: 2400
csgu_kernel_size: 31
num_encoder_layers: 18
num_decoder_layers: 0
transformer_dropout: 0.1
activation: !name:torch.nn.GELU
output_neurons: 128


# Outputs
blank_index: 0

############################## models ################################

CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd
input_shape: (8, 10, 80)
num_blocks: 2
num_layers_per_block: 1
out_channels: (64, 32)
kernel_sizes: (3, 3)
strides: (2, 2)
residuals: (False, False)

Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length
input_size: 640
tgt_vocab: !ref <output_neurons>
d_model: !ref <d_model>
nhead: !ref <nhead>
num_encoder_layers: !ref <num_encoder_layers>
num_decoder_layers: !ref <num_decoder_layers>
dropout: !ref <transformer_dropout>
activation: !ref <activation>
encoder_module: branchformer
attention_type: !ref <attention_type>
normalize_before: True
causal: False
csgu_linear_units: !ref <csgu_linear_units>
kernel_size: !ref <csgu_kernel_size>


ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <d_model>
n_neurons: !ref <output_neurons>


normalize: !new:speechbrain.processing.features.InputNormalization
norm_type: global
update_until_epoch: 4

modules:
CNN: !ref <CNN>
Transformer: !ref <Transformer>
ctc_lin: !ref <ctc_lin>
normalize: !ref <normalize>

model: !new:torch.nn.ModuleList
- [!ref <CNN>, !ref <Transformer>, !ref <ctc_lin>]

Adam: !name:torch.optim.AdamW
lr: !ref <lr_adam>
betas: (0.9, 0.98)
eps: 0.000000001
weight_decay: !ref <weight_decay>

log_softmax: !new:torch.nn.LogSoftmax
dim: -1

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
reduction: !ref <loss_reduction>

noam_annealing: !new:speechbrain.nnet.schedulers.LinearNoamScheduler
lr_initial: !ref <lr_adam>
n_warmup_steps: 7500
n_keep_steps: 36000


checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
noam_scheduler: !ref <noam_annealing>
normalizer: !ref <normalize>
counter: !ref <epoch_counter>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

augmentation: !new:speechbrain.lobes.augment.SpecAugment
time_warp: True
time_warp_window: 5
time_warp_mode: bicubic
freq_mask: True
n_freq_mask: 2
time_mask: True
n_time_mask: 7
replace_with_zero: False
freq_mask_width: 27
time_mask_width: 30
time_mask_ratio: 0.05

speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb
orig_freq: !ref <sample_rate>
speeds: [95, 100, 105]

compute_features: !new:speechbrain.lobes.features.Fbank
sample_rate: !ref <sample_rate>
n_fft: !ref <n_fft>
win_length: !ref <win_length>
n_mels: !ref <n_mels>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
wer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
collect_in: !ref <save_folder>
loadables:
tokenizer: !ref <tokenizer>
paths:
tokenizer: !ref <pretrained_tokenizer_path>/128_bpe.model