In [None]:
!pip install -q datasets \
                huggingface \
                jiwer \
                transformers \
                torchaudio \
                torch \
                tqdm
!pip install -U -q datasets

from datasets import Audio, load_dataset
from jiwer import cer, wer
from transformers import AutoModelForSpeechSeq2Seq, \
                         WhisperProcessor
import torch
import torchaudio
from tqdm import tqdm

base_model = "openai/whisper-small"

params = [
          ("00", "00"),
          ("05", "05"),
          ("15", "15"),
          ("25", "25"),
          ("35", "35"),
          ("05", "25"),
          ("25", "05"),
          ("15", "35"),
          ("35", "15"),
          ("00", "50"),
          ("50", "00"),
          ("50", "50"),
]
model_paths = [base_model] + \
              [f"victors3136/whisper-model-small-ro-finetune-5k-{it}-{sp}" for it, sp in params]
language = "ro"
task = "transcribe"
split = "test"
max_samples = 1_000

def speech_file_to_array_fn(batch):
    speech_array, _ = torchaudio.load(batch["audio"]["path"])
    batch["speech"] = speech_array[0].numpy()
    batch["target_text"] = batch["sentence"].lower()
    return batch


print("Loading Common Voice Romanian test data...")
dataset = load_dataset("mozilla-foundation/common_voice_11_0", language, split=split)\
            .filter(lambda x: x["sentence"] is not None and x["audio"] is not None)
dataset = dataset.select(range(min(max_samples, len(dataset))))\
            .map(speech_file_to_array_fn, num_proc=4)

processor = WhisperProcessor.from_pretrained(base_model, language=language, task=task)

preprocessed_inputs = [
    {
        "input_features": processor(audio=sample["speech"], sampling_rate=16_000, return_tensors="pt").input_features,
        "target_text": sample["target_text"]
    }
    for sample in tqdm(dataset, desc="Preprocessing inputs")
]

results = {}
for model_id in model_paths:
    print(f"\nEvaluating {model_id}...")
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id,
        device_map="auto"
    )
    model.config.forced_decoder_ids = processor.get_decoder_prompt_ids()
    predictions = []
    references = []
    for sample in tqdm(preprocessed_inputs, desc=f"Benchmarking {model_id}..."):
        inputs = {"input_features": sample["input_features"].to(model.device)}

        with torch.no_grad():
            predicted_ids = model.generate(
                **inputs,
                max_new_tokens=225,
                language="ro",
                task="transcribe"
            )

        transcription = processor.batch_decode(
            predicted_ids,
            skip_special_tokens=True,
            normalize=True
        )[0]
        predictions.append(transcription)
        references.append(sample['target_text'])


    wer_score = wer(references, predictions)
    cer_score = cer(references, predictions)
    results[model_id] = {"wer": wer_score, "cer": cer_score}

results