In [1]:
# %env LC_ALL=C.UTF-8
# %env LANG=C.UTF-8
# %env TRANSFORMERS_CACHE=/content/cache
# %env HF_DATASETS_CACHE=/content/cache
# %env CUDA_LAUNCH_BLOCKING=1

In [32]:
import torch
import torchaudio
import librosa
from importlib import reload

from datasets import load_dataset, load_metric

import pandas as pd
import numpy as np

import hazm
from num2fawords import words, ordinal_words
from tqdm import tqdm

import os
import string
import six
import re
import glob

## Tokenizer

In [2]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer(
    "./fa-vocab.json", 
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    word_delimiter_token="|",
    do_lower_case=False
)

In [3]:
text = "از مهمونداری کنار بکشم"
print(" ".join(tokenizer.tokenize(text)))
print(tokenizer.decode(tokenizer.encode(text)))

ا ز | م ه م و ن د ا ر ی | ک ن ا ر | ب ک ش م
از مهمونداری کنار بکشم


## Feature Extractor

In [4]:
from transformers import Wav2Vec2FeatureExtractor


feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, 
                                             sampling_rate=16000, 
                                             padding_value=0.0, 
                                             do_normalize=True, 
                                             return_attention_mask=True)

## Processor

In [5]:
from transformers import Wav2Vec2Processor


processor = Wav2Vec2Processor(feature_extractor=feature_extractor, 
                              tokenizer=tokenizer)

In [6]:
if len(processor.tokenizer.get_vocab()) == len(processor.tokenizer):
    print(len(processor.tokenizer))

40


In [7]:
save_dir = "weights/wav2vec2-large-xlsr-persian-shemo"

processor.save_pretrained(save_dir)

## Dataset

In [8]:
import dataset
reload(dataset)
from dataset import get_datasets

common_voice_train, common_voice_test = get_datasets('/media/data/soroosh/dataset/ASR/shemo-fa/train.csv', 
                                                     '/media/data/soroosh/dataset/ASR/shemo-fa/test.csv', 
                                                     processor,
                                                     n_jobs=20, 
                                                     min_secs=1, 
                                                     max_secs=20, 
                                                     make_vocab=False)

Using custom data configuration default-f62212371f721c8f
Reusing dataset csv (/home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)
Using custom data configuration default-7e5897945ab50ade
Reusing dataset csv (/home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-3e33ac38e0d9988c.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-7cfd08d61b6b542d.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd

Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-5c32b1c150907228.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-c148ea65c45f4d1c.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-997e3933d272dca6.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-738ee03122b39276.arrow


Split sizes [BEFORE]: 2554 train and 284 validation.


Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-9decde67de1b5490.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-0d00a7b47d9f00bc.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-ebd04b6274edd110.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-de0fe88bcd681ff2.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8

Split sizes [AFTER]: 2410 train and 284 validation.


Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-f62212371f721c8f/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-29f282eb18e45237.arrow























Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-620826a316a58d9f.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-edd9e55bc63661ec.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-e0a40f3142824a23.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0/cache-d31eaab6290fc8c2.arrow
Loading cached processed dataset at /home/soroosh/.cache/huggingface/datasets/csv/default-7e5897945ab50ade/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8

## Data Collator

In [9]:
from dataset import DataCollatorCTCWithPadding

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

## Metric

In [10]:
wer_metric = load_metric("wer")

In [11]:
import random


def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    
#     if isinstance(label_str, list):
#         if isinstance(pred_str, list) and len(pred_str) == len(label_str):
#             for index in random.sample(range(len(label_str)), 3):
#                 print(f'reference: "{label_str[index]}"')
#                 print(f'predicted: "{pred_str[index]}"')

#         else:
#             for index in random.sample(range(len(label_str)), 3):
#                 print(f'reference: "{label_str[index]}"')
#                 print(f'predicted: "{pred_str}"')

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

## Model

In [12]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    ctc_zero_infinity=True,
    bos_token_id=processor.tokenizer.bos_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer.get_vocab())
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
model.freeze_feature_extractor()

In [15]:
model.num_parameters()

315479720

In a final step, we define all parameters related to training. 
To give more explanation on some of the parameters:
- `group_by_length` makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
- `learning_rate` and `weight_decay` were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the [docs](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer#trainingarguments).

**Note**: If one wants to save the trained models in his/her google drive the commented-out `output_dir` can be used instead.

## Trainer

In [16]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='/media/data/soroosh/' + save_dir,
    group_by_length=True,
    
    per_device_train_batch_size=5,
    per_device_eval_batch_size=5,
    gradient_accumulation_steps=2,
    
    num_train_epochs=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    
    fp16=True,
    weight_decay=1e-3,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
)

In [17]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

### Training

In case you want to use this google colab to fine-tune your model, you should make sure that your training doesn't stop due to inactivity. A simple hack to prevent this is to paste the following code into the console of this tab (*right mouse click -> inspect -> Console tab and insert code*).

```javascript
function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}
setInterval(ConnectButton,60000);
```

In [18]:
train_result = trainer.train()

Epoch,Training Loss,Validation Loss,Wer,Runtime,Samples Per Second
1,7.8022,3.059248,1.0,22.2578,12.76
2,3.0411,2.975186,1.0,22.2561,12.761
3,2.9518,2.633955,1.0,22.2777,12.748
4,1.947,1.010393,0.663939,22.369,12.696
5,1.2776,0.817091,0.532121,22.3706,12.695
6,1.0907,0.75778,0.491818,22.3583,12.702
7,0.9687,0.720564,0.459091,22.4363,12.658
8,0.8902,0.68329,0.415455,22.4027,12.677
9,0.8052,0.733186,0.408788,22.389,12.685
10,0.7567,0.642406,0.407879,22.4135,12.671


KeyboardInterrupt: 

In [37]:
trainer.save_model()
trainer.save_state()

metrics = train_result.metrics
max_train_samples = len(_common_voice_train)
metrics["train_samples"] = min(max_train_samples, len(_common_voice_train))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

## Evaluate

In [22]:
metrics = trainer.evaluate()
max_val_samples = len(_common_voice_test)
metrics["eval_samples"] = len(common_voice_test)

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [23]:
metrics

{'eval_loss': 0.6420636177062988,
 'eval_wer': 0.36333333333333334,
 'eval_runtime': 22.5201,
 'eval_samples_per_second': 12.611,
 'eval_samples': 284}

## load and evaluate

In [30]:
model = Wav2Vec2ForCTC.from_pretrained('/media/data/soroosh/' + save_dir).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(save_dir)

In [33]:
input_dict = processor(common_voice_test["input_values"][0], return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

It is strongly recommended to pass the ``sampling_rate`` argument to this function.Failing to do so can result in silent errors that might be hard to debug.


In [36]:
print("Prediction:")
print(processor.decode(pred_ids))

print("\nReference:")
print(processor.tokenizer.decode(common_voice_test["labels"][0]))

Prediction:
سر هنک آرتاه که به تاضگی در صمت ریاصت نضمی کرمان شاه منصوب شده است عزمه آنجاست

Reference:
سرهنگ آرتا که به تازگی در سمت ریاست نظمیه کرمانشاه منسوب شده است عازم آنجا است
