Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating Pruned Fast RNNT with Transducer + new recipe for mTEDx dataset #1465

Open
wants to merge 36 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
608e1de
added script to prepare mTEDx dataset
Jun 22, 2022
26dddf5
added an asr transducer training file for mTEDx recipe
Jun 22, 2022
1a6fbbe
added jointer network to be used with the pruned loss of Fast RNNT
Jun 22, 2022
e9e6241
added pruned-loss to the losses script
Jun 22, 2022
af2c3b3
created simple beam searcher for the pruned loss; just the same as Tr…
Jun 22, 2022
9e6e1ee
added a recipe for creating a tokenizer on mTEDx dataset
Jun 22, 2022
a2073fa
added a recipe for creating an RNN language model on mTEDx-French dat…
Jun 22, 2022
22ec024
added a recipe for creating an RNN language model on mTEDx-French dat…
Jun 22, 2022
b74a424
added yaml file for training ASR transducer on mTEDx
Jun 22, 2022
61ccae8
added yaml file for training ASR transducer on mTEDx
Jun 22, 2022
0ca6e7a
Merge remote-tracking branch 'upstream/develop' into 'pruned_fast_rnnt'
Jul 29, 2022
78c9008
added README file for mTEDx recipe
Anwarvic Aug 2, 2022
f9b9e03
Merge branch 'speechbrain:develop' into pruned_fast_rnnt
Anwarvic Aug 12, 2022
900c261
Merge branch 'speechbrain:develop' into pruned_fast_rnnt
Anwarvic Sep 16, 2022
4e38371
Merge branch 'speechbrain:develop' into pruned_fast_rnnt
Anwarvic Sep 18, 2022
2d60e5e
updated Transducer recipes + added README
Anwarvic Sep 19, 2022
bde66d5
updated Transducer recipes + added README
Anwarvic Sep 19, 2022
112b688
added CTC recipes
Anwarvic Sep 19, 2022
c5cbe1f
updated files with latest updates
Anwarvic Sep 19, 2022
2974d3a
Merge branch 'pruned_fast_rnnt' of https://github.com/Anwarvic/speech…
Anwarvic Sep 19, 2022
eb37ab2
updated scripts with latest updates
Anwarvic Sep 19, 2022
cddef0a
fixed pre-commit erorrs
Anwarvic Sep 19, 2022
6b2e8f5
fixed pre-commit erorrs
Anwarvic Sep 19, 2022
0ea78d7
added recipes yaml files to tests/recipes.csv
Anwarvic Sep 19, 2022
3960a4e
fixed the un-used dnn_neurons variable in train_wav2vec.yaml file
Anwarvic Sep 19, 2022
f76f0a1
pre-commit passed successfully
Anwarvic Sep 19, 2022
9f36769
updated transducer configs in the other dataset recipes to match the …
Anwarvic Sep 19, 2022
405bcee
updated transducer configs in the other dataset recipes to match the …
Anwarvic Sep 19, 2022
82d450b
added needed README files for mTEDx recipes
Anwarvic Sep 25, 2022
1bf6988
changed use_torchaudio flag in Transducer recipes README all across d…
Anwarvic Sep 25, 2022
b71c2d7
fixed wrong pths in tests/recipes.csv
Anwarvic Sep 25, 2022
9fd2564
added CTC models to CTC README of mTEDx recipe
Anwarvic Sep 26, 2022
dd41227
fixed merged issues in tests/recipes.csv
Anwarvic Sep 26, 2022
209a60a
minor changes in README file
Anwarvic Sep 26, 2022
f22f144
removed unused variables in conf files in mTEDx recipes
Anwarvic Sep 26, 2022
b2aedf5
fixed the naming issue for transducer recipe
Anwarvic Sep 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
226 changes: 226 additions & 0 deletions recipes/mTEDx/ASR/Transducer/hparams/train_pruned.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# ############################################################################
# Model: E2E ASR with attention-based ASR
# Encoder: CRDNN model
# Decoder: GRU + beamsearch + RNNLM
# Tokens: BPE with unigram
# losses: Transducer
# Training: mTEDx-fr
# Authors: Mohamed Anwar 2022
# ############################################################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
output_folder: !ref ./results/mTEDx_fr/CASCADE/CRDNN_pruned
wer_file: !ref <output_folder>/wer.txt
save_folder: !ref <output_folder>/save

# Language model (LM) pretraining
# NB: To avoid mismatch, the speech recognizer must be trained with the same
# tokenizer used for LM training. Here, we download everything from the
# speechbrain HuggingFace repository. However, a local path pointing to a
# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
# instead. E.g if you want to use your own LM / tokenizer.
pretrained_tokenizer_path: #!PLACEHOLDER
pretrained_lm_path: #!PLACEHOLDER

# Data files
data_folder: #!PLACEHOLDER
anautsch marked this conversation as resolved.
Show resolved Hide resolved
langs:
- fr
remove_punc_cap: True # remove punctuation & capitalization from text

train_json: !ref <data_folder>/train_fr.json
valid_json: !ref <data_folder>/valid_fr.json
test_json: !ref <data_folder>/test_fr.json

# Training parameters
number_of_epochs: 30
batch_size: 8
batch_size_valid: 2 #for valid & test
lr: 1
sorting: random #ascending, descending, random
dynamic_batching: False

# Feature parameters
sample_rate: 16000
n_fft: 400
n_mels: 40

opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>

valid_dataloader_opts:
batch_size: !ref <batch_size_valid>

test_dataloader_opts:
batch_size: !ref <batch_size_valid>

# Model parameters
activation: !name:torch.nn.LeakyReLU
dropout: 0.15
cnn_blocks: 2
cnn_channels: (128, 256)
inter_layer_pooling_size: (2, 2)
cnn_kernelsize: (3, 3)
time_pooling_size: 4
rnn_class: !name:speechbrain.nnet.RNN.LSTM
rnn_layers: 4
rnn_neurons: 1024
rnn_bidirectional: True
dnn_blocks: 2
output_neurons: 1000 # index(blank/eos/bos) = 0
dnn_neurons: !ref <output_neurons>
dec_neurons: !ref <output_neurons>
joint_dim: 1024
blank_index: 0

# Decoding parameters
beam_size: 4
nbest: 1
# by default {state,expand}_beam = 2.3 as mention in paper
# https://arxiv.org/abs/1904.02619
state_beam: 2.3
expand_beam: 2.3
lm_weight: 0.50

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

normalize: !new:speechbrain.processing.features.InputNormalization
norm_type: global

compute_features: !new:speechbrain.lobes.features.Fbank
sample_rate: !ref <sample_rate>
n_fft: !ref <n_fft>
n_mels: !ref <n_mels>

enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
input_shape: [null, null, !ref <n_mels>]
activation: !ref <activation>
dropout: !ref <dropout>
cnn_blocks: !ref <cnn_blocks>
cnn_channels: !ref <cnn_channels>
cnn_kernelsize: !ref <cnn_kernelsize>
inter_layer_pooling_size: !ref <inter_layer_pooling_size>
time_pooling: True
using_2d_pooling: False
time_pooling_size: !ref <time_pooling_size>
rnn_class: !ref <rnn_class>
rnn_layers: !ref <rnn_layers>
rnn_neurons: !ref <rnn_neurons>
rnn_bidirectional: !ref <rnn_bidirectional>
rnn_re_init: True
dnn_blocks: !ref <dnn_blocks>
dnn_neurons: !ref <dnn_neurons>

emb: !new:speechbrain.nnet.embedding.Embedding
num_embeddings: !ref <output_neurons>
consider_as_one_hot: True
blank_id: !ref <blank_index>

dec: !new:speechbrain.nnet.RNN.GRU
input_shape: [null, null, !ref <output_neurons> - 1]
hidden_size: !ref <dec_neurons>
num_layers: 1
re_init: True

Tjoint: !new:speechbrain.nnet.transducer.transducer_joint.Fast_RNNT_Joiner
input_dim: !ref <output_neurons>
inner_dim: !ref <joint_dim>
output_dim: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True

transducer_cost: !name:speechbrain.nnet.losses.fast_rnnt_pruned_loss
blank_index: !ref <blank_index>
prune_range: 115 # how many symbols to keep for each frame.
loss_scale: 0.5 #scale for the `mode` loss before adding pruned_loss
jointer: !ref <Tjoint>
mode: simple #simple | smoothed
reduction: mean

# This is the RNNLM that is used according to the Huggingface repository
# NB: It has to match the pre-trained RNNLM!!
lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
output_neurons: !ref <output_neurons>
embedding_dim: 128
activation: !name:torch.nn.LeakyReLU
dropout: 0.5
rnn_layers: 4
rnn_neurons: 512
dnn_blocks: 1
dnn_neurons: 256
return_hidden: True # For inference

# for MTL
# update model if any HEAD module is added
modules:
enc: !ref <enc>
emb: !ref <emb>
dec: !ref <dec>
Tjoint: !ref <Tjoint>
normalize: !ref <normalize>
lm_model: !ref <lm_model>

model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <emb>, !ref <dec>, !ref <Tjoint>]

# Tokenizer initialization
tokenizer: !new:sentencepiece.SentencePieceProcessor

Greedysearcher: !new:speechbrain.decoders.transducer.FastRNNTBeamSearcher
decode_network_lst: [!ref <emb>, !ref <dec>]
tjoint: !ref <Tjoint>
blank_id: !ref <blank_index>
beam_size: 1
nbest: 1

Beamsearcher: !new:speechbrain.decoders.transducer.FastRNNTBeamSearcher
decode_network_lst: [!ref <emb>, !ref <dec>]
tjoint: !ref <Tjoint>
blank_id: !ref <blank_index>
beam_size: !ref <beam_size>
nbest: !ref <nbest>
lm_module: !ref <lm_model>
lm_weight: !ref <lm_weight>
state_beam: !ref <state_beam>
expand_beam: !ref <expand_beam>

lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
model: !ref <model>
scheduler: !ref <lr_annealing>
normalizer: !ref <normalize>
counter: !ref <epoch_counter>

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
collect_in: !ref <save_folder>
loadables:
lm: !ref <lm_model>
tokenizer: !ref <tokenizer>
paths:
lm: !ref <pretrained_lm_path>
tokenizer: !ref <pretrained_tokenizer_path>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <output_folder>/train_log.txt

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True