# Fine-tuning ArTST for Arabic ASR

Dataset: [Classical Arabic TTS Corpus](https://huggingface.co/datasets/MBZUAI/ClArTTS)

In [35]:
! pip install -q transformers datasets librosa evaluate jiwer accelerate transformers[torch] pyarabic sentencepiece

In [36]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [37]:
import torch
import warnings
warnings.filterwarnings("ignore") #prevent printing of warning messages

### Data Preprocessing functions

Text preprocessing steps described in the paper.

In [42]:
import re
import sys
import unicodedata
import pyarabic.araby as araby
map_numbers = {'0': '٠', '1': '١', '2': '٢', '3': '٣', '4': '٤', '5': '٥', '6': '٦', '7': '٧', '8': '٨', '9': '٩'}
map_numbers = dict((v, k) for k, v in map_numbers.items())
punctuations = ''.join([chr(i) for i in list(i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith('P'))])
punctuations = punctuations + '÷#ݣ+=|$×⁄<>`åûݘ ڢ̇ پ'

def convert_numerals_to_digit(word):
    sentence=[]
    for w in word:
        sentence.append(map_numbers.get(w, w))
    word = ''.join(sentence)
    return word

def remove_diacritics(word):
    return araby.strip_diacritics(word)

def remove_punctuation(word):
    return word.translate(str.maketrans('', '', re.sub('[@% ]','', punctuations))).lower()

def normalize_text(text):
    # remove diacritics
    text = remove_diacritics(text)
    # number mapping
    text = convert_numerals_to_digit(text)
    # punctuation removal
    text = remove_punctuation(text)
    return text

### Load Dataset from Huggingface

In [41]:
from datasets import load_dataset

dataset = load_dataset("MBZUAI/ClArTTS")

dataset

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/21 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'file', 'audio', 'sampling_rate', 'duration'],
        num_rows: 9500
    })
    test: Dataset({
        features: ['text', 'file', 'audio', 'sampling_rate', 'duration'],
        num_rows: 205
    })
})

In [39]:
# View dataset features
dataset['train'].features

{'input_values': Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'decoder_attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'input_length': Value(dtype='float64', id=None)}

In [43]:
from IPython.display import Audio

# play audio sample
print(dataset['train'][0]['text'])
Audio(dataset['train'][0]['audio'], rate=dataset['train']['sampling_rate'][0])


.لِأَنَّهُ لَا يَرَى أَنَّهُ عَلَى السَّفَهِ .ثُمَّ مِنْ بَعْدِ ذَلِكَ حَدِيثٌ مُنْتَشِرٌ


In [44]:
from IPython.display import Audio

# play audio sample
print(dataset['train'][5]['text'])
Audio(dataset['train'][5]['audio'], rate=dataset['train']['sampling_rate'][5])


مَا أَعْظَمُ الْمَصَائِبِ عِنْدَكُمْ؟ :فَقَالَ


### Feature Extraction

In [None]:
import numpy as np
import librosa
from transformers import SpeechT5Processor, SpeechT5Tokenizer

model_id = "mbzuai/artst_asr"
tokenizer = SpeechT5Tokenizer.from_pretrained(model_id)
processor = SpeechT5Processor.from_pretrained(model_id)
sampling_rate = processor.feature_extractor.sampling_rate
print(f"Model expects {sampling_rate} sr")

def prepare_dataset(example):
    #  resample audio with librosa
    audio = librosa.resample(np.array(example["audio"]), orig_sr=example['sampling_rate'], target_sr=sampling_rate)
    text = normalize_text(example["text"]) # text preprocessing steps

    # use speecht5 processor for feature extraction, pass in audio, target text
    example = processor(
        audio=audio,
        sampling_rate=sampling_rate,
        text_target=text,
    )
    # # compute input length of audio sample in seconds
    example["input_length"] = len(audio) / sampling_rate

    return example

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

spm_char.model:   0%|          | 0.00/404k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/458 [00:00<?, ?B/s]

Model expects 16000 sr


In [None]:
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names['train'])
dataset

Map:   0%|          | 0/9500 [00:00<?, ? examples/s]

Map:   0%|          | 0/205 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_values', 'attention_mask', 'labels', 'decoder_attention_mask', 'input_length'],
        num_rows: 9500
    })
    test: Dataset({
        features: ['input_values', 'attention_mask', 'labels', 'decoder_attention_mask', 'input_length'],
        num_rows: 205
    })
})

Finally, we filter any training data with audio samples longer than 30s. We define a function that returns True for samples that are less than 30s, and False for those that are longer:

In [None]:
max_input_length = 30.0

def is_audio_in_length_range(length):
    return length < max_input_length

In [None]:
dataset = dataset.filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)

Filter:   0%|          | 0/9500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/205 [00:00<?, ? examples/s]

### Data Collator for Training

In [None]:
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from torch.nn.utils.rnn import pad_sequence

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]], padding=True
    ) -> Dict[str, torch.Tensor]:
        batch = {}
        # split inputs and labels since they have to be of different lengths and need different padding methods
        labels_batch = processor.tokenizer.pad({'input_ids':[ sample['labels'] for sample in features]}, return_tensors="pt")

        batch['input_values'] = pad_sequence([torch.tensor(sample['input_values'][0]) for sample in features], batch_first=True)
        batch['attention_mask'] = pad_sequence([torch.tensor(sample['attention_mask'][0]) for sample in features], batch_first=True)

        labels = [{"labels": feature["labels"]} for feature in features]


        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )


        batch["labels"] = labels
        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Test Dataloader

In [None]:
features = [
    dataset['train'][0],
    dataset['train'][1],
    dataset['train'][2],
]

batch = data_collator(features)

In [None]:
{k:v.shape for k,v in batch.items()}

{'input_values': torch.Size([3, 84669]),
 'attention_mask': torch.Size([3, 84669]),
 'labels': torch.Size([3, 52])}

In [None]:
batch['labels']

tensor([[   4,    6,   18,    9,   15,    4,    6,    5,    4,    7,   12,   29,
            4,   18,    9,   15,    4,   13,    6,   29,    4,    5,    6,   19,
           20,   15,    4,   33,    8,    4,    8,    9,    4,   14,   13,   16,
            4,   25,    6,   21,    4,   23,   16,    7,   33,    4,    8,    9,
           11,   28,   12,    2],
        [   4,    5,    6,   13,    8,   12,    4,    7,    9,   22,   30,    4,
           10,    5,    6,   25,    9,   10,   14,    4,   11,   34,    7,   16,
            2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100],
        [   4,   22,    6,    7,    6,   17,    4,   14,    7,    9,    4,    5,
            6,    8,   21,   33,   12,    7,    9,    4,   20,   27,    9,    4,
            5,    6,    9,    5,   19,    2, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -1

In [None]:
batch['input_values'][2]

tensor([2.2279e-05, 3.5953e-08, 1.2527e-05,  ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00])

### Evaluation Metrics

In [None]:
import evaluate

wer = evaluate.load("wer")
cer = evaluate.load("cer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

In [None]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # filtering step to only evaluate the samples that correspond to non-zero references:
    pred_str_norm = [
        pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0
    ]
    label_str_norm = [
        label_str[i]
        for i in range(len(label_str))
        if len(label_str[i]) > 0
    ]

    pred_chr = [w.replace(' ','') for w in pred_str]
    label_chr = [w.replace(' ','') for w in label_str]

    # compute metrics
    _wer = 100 * wer.compute(predictions=pred_str, references=label_str)
    _wer_non_zero = 100 * wer.compute(predictions=pred_str_norm, references=label_str_norm)
    _cer = 100 * cer.compute(predictions=pred_chr, references=label_chr)

    return {"wer": _wer, "cer": _cer, "wer_non_zero": _wer_non_zero}

### Load Pre-trained Checkpoint

In [None]:
from transformers import SpeechT5ForSpeechToText

model = SpeechT5ForSpeechToText.from_pretrained(model_id)
model.to(device)

config.json:   0%|          | 0.00/2.12k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/619M [00:00<?, ?B/s]

SpeechT5ForSpeechToText has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

SpeechT5ForSpeechToText(
  (speecht5): SpeechT5Model(
    (encoder): SpeechT5EncoderWithSpeechPrenet(
      (prenet): SpeechT5SpeechEncoderPrenet(
        (feature_encoder): SpeechT5FeatureEncoder(
          (conv_layers): ModuleList(
            (0): SpeechT5GroupNormConvLayer(
              (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
              (activation): GELUActivation()
              (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
            )
            (1-4): 4 x SpeechT5NoLayerNormConvLayer(
              (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
              (activation): GELUActivation()
            )
            (5-6): 2 x SpeechT5NoLayerNormConvLayer(
              (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
              (activation): GELUActivation()
            )
          )
        )
        (feature_projection): SpeechT5FeatureProjection(
          (layer_norm): LayerNorm((512,),

In [None]:
# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

### Define the Training Configuration

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./ASR_Output",
    auto_find_batch_size=True,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=6e-5,
    lr_scheduler_type="inverse_sqrt",
    warmup_steps=100,
    max_steps=2000,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=250,
    eval_steps=250,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to="tensorboard",

)

In [None]:
training_data = dataset["train"].train_test_split(test_size=0.2)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=training_data["train"],
    eval_dataset=training_data["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

### Training

In [None]:
trainer.train()

Step,Training Loss,Validation Loss,Wer,Cer,Wer Non Zero
250,0.0243,0.035893,4.269945,1.538876,4.269945
500,0.0213,0.035648,3.992333,1.380392,3.992333
750,0.0504,0.032269,3.879966,1.342356,3.879966
1000,0.0426,0.032706,3.602353,1.337602,3.602353
1250,0.0409,0.032439,3.456937,1.099876,3.456937
1500,0.0366,0.033267,3.410668,1.074519,3.410668
1750,0.0393,0.03368,3.390839,1.057086,3.390839
2000,0.0329,0.034426,3.463547,1.071349,3.463547
2250,0.0333,0.03509,3.516425,1.077689,3.516425
2500,0.0312,0.034722,3.516425,1.082443,3.516425


There were missing keys in the checkpoint model loaded: ['text_decoder_postnet.lm_head.weight'].


TrainOutput(global_step=4000, training_loss=0.032829076558351517, metrics={'train_runtime': 7290.3467, 'train_samples_per_second': 17.557, 'train_steps_per_second': 0.549, 'total_flos': 1.2266540487568626e+19, 'train_loss': 0.032829076558351517, 'epoch': 16.842105263157894})

### Evaluate

In [None]:
trainer.evaluate(dataset['test'])

{'eval_loss': 0.05657927691936493,
 'eval_wer': 4.132231404958678,
 'eval_cer': 1.2680115273775217,
 'eval_wer_non_zero': 4.132231404958678,
 'eval_runtime': 13.6849,
 'eval_samples_per_second': 14.98,
 'eval_steps_per_second': 1.9,
 'epoch': 16.842105263157894}