Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KenLM n-gram training recepie #2304

Merged
merged 17 commits into from Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 25 additions & 9 deletions recipes/CommonVoice/ASR/CTC/README.md
@@ -1,9 +1,18 @@
# CommonVoice ASR with CTC based Seq2Seq models.
This folder contains scripts necessary to run an ASR experiment with the CommonVoice 14.0 dataset: CommonVoice Homepage and pytorch 2.0
This folder contains scripts necessary to run an ASR experiment with the CommonVoice 14.0 dataset

# How to run
python train.py hparams/{hparam_file}.yaml

To use an n-gram Language Model (LM) for decoding, follow these steps:
1. Uncomment the line `kenlm_model_path: none` in the `test_beam_serch` entry in the yaml file.
2. Set a path to an ARPA or bin file containing the n-gram LM.

For training an n-gram LM in ARPA (or bin) format, refer to the LM recipe in recipes/CommonVoice/LM.
Alternatively, you can download a pre-trained n-gram LM from our Dropbox repository at this link: [Pretrained n-gram LMs](https://www.dropbox.com/scl/fo/zw505t10kesqpvkt6m3tu/h?rlkey=6626h1h665tvlo1mtekop9rx5&dl=0).

These models are trained on the Commonvoice audio transcriptions available in the training set.

# Data preparation
It is important to note that CommonVoice initially offers mp3 audio files at 42Hz. Hence, audio files are downsampled on the fly within the dataio function of the training script.

Expand All @@ -19,17 +28,24 @@ Here is a list of the different languages that we tested within the CommonVoice
- Portuguese
- Chinese(china)

>>Note:
>In our experiments, we use CTC beam search and also boost the performance using the 5-gram model previously trained
on the transcription of the training data.(Refer to LM recipe: recipes/CommonVoice/LM).

>>Note:
> For Chinese the concept of word is not well-defined, hence, we consider the character error rate instead of the word error rate. For the same reason, we don't also employ 5-gram.

# Results
| Language | CommonVoice Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link | Model link | GPUs |
| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:| :-----------:|
| English | 2023-08-15 | train_en_with_wav2vec.yaml | No | 5.65 | 13.67 | 7.92 | 16.86 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-en) | [model](https://www.dropbox.com/sh/ch10cnbhf1faz3w/AACdHFG65LC6582H0Tet_glTa?dl=0) | 1xV100 32GB |
| German | 2023-08-15 | train_de_with_wav2vec.yaml | No | 1.74 | 7.40 | 2.24 | 8.93 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-de) | [model](https://www.dropbox.com/sh/dn7plq4wfsujsi1/AABS1kqB_uqLJVkg-bFkyPpVa?dl=0) | 1xV100 32GB |
| French | 2023-08-15 | train_fr_with_wav2vec.yaml | No | 2.59 | 8.47 | 3.44 | 10.24 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-fr) | [model](https://www.dropbox.com/sh/0i7esfa8jp3rxpp/AAArdi8IuCRmob2WAS7lg6M4a?dl=0) | 1xV100 32GB |
| Italian | 2023-08-15 | train_it_with_wav2vec.yaml | No | 2.10 | 7.77 | 2.38 | 8.38 |[model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-it) | [model](https://www.dropbox.com/sh/hthxqzh5boq15rn/AACftSab_FM6EFWWPgHpKw82a?dl=0) | 1xV100 32GB |
| Kinyarwanda | 2023-08-15 | train_rw_with_wav2vec.yaml | No | 5.47 | 19.58 | 7.59 | 23.71 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-rw) | [model](https://www.dropbox.com/sh/4iax0l4yfry37gn/AABuQ31JY-Sbyi1VlOJfV7haa?dl=0) | 1xV100 32GB |
| Arabic | 2023-08-15 | train_ar_with_wav2vec.yaml | No | 6.45 | 20.80 | 10.01 | 29.92 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-ar) | [model](https://www.dropbox.com/sh/7tnuqqbr4vy96cc/AAA_5_R0RmqFIiyR0o1nVS4Ia?dl=0) | 1xV100 32GB |
| Spanish | 2023-08-15 | train_es_with_wav2vec.yaml | No | 3.36 | 12.61 | 3.80 | 13.38 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-es) | [model](https://www.dropbox.com/sh/ejvzgl3d3g8g9su/AACYtbSWbDHvBr06lAb7A4mVa?dl=0) | 1xV100 32GB |
| Portuguese | 2023-08-15 | train_pt_with_wav2vec.yaml | No | 6.26 | 21.05 | 6.85 | 22.51 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-pt) | [model](https://www.dropbox.com/sh/80wucrvijdvao2a/AAD6-SZ2_ZZXmlAjOTw6fVloa?dl=0) | 1xV100 32GB |
| English | 2023-08-15 | train_en_with_wav2vec.yaml | No | 5.65 | 13.67 | 7.76 | 16.16 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-en) | [model](https://www.dropbox.com/sh/ch10cnbhf1faz3w/AACdHFG65LC6582H0Tet_glTa?dl=0) | 1xV100 32GB |
| German | 2023-08-15 | train_de_with_wav2vec.yaml | No | 1.74 | 7.40 | 2.18 | 8.39 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-de) | [model](https://www.dropbox.com/sh/dn7plq4wfsujsi1/AABS1kqB_uqLJVkg-bFkyPpVa?dl=0) | 1xV100 32GB |
| French | 2023-08-15 | train_fr_with_wav2vec.yaml | No | 2.59 | 8.47 | 3.36 | 9.71 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-fr) | [model](https://www.dropbox.com/sh/0i7esfa8jp3rxpp/AAArdi8IuCRmob2WAS7lg6M4a?dl=0) | 1xV100 32GB |
| Italian | 2023-08-15 | train_it_with_wav2vec.yaml | No | 2.10 | 7.77 | 2.30 | 7.99 |[model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-it) | [model](https://www.dropbox.com/sh/hthxqzh5boq15rn/AACftSab_FM6EFWWPgHpKw82a?dl=0) | 1xV100 32GB |
| Kinyarwanda | 2023-08-15 | train_rw_with_wav2vec.yaml | No | 5.47 | 19.58 | 7.30 | 22.52 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-rw) | [model](https://www.dropbox.com/sh/4iax0l4yfry37gn/AABuQ31JY-Sbyi1VlOJfV7haa?dl=0) | 1xV100 32GB |
| Arabic | 2023-08-15 | train_ar_with_wav2vec.yaml | No | 6.45 | 20.80 | 9.65 | 28.53 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-ar) | [model](https://www.dropbox.com/sh/7tnuqqbr4vy96cc/AAA_5_R0RmqFIiyR0o1nVS4Ia?dl=0) | 1xV100 32GB |
| Spanish | 2023-08-15 | train_es_with_wav2vec.yaml | No | 3.36 | 12.61 | 3.67 | 12.67 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-es) | [model](https://www.dropbox.com/sh/ejvzgl3d3g8g9su/AACYtbSWbDHvBr06lAb7A4mVa?dl=0) | 1xV100 32GB |
| Portuguese | 2023-08-15 | train_pt_with_wav2vec.yaml | No | 6.26 | 21.05 | 6.63 | 21.69 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-pt) | [model](https://www.dropbox.com/sh/80wucrvijdvao2a/AAD6-SZ2_ZZXmlAjOTw6fVloa?dl=0) | 1xV100 32GB |
| Chinese(china) | 2023-08-15 | train_zh-CN_with_wav2vec.yaml | No | 25.03 | - | 23.17 | - | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-14-zh-CN) | [model](https://www.dropbox.com/sh/2bikr81vgufoglf/AABMpD0rLIaZBxjtwBHgrNpga?dl=0) | 1xV100 32GB |


Expand Down
69 changes: 69 additions & 0 deletions recipes/CommonVoice/LM/README.md
@@ -0,0 +1,69 @@

# Traing KenLM
This folder contains recipes for training the kenLM-gram model for the CommonVoice Dataset.
Using Wav2Vec2 in combination with a language model can yield a significant improvement, especially when the model is fine-tuned on small speech datasets. This is a guide to explain how one can create an n-gram language model and combine it with an existing fine-tuned Wav2Vec2.


You can download CommonVoice at https://commonvoice.mozilla.org/en

## Installing Extra Dependencies

Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal:

```
pip install -r extra_requirements.txt
```

We will use the popular KenLM library to build an n-gram. Let's start by installing the Ubuntu library prerequisites. For a complete guide on how to install required dependencies, please refer to [this](https://kheafield.com/code/kenlm/dependencies/) link:
```
sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
```

Next, we need to start downloading and unpacking the KenLM repo.
```
wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
```

KenLM is written in C++, so we'll make use of cmake to build the binaries.
```
mkdir kenlm/build && cd kenlm/build && cmake .. && make -j2
```

Now, make sure that the executables are added to your .bashrc file. To do it,
- Open the ~/.bashrc file in a text editor.
- Scroll to the end of the file and add the following line: ```export PATH=$PATH:/your/path/to/kenlm/build/bin ```
- Save it and type: `source ~/.bashrc `

```
# How to run:
```shell
python train.py hparams/train_kenlm.yaml --data_folder=your/data/folder
```

# Results
The script trains a n-gram language model, which is stored in the popular ARPA format.
The output folders with checkpoints and logs can be found [here](https://www.dropbox.com/scl/fo/zw505t10kesqpvkt6m3tu/h?rlkey=6626h1h665tvlo1mtekop9rx5&dl=0).




# **About SpeechBrain**
- Website: https://speechbrain.github.io/
- Code: https://github.com/speechbrain/speechbrain/
- HuggingFace: https://huggingface.co/speechbrain/


# **Citing SpeechBrain**
Please, cite SpeechBrain if you use it for your research or business.

```bibtex
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}
```
1 change: 1 addition & 0 deletions recipes/CommonVoice/LM/common_voice_prepare.py
22 changes: 22 additions & 0 deletions recipes/CommonVoice/LM/hparams/train_kenlm.yaml
@@ -0,0 +1,22 @@
#########
# Recipe for Training kenLM on CommonVoice Data
# It is used to boost Wav2Vec2 with n-grams.
#
# Author: Pooneh Mousavi (2023)
################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/CommonVoice/ngrams/<language>/<seed>

# Data files
data_folder: !PLACEHOLDER # e.g, /localscratch/cv-corpus-14.0-2023-06-23/en
train_tsv_file: !ref <data_folder>/train.tsv
language: en
# accented_letters should be set according to the language
accented_letters: True
train_csv: !ref <output_folder>/train.csv
skip_prep: False
text_file: !ref <output_folder>/train.txt
ngram: 5
ngram_file: !ref <output_folder>/<language>_<ngram>gram.arpa
94 changes: 94 additions & 0 deletions recipes/CommonVoice/LM/train.py
@@ -0,0 +1,94 @@
"""
Recipe to train kenlm ngram model to combine an n-gram with Wav2Vec2.
https://huggingface.co/blog/wav2vec2-with-ngram

To run this recipe, do the following:
> python train.py hparams/train.yaml --data_folder=/path/to/CommonVoice
Author
* Pooneh Mousavi 2023
"""

import os
import csv
import sys
import logging
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
from hyperpyyaml import load_hyperpyyaml


logger = logging.getLogger(__name__)


def csv2text():
"""Read CSV file and convert specific data entries into text file.
"""
annotation_file = open(hparams["train_csv"], "r")
reader = csv.reader(annotation_file)
headers = next(reader, None)
text_file = open(hparams["text_file"], "w+")
index_label = headers.index("wrd")
row_idx = 0
for row in reader:
row_idx += 1
sent = row[index_label]
text_file.write(sent + "\n")
text_file.close()
annotation_file.close()
logger.info("Text file created at: " + hparams["text_file"])


if __name__ == "__main__":
# Load hyperparameters file with command-line overrides
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)

# Dataset prep (parsing Librispeech)
from common_voice_prepare import prepare_common_voice # noqa

# multi-gpu (ddp) save data preparation
if not os.path.exists(hparams["text_file"]):
run_on_main(
prepare_common_voice,
kwargs={
"data_folder": hparams["data_folder"],
"save_folder": hparams["output_folder"],
"train_tsv_file": hparams["train_tsv_file"],
"accented_letters": hparams["accented_letters"],
"language": hparams["language"],
"skip_prep": hparams["skip_prep"],
},
)
csv2text()

logger.info(f"Start tarining {hparams['ngram']}-gram kenlm model.")
tmp_ngram_file = "ngram.arpa"
cmd = f'lmplz -o {hparams["ngram"]} <"{hparams["text_file"]}" > "{tmp_ngram_file}"'
os.system(cmd)
with open(tmp_ngram_file, "r") as read_file, open(
hparams["ngram_file"], "w"
) as write_file:
has_added_eos = False
for line in read_file:
if not has_added_eos and "ngram 1=" in line:
count = line.strip().split("=")[-1]
write_file.write(line.replace(f"{count}", f"{int(count)+1}"))
elif not has_added_eos and "<s>" in line:
write_file.write(line)
write_file.write(line.replace("<s>", "</s>"))
has_added_eos = True
else:
write_file.write(line)
os.remove(tmp_ngram_file)
logger.info(
f"{hparams['ngram']}-gram kenlm model was built and saved in {hparams['ngram_file']}."
)
58 changes: 48 additions & 10 deletions speechbrain/inference/ASR.py
Expand Up @@ -17,6 +17,8 @@
import sentencepiece
import speechbrain
from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.fetching import fetch
from speechbrain.utils.data_utils import split_path


class EncoderDecoderASR(Pretrained):
Expand Down Expand Up @@ -46,8 +48,11 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = self.hparams.tokenizer
self.transducer_beam_search = False
self.transformer_beam_search = False
if hasattr(self.hparams, "transducer_beam_search"):
self.transducer_beam_search = self.hparams.transducer_beam_search
if hasattr(self.hparams, "transformer_beam_search"):
self.transformer_beam_search = self.hparams.transformer_beam_search

def transcribe_file(self, path, **kwargs):
"""Transcribes the given audiofile into a sequence of words.
Expand Down Expand Up @@ -98,6 +103,8 @@ def encode_batch(self, wavs, wav_lens):
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.encoder(wavs, wav_lens)
if self.transformer_beam_search:
encoder_out = self.mods.transformer.encode(encoder_out, wav_lens)
return encoder_out

def transcribe_batch(self, wavs, wav_lens):
Expand Down Expand Up @@ -130,11 +137,10 @@ def transcribe_batch(self, wavs, wav_lens):
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
if self.transducer_beam_search:
predicted_tokens, scores, _, _ = self.mods.decoder(encoder_out)
inputs = [encoder_out]
else:
predicted_tokens, scores = self.mods.decoder(
encoder_out, wav_lens
)
inputs = [encoder_out, wav_lens]
predicted_tokens, _, _, _ = self.mods.decoder(*inputs)
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predicted_tokens
Expand Down Expand Up @@ -165,14 +171,43 @@ class EncoderASR(Pretrained):
>>> asr_model.transcribe_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
"""

HPARAMS_NEEDED = ["tokenizer", "decoding_function"]
HPARAMS_NEEDED = ["tokenizer", "decoding_type"]
MODULES_NEEDED = ["encoder"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.tokenizer = self.hparams.tokenizer
self.decoding_function = self.hparams.decoding_function
self.decoding_type = self.hparams.decoding_type
self.set_decoding_function()

def set_decoding_function(self):
"""Set the decoding function based on the parameters defined in the hyperparameter file."""
if self.decoding_type == "beam":
if hasattr(self.hparams, "kenlm_model_path"):
source, fl = split_path(self.hparams.kenlm_model_path)
kenlm_model_path = str(fetch(fl, source=source, savedir="."))
self.hparams.test_beam_search[
"kenlm_model_path"
] = kenlm_model_path

vocab_list = [
self.tokenizer.id_to_piece(i)
for i in range(self.tokenizer.vocab_size())
]

from speechbrain.decoders.ctc import CTCBeamSearcher

self.decoding_function = CTCBeamSearcher(
**self.hparams.test_beam_search, vocab_list=vocab_list
)
else:
from functools import partial

self.decoding_function = partial(
speechbrain.decoders.ctc_greedy_decode,
blank_id=self.hparams.blank_index,
)

def transcribe_file(self, path, **kwargs):
"""Transcribes the given audiofile into a sequence of words.
Expand Down Expand Up @@ -265,10 +300,13 @@ def transcribe_batch(self, wavs, wav_lens):
elif isinstance(
self.tokenizer, sentencepiece.SentencePieceProcessor
):
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predictions
]
if self.decoding_type == "greedy":
predicted_words = [
self.tokenizer.decode_ids(token_seq)
for token_seq in predictions
]
else:
predicted_words = [hyp[0].text for hyp in predictions]
else:
raise ValueError(
"The tokenizer must be sentencepiece or CTCTextEncoder"
Expand Down
1 change: 1 addition & 0 deletions tests/recipes/CommonVoice.csv
Expand Up @@ -31,3 +31,4 @@ SSL,CommonVoice,recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_w
quantization,CommonVoice,recipes/CommonVoice/quantization/train.py,recipes/CommonVoice/quantization/hparams/train_with_hubert.yaml,recipes/CommonVoice/quantization/common_voice_prepare.py,recipes/CommonVoice/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"
quantization,CommonVoice,recipes/CommonVoice/quantization/train.py,recipes/CommonVoice/quantization/hparams/train_with_wav2vec.yaml,recipes/CommonVoice/quantization/common_voice_prepare.py,recipes/CommonVoice/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"
quantization,CommonVoice,recipes/CommonVoice/quantization/train.py,recipes/CommonVoice/quantization/hparams/train_with_wavlm.yaml,recipes/CommonVoice/quantization/common_voice_prepare.py,recipes/CommonVoice/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"
LM,CommonVoice,recipes/CommonVoice/LM/train.py,recipes/CommonVoice/LM/hparams/train_kenlm.yaml,recipes/CommonVoice/LM/common_voice_prepare.py,recipes/CommonVoice/LM/README.md,https://www.dropbox.com/scl/fo/zw505t10kesqpvkt6m3tu/h?rlkey=6626h1h665tvlo1mtekop9rx5&dl=0,,--data_folder=tests/samples/ASR/ --text_file=tests/samples/annotation/LM_train.txt --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]"