In [None]:
!pip install -q datasets \
                huggingface \
                jiwer \
                peft \
                transformers \
                torchaudio \
                torch \
                tqdm
!pip install -U -q datasets
!pip install -U -q bitsandbytes


from datasets import Audio, DatasetDict, load_dataset
from jiwer import wer, cer
from peft import LoraConfig, get_peft_model
from transformers import WhisperFeatureExtractor, \
                         WhisperForConditionalGeneration, \
                         WhisperProcessor, \
                         WhisperTokenizer
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
from transformers.trainer_seq2seq import Seq2SeqTrainer
import torch
from dataclasses import dataclass
from tqdm import tqdm
from typing import Any
import os

base_model_name = "openai/whisper-small"
language = "romanian"
task = "transcribe"
language_prefix = "ro"
org = "victors3136"

extractor = WhisperFeatureExtractor.from_pretrained(base_model_name)
tokenizer = WhisperTokenizer.from_pretrained(base_model_name, language=language, task=task)
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")

def compute_metric(prediction):
    pred_ids = prediction.predictions
    label_ids = prediction.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_strs = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_strs = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    return {
        "wer": wer(label_strs, pred_strs),
        "cer": cer(label_strs, pred_strs)
    }

@dataclass
class Collator:
    processor: Any
    def __call__(self, features):
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        labels_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(labels_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels

        return batch

def prepare(batch):
    audio = batch["audio"]
    batch["input_features"] = extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

def train_model(it: str, sp: str):
    dataset_name = f"victors3136/dataset-5k-{it}it-{sp}sp"
    model_name = f"whisper-model-small-ro-finetune-5k-{it}-{sp}"
    model_repo = f"{org}/{model_name}"

    dataset = DatasetDict()
    dataset["train"] = load_dataset(dataset_name, split="train").shuffle(seed=42)
    dataset["validation"] = load_dataset(dataset_name, split="val")
    dataset["test"] = load_dataset(dataset_name, split="test")

    processor = WhisperProcessor.from_pretrained(base_model_name, language=language, task=task)
    processor.tokenizer.set_prefix_tokens(language=language, task=task)

    dataset = dataset \
                .cast_column("audio", Audio(sampling_rate=16_000)) \
                .map(prepare, remove_columns=dataset["train"].column_names, num_proc=4)

    collator = Collator(processor=processor)

    model = WhisperForConditionalGeneration.from_pretrained(base_model_name, device_map="auto")
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    model = get_peft_model(model, config)

    training_args = Seq2SeqTrainingArguments(
        output_dir=model_repo,
        per_device_train_batch_size=32,
        gradient_accumulation_steps=2,
        learning_rate=2e-5,
        warmup_steps=30,
        num_train_epochs=5,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        logging_steps=50,
        fp16=True,
        per_device_eval_batch_size=16,
        predict_with_generate=True,
        generation_max_length=225,
        remove_unused_columns=False,
        report_to="wandb",
        push_to_hub=False
    )

    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        data_collator=collator,
        compute_metrics=compute_metric,
    )

    model.config.use_cache = False
    trainer.train()

    model.save_pretrained(model_repo)
    processor.save_pretrained(model_repo)
    trainer.push_to_hub("Training done")



params = [
          ("00", "00"),
          ("05", "05"),
          ("15", "15"),
          ("25", "25"),
          ("35", "35"),
          ("05", "25"),
          ("25", "05"),
          ("15", "35"),
          ("35", "15"),
          ("00", "50"),
          ("50", "00"),
          ("50", "50"),
        ]
for param_set in tqdm(params, "Training models... "):
    train_model(*param_set)