# **Fine-tuning XLS-R for Multi-Lingual ASR with 🤗 Transformers**

**Wav2Vec2** is a pretrained model for Automatic Speech Recognition (ASR) and was released in [September 2020](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/) by *Alexei Baevski, Michael Auli, and Alex Conneau*.  Soon after the superior performance of Wav2Vec2 was demonstrated on one of the most popular English datasets for ASR, called [LibriSpeech](https://huggingface.co/datasets/librispeech_asr), *Facebook AI* presented a multi-lingual version of Wav2Vec2, called [XLSR](https://arxiv.org/abs/2006.13979). XLSR stands for *cross-lingual speech representations* and refers to model's ability to learn speech representations that are useful across multiple languages.

XLSR's successor, simply called **XLS-R** (refering to the [*''XLM-R*](https://ai.facebook.com/blog/-xlm-r-state-of-the-art-cross-lingual-understanding-through-self-supervision/) *for Speech''*), was released in [November 2021](https://ai.facebook.com/blog/xls-r-self-supervised-speech-processing-for-128-languages) by *Arun Babu, Changhan Wang, Andros Tjandra, et al.* XLS-R used almost **half a million** hours of audio data in 128 languages for self-supervised pre-training and comes in sizes ranging from 300 milion up to **two billion** parameters. You can find the pretrained checkpoints on the 🤗 Hub:

- [**Wav2Vec2-XLS-R-300M**](https://huggingface.co/facebook/wav2vec2-xls-r-300m)
- [**Wav2Vec2-XLS-R-1B**](https://huggingface.co/facebook/wav2vec2-xls-r-1b)
- [**Wav2Vec2-XLS-R-2B**](https://huggingface.co/facebook/wav2vec2-xls-r-2b)

Similar to [BERT's masked language modeling objective](http://jalammar.github.io/illustrated-bert/), XLS-R learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network during self-supervised pre-training (*i.e.* diagram on the left below). 

For fine-tuning, a single linear layer is added on top of the pre-trained network to train the model on labeled data of audio downstream tasks such as speech recognition, speech translation and audio classification (*i.e.* diagram on the right below).

![wav2vec2_structure](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/xls_r.png)

XLS-R shows impressive improvements over previous state-of-the-art results on both speech recognition, speech translation and speaker/language identification, *cf.* with Table 3-6, Table 7-10, and Table 11-12 respectively of the official [paper](https://ai.facebook.com/blog/xls-r-self-supervised-speech-processing-for-128-languages).

XLS-R is fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that is used to train neural networks for sequence-to-sequence problems, such as ASR and handwriting recognition. 

More reading on CTC can be found here: [*Sequence Modeling with CTC (2017)*](https://distill.pub/2017/ctc/) by Awni Hannun.

In [1]:
!pip install transformers datasets accelerate -U
!pip install jiwer
!pip install librosa



In [2]:
import os
from datasets import load_dataset
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
import utils

In [3]:
FINETUNED_MODEL_PATH = "finetuned_wav2vec2" #path to save the finetuned model
BASE_MODEL_NAME = "facebook/wav2vec2-xls-r-300m" #path of the model in huggingface
DATASET_NAME =  "google/fleurs" #name of the dataset in huggingface
KEEP_HEBREW_ONLY = True #Only keep samples that contain pure hebrew text-- This is mostly useful to ensure alighment with the kenlm model vocabulary later on.

### Dataset prep



Before training the model, some data preperation is required. The following steps will be performed:
- **Load the dataset**
- **Standardize the dataset**: This is a function which just standardizes the dataset to have the same column names as other datasets in huggingface, to help switching between datasets easily.
- **Drop non-hebrew samples**: This is a function which drops all samples that contain non-hebrew characters. This is useful to ensure that the kenlm model vocabulary will match the wav2vec2 model vocabulary.
- **Subsample the dataset** (optional): This is a function which subsamples the dataset to a smaller size. This is useful for testing purposes.

The code for these functions are imported from a seperate utls file in this repo for notebook clarity

In [4]:
if not os.path.isdir(FINETUNED_MODEL_PATH):
    os.mkdir(FINETUNED_MODEL_PATH)
dataset = load_dataset(DATASET_NAME, 'he_il')
print("Standardizing train and test splits")

for split in dataset:
    dataset[split] = utils.standardize_dataset(dataset[split])
    dataset[split]= dataset.map(utils.remove_special_characters)

    if KEEP_HEBREW_ONLY:
        dataset[split] = utils.drop_english_samples(dataset[split])

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Standardizing train and test splits
Removing unecessary columns
writing name of dataset to dataset column


Map:   0%|          | 0/3242 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3242 [00:00<?, ? examples/s]

Removing unecessary columns
writing name of dataset to dataset column


Map:   0%|          | 0/328 [00:00<?, ? examples/s]

Filter:   0%|          | 0/328 [00:00<?, ? examples/s]

Removing unecessary columns
writing name of dataset to dataset column


Map:   0%|          | 0/792 [00:00<?, ? examples/s]

Filter:   0%|          | 0/792 [00:00<?, ? examples/s]

In [5]:
if utils.SUBSAMPLE_RATIO < 1.0: #change this to 1.0 if you want to use the full dataset
    for name in dataset.keys():
        dataset[name] = utils.subsample_dataset(dataset[name])

In [7]:
dataset.column_names

{'train': ['audio', 'transcription'],
 'validation': ['audio', 'transcription'],
 'test': ['audio', 'transcription']}

In [10]:
utils.show_random_elements(dataset['train'].remove_columns('audio'), num_examples=10)

Unnamed: 0,transcription
0,עקב חשיבותה הדתית של העיר ובעיקר האתרים הרבים באזור העיר העתיקה ירושלים היא אחד מאתרי התיירות העיקריים של ישראל
1,עליכם לשים לב לתנוחת הקורבן כשאתם ניגשים אליו או אליה ולנורות אזהרה אוטומטיים
2,בדרך כלל משתמשים במערכת הבינה המלאכותית בתחומי הכלכלה רפואה הנדסה וצבא כמו כן היא הותקנה במספר תוכנות למחשבים ביתיים ומשחקי מחשב
3,אוריגמי פיורלנד הוא סוג אוריגמי שכולל את המגבלה שבכל שלב מותר לבצע רק קיפול בודד קיפולים מורכבים יותר כגון קיפול הפוך אסורים וכל הקיפולים מתבצעים במיקומים ברורים
4,סופוקלס ואריסטופנס הם עדיין מחזאים פופולריים והמחזות שלהם נמנים עם היצירות הגדולות של ספרות העולם
5,נתחיל בהסבר על התוכניות של איטליה איטליה הייתה בעיקר האח הקטן של גרמניה ויפן
6,עיצוב אינטראקטיבי דורש הערכה מחודשת של ההנחות שלכם לגבי הפקת מדיה ושתלמדו לחשוב בדרכים לא ליניאריות
7,רבות מהחיות אקזוטיות הן קשות למציאה ובפארקים יש לעתים חוקים לגבי צילום תמונות למטרות מסחריות
8,למשל ילדים שמזדהים עם מיעוט אתני שקיים לגביו סטראוטיפ שלפיו אינם מצליחים בלימודים נוטים שלא להצליח בלימודים לאחר שלמדו על הסטראוטיפ שנקשר לגזע שלהם
9,בתום הקרב על צרפת גרמניה החלה להערך לפלישה לאי הבריטי


### Create `Wav2Vec2CTCTokenizer` and `Wav2Vec2FeatureExtractor`

A pre-trained XLS-R model maps the speech signal to a sequence of context representations as illustrated in the figure above. However, for speech recognition the model has to to map this sequence of context representations to its corresponding transcription which means that a linear layer has to be added on top of the transformer block (shown in yellow in the diagram above). This linear layer is used to classifies each context representation to a token class analogous how, *e.g.*, after pretraining a linear layer is added on top of BERT's embeddings for further classification - *cf.* with *'BERT'* section of this [blog post](https://huggingface.co/blog/warm-starting-encoder-decoder).

The output size of this layer corresponds to the number of tokens in the vocabulary, which does **not** depend on XLS-R's pretraining task, but only on the labeled dataset used for fine-tuning. So in the first step, we will take a look at the chosen dataset of Common Voice and define a vocabulary based on the transcriptions.

In CTC, it is common to classify speech chunks into letters, so we will do the same here. 
Let's extract all distinct letters of the training and test data and build our vocabulary from this set of letters.

In [19]:
from datasets import Audio
import json
vocab = utils.get_vocab(dataset)

Map:   0%|          | 0/2349 [00:00<?, ? examples/s]

Map:   0%|          | 0/264 [00:00<?, ? examples/s]

Map:   0%|          | 0/603 [00:00<?, ? examples/s]

printing vocab with length:  34


In the add_special_characters function, we make it clearer that `" "` has its own token class, we give it a more visible character `|`. In addition, we also add an "unknown" token so that the model can later deal with characters not encountered in Common Voice's training set. 

Finally, we also add a padding token that corresponds to CTC's "*blank token*". The "blank token" is a core component of the CTC algorithm. For more information, please take a look at the "Alignment" section [here](https://distill.pub/2017/ctc/).

In [None]:
vocab = utils.add_special_characters(vocab)
print(vocab)

Now our vocabulary is complete and consists of 34 tokens, which means that the linear layer that we will add on top of the pretrained XLS-R checkpoint will have an output dimension of 34.

We use the json file to load the vocabulary into an instance of the `Wav2Vec2CTCTokenizer` class.


In [None]:
print("printing vocab with length: ", len(vocab))
with open(f"{FINETUNED_MODEL_PATH}/vocab.json", "w") as f:
    json.dump(vocab, f)
    
tokenizer = Wav2Vec2CTCTokenizer(f"{FINETUNED_MODEL_PATH}/vocab.json",
            unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

A `Wav2Vec2FeatureExtractor` object requires the following parameters to be instantiated:

- `feature_size`: Speech models take a sequence of feature vectors as an input. While the length of this sequence obviously varies, the feature size should not. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal ${}^2$.
- `sampling_rate`: The sampling rate at which the model is trained on.
- `padding_value`: For batched inference, shorter inputs need to be padded with a specific value
- `do_normalize`: Whether the input should be *zero-mean-unit-variance* normalized or not. Usually, speech models perform better when normalizing the input
- `return_attention_mask`: Whether the model should make use of an `attention_mask` for batched inference. In general, XLS-R models checkpoints should **always** use the `attention_mask`.

Regarding the sampling rate:

Speech is a continuous signal and to be treated by computers, it first has to be discretized, which is usually called **sampling**. The sampling rate hereby plays an important role in that it defines how many data points of the speech signal are measured per second. Therefore, sampling with a higher sampling rate results in a better approximation of the *real* speech signal but also necessitates more values per second.

A pretrained checkpoint expects its input data to have been sampled more or less from the same distribution as the data it was trained on. The same speech signals sampled at two different rates have a very different distribution, *e.g.*, doubling the sampling rate results in data points being twice as long. Thus, 
before fine-tuning a pretrained checkpoint of an ASR model, it is crucial to verify that the sampling rate of the data that was used to pretrain the model matches the sampling rate of the dataset used to fine-tune the model.

In [20]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(BASE_MODEL_NAME,
                            feature_size=1, sampling_rate=16000, padding_value=0.0, 
                            do_normalize=True, return_attention_mask=True
                            )



Great, XLS-R's feature extraction pipeline is thereby fully defined!

For improved user-friendliness, the feature extractor and tokenizer are *wrapped* into a single `Wav2Vec2Processor` class so that one only needs a `model` and `processor` object.

In [None]:
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

Finally, we can leverage `Wav2Vec2Processor` to process the data to the format expected by `Wav2Vec2ForCTC` for training .(for wav2vec2, this mostly just normalizes the data and converts feature names to their expected values)

We will do this in the prepare_dataset function, as well as encode the transcriptions to label ids. To do so, we will make use of Dataset's [`map(...)`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=map#datasets.DatasetDict.map) function.

In [21]:
for split, data in dataset.items():
    dataset[split] = dataset[split].map(utils.prepare_dataset, fn_kwargs={"processor": processor, "input_key": "input_values"},
                                        remove_columns=dataset[split].column_names)

Map:   0%|          | 0/2349 [00:00<?, ? examples/s]

Map:   0%|          | 0/264 [00:00<?, ? examples/s]

Map:   0%|          | 0/603 [00:00<?, ? examples/s]

Lets just validate that after all of our processing, the dataset looks as expected:

In [17]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(dataset["train"])-1)

print(dataset["train"][rand_int]["transcription"])
print("input array shape", dataset["train"][rand_int]["audio"]["array"].shape)
print("sampling rate", dataset["train"][rand_int]["audio"]["sampling_rate"])
ipd.Audio(data=dataset["train"][rand_int]["audio"]["array"], autoplay=True, rate=16000)


הם מתחילים כמשפכים היורדים מענני סערה ונהפכים ל סופות טורנדו כשהם נוגעים בקרקע
input array shape (147840,)
sampling rate 16000


## Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗's [Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) for which we essentially need to do the following:

- Define a data collator. In contrast to most NLP models, XLS-R has a much larger input length than output length. *E.g.*, a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLS-R requires a special padding data collator, which we will define below

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

In [22]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


In [23]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

The most common metric in ASR is the word error rate [(WER)](https://huggingface.co/metrics/wer) metric. We will use this to evaluate our model. 

To allow models to become independent of the speaker rate, in CTC, consecutive tokens that are identical are simply grouped as a single token. However, the encoded labels should not be grouped when decoding since they don't correspond to the predicted tokens of the model, which is why the `group_tokens=False` parameter has to be passed. If we wouldn't pass this parameter a word like `"hello"` would incorrectly be encoded, and decoded as `"helo"`.

${}^2$ The blank token allows the model to predict a word, such as `"hello"` by forcing it to insert the blank token between the two l's. A CTC-conform prediction of `"hello"` of our model would be `[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]`.

In [24]:
#import load_metric
from datasets import load_metric
import numpy as np
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Now, we can load the pretrained checkpoint of [Wav2Vec2-XLS-R-300M](https://huggingface.co/facebook/wav2vec2-xls-r-300m). The tokenizer's `pad_token_id` must be to define the model's `pad_token_id` or in the case of `Wav2Vec2ForCTC` also CTC's *blank token* ${}^2$. To save GPU memory, we enable PyTorch's [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html) and also set the loss reduction to "*mean*".

Because the dataset is quite small, the model requires some playing around with arguements such as masking dropout rate, layer dropout, and the learning rate until training can be stable enough for your dataset.

In [29]:
model = Wav2Vec2ForCTC.from_pretrained(
    BASE_MODEL_NAME, 
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.bias', 'lm_head.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The first component of XLS-R consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the [paper](https://arxiv.org/pdf/2006.13979.pdf) does not need to be fine-tuned anymore. 
Thus, we can set the `requires_grad` to `False` for all parameters of the *feature extraction* part.

In [30]:
model.freeze_feature_extractor()

Enable gradient checkpointing to save memory, by trading compute for memory. How this works, is that the model will compute the gradients for a few layers, then free the memory, and then compute the gradients for the next few layers. This will make the training slower, but will save memory.

In [None]:
model.gradient_checkpointing_enable() 

In a final step, we define all parameters related to training. 
To give more explanation on some of the parameters:
- `group_by_length` makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
- `learning_rate` and `weight_decay` were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the [docs](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer#trainingarguments).

In [31]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir= FINETUNED_MODEL_PATH,
  group_by_length=True,
  per_device_train_batch_size=8,
  gradient_accumulation_steps=4,
  evaluation_strategy="steps",
  num_train_epochs=3,
  fp16=True,
  gradient_checkpointing=True, 
  save_steps=500,
  # eval_steps=200,
  logging_steps=50,
  learning_rate=1e-5, #learning_rate=5e-5,
  weight_decay=0.005,
  # warmup_steps=1000,
  save_total_limit=2,
)

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.feature_extractor,
)




In [None]:
trainer.train()

In [None]:
trainer.save_model(FINETUNED_MODEL_PATH)