# Fine-tuning ArTST for Arabic ASR

Dataset: [Classical Arabic TTS Corpus](https://huggingface.co/datasets/MBZUAI/ClArTTS)

In [None]:
! pip install -q transformers datasets librosa evaluate jiwer accelerate transformers[torch] pyarabic sentencepiece

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import torch
import warnings
warnings.filterwarnings("ignore") #prevent printing of warning messages

### Data Preprocessing functions

Text preprocessing steps described in the paper.

In [None]:
import re
import sys
import unicodedata
import pyarabic.araby as araby
map_numbers = {'0': '٠', '1': '١', '2': '٢', '3': '٣', '4': '٤', '5': '٥', '6': '٦', '7': '٧', '8': '٨', '9': '٩'}
map_numbers = dict((v, k) for k, v in map_numbers.items())
punctuations = ''.join([chr(i) for i in list(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith('P'))])
punctuations = punctuations + '÷#ݣ+=|$×⁄<>`åûݘ ڢ̇ پ'

def convert_numerals_to_digit(word):
    sentence=[]
    for w in word:
        sentence.append(map_numbers.get(w, w))
    word = ''.join(sentence)
    return word

def remove_diacritics(word):
    return araby.strip_diacritics(word)

def remove_punctuation(word):
    return word.translate(str.maketrans('', '', re.sub('[@% ]','', punctuations))).lower()

def normalize_text(text):
    # remove diacritics
    text = remove_diacritics(text)
    # number mapping
    text = convert_numerals_to_digit(text)
    # punctuation removal
    text = remove_punctuation(text)
    return text

### Load Dataset from Huggingface

In [None]:
from datasets import load_dataset

dataset = load_dataset("MBZUAI/ClArTTS")

dataset

In [None]:
# View dataset features
dataset['train'].features

In [None]:
from IPython.display import Audio

# play audio sample
print(dataset['train'][0]['text'])
Audio(dataset['train'][0]['audio'], rate=dataset['train']['sampling_rate'][0])


In [None]:
from IPython.display import Audio

# play audio sample
print(dataset['train'][5]['text'])
Audio(dataset['train'][5]['audio'], rate=dataset['train']['sampling_rate'][5])


### Feature Extraction

In [None]:
import numpy as np
import librosa
from transformers import SpeechT5Processor, SpeechT5Tokenizer

model_id = "mbzuai/artst_asr"
tokenizer = SpeechT5Tokenizer.from_pretrained(model_id)
processor = SpeechT5Processor.from_pretrained(model_id)
sampling_rate = processor.feature_extractor.sampling_rate
print(f"Model expects {sampling_rate} sr")

def prepare_dataset(example):
    #  resample audio with librosa
    audio = librosa.resample(np.array(example["audio"]), orig_sr=example['sampling_rate'], target_sr=sampling_rate)
    text = normalize_text(example["text"]) # text preprocessing steps

    # use speecht5 processor for feature extraction, pass in audio, target text
    example = processor(
        audio=audio,
        sampling_rate=sampling_rate,
        text_target=text,
    )
    # # compute input length of audio sample in seconds
    example["input_length"] = len(audio) / sampling_rate

    return example

In [None]:
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names['train'])
dataset

Finally, we filter any training data with audio samples longer than 30s. We define a function that returns True for samples that are less than 30s, and False for those that are longer:

In [None]:
max_input_length = 30.0

def is_audio_in_length_range(length):
    return length < max_input_length

In [None]:
dataset = dataset.filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)

### Data Collator for Training

In [None]:
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from torch.nn.utils.rnn import pad_sequence

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]], padding=True
    ) -> Dict[str, torch.Tensor]:
        batch = {}
        # split inputs and labels since they have to be of different lengths and need different padding methods
        labels_batch = processor.tokenizer.pad({'input_ids':[ sample['labels'] for sample in features]}, return_tensors="pt")

        batch['input_values'] = pad_sequence([torch.tensor(sample['input_values'][0]) for sample in features], batch_first=True)
        batch['attention_mask'] = pad_sequence([torch.tensor(sample['attention_mask'][0]) for sample in features], batch_first=True)

        labels = [{"labels": feature["labels"]} for feature in features]


        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )


        batch["labels"] = labels
        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Test Dataloader

In [None]:
features = [
    dataset['train'][0],
    dataset['train'][1],
    dataset['train'][2],
]

batch = data_collator(features)

In [None]:
{k:v.shape for k,v in batch.items()}

In [None]:
batch['labels']

In [None]:
batch['input_values'][2]

### Evaluation Metrics

In [None]:
import evaluate

wer = evaluate.load("wer")
cer = evaluate.load("cer")

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # filtering step to only evaluate the samples that correspond to non-zero references:
    pred_str_norm = [
        pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0
    ]
    label_str_norm = [
        label_str[i]
        for i in range(len(label_str))
        if len(label_str[i]) > 0
    ]

    pred_chr = [w.replace(' ','') for w in pred_str]
    label_chr = [w.replace(' ','') for w in label_str]

    # compute metrics
    _wer = 100 * wer.compute(predictions=pred_str, references=label_str)
    _wer_non_zero = 100 * wer.compute(predictions=pred_str_norm, references=label_str_norm)
    _cer = 100 * cer.compute(predictions=pred_chr, references=label_chr)

    return {"wer": _wer, "cer": _cer, "wer_non_zero": _wer_non_zero}

### Load Pre-trained Checkpoint

In [None]:
from transformers import SpeechT5ForSpeechToText

model = SpeechT5ForSpeechToText.from_pretrained(model_id)
model.to(device)

In [None]:
# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

### Define the Training Configuration

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./ASR_Output",
    auto_find_batch_size=True,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=6e-5,
    lr_scheduler_type="inverse_sqrt",
    warmup_steps=100,
    max_steps=2000,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=250,
    eval_steps=250,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to="tensorboard",

)

In [None]:
training_data = dataset["train"].train_test_split(test_size=0.2)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=training_data["train"],
    eval_dataset=training_data["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

### Training

In [None]:
trainer.train()

### Evaluate

In [None]:
trainer.evaluate(dataset['test'])