# evaluate WER of whisper with PEFT LoRA

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/evaluate-whisper-lora.ipynb)

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer bitsandbytes peft
# install then `import evaluate` throw error on kaggle

In [None]:
import torch
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer
import datasets as hugDS
import jiwer
# DO NOT USE `evaluate.evaluator`: buggy, cannot set language resulting very bad WER

JIWER_TRANS = jiwer.Compose([
	jiwer.RemoveEmptyStrings(),
	jiwer.ToLowerCase(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(**kwargs):  # disable streaming coz lightweight
	return hugDS.load_dataset(**kwargs, split="test", num_proc=2, trust_remote_code=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))

MY_DATA = hugDS.DatasetDict()
MY_DATA["commonvoice"] = load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",  ).select_columns(["audio", "sentence"])
MY_DATA["fleurs"]      = load_my_data(path="google/fleurs",                        name="vi_vn").select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
MY_DATA["vivos"]       = load_my_data(path="vivos"                                             ).select_columns(["audio", "sentence"])
# samples count: 1326 + 857 + 760

In [None]:
PEFT_MODEL_ID = "daila/whisper-large-v3_LoRA_Common-Vi_WER"  # @param ["daila/whisper-large-v3_LoRA_Common-Vi_WER", "daila/whisper-large-v3_LoRA_vi", "vikas85/whisper-vlsp-peft", "vikas85/whisper-vlsp", "vikas85/whisper-fosd-peft", "vikas85/whisper-fleurs-peft-vi-2", "DuyTa/vi-whisper-medium-Lora", "vikas85/whisper-cv-fleur-v6", "vikas85/fleurs-vn-peft-v2", "Yuhthe/openai-whisper-small-vivos-LORA-colab"]
BASE_MODEL_ID = PeftConfig.from_pretrained(PEFT_MODEL_ID).base_model_name_or_path
print("adapter to", BASE_MODEL_ID)

FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(BASE_MODEL_ID)
TOKENIZER = WhisperTokenizer.from_pretrained(BASE_MODEL_ID, language="vi", task="transcribe")

MODEL = PeftModel.from_pretrained(
	WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_ID, load_in_8bit=True, device_map="auto"),
	PEFT_MODEL_ID
).merge_and_unload(progressbar=True)  # reduce latency with LoRA

In [None]:
@torch.autocast(device_type="cuda")  # required by PEFT
@torch.inference_mode()
def predict(batch):
	inputs = FEATURE_EXTRACTOR(batch["audio"]["array"], sampling_rate=SAMPLING_RATE, return_tensors="pt").to(MODEL.device)
	predicted_ids = MODEL.generate(**inputs)
	batch["pred"] = TOKENIZER.batch_decode(predicted_ids, skip_special_tokens=True)[0]
	return batch

MY_DATA_BIS = MY_DATA.map(predict, remove_columns=["audio"])  # progress bar included

In [None]:
for split in ["commonvoice", "fleurs", "vivos"]:
	wer = 100 * jiwer.wer(
		reference=MY_DATA_BIS[split]["sentence"], hypothesis=MY_DATA_BIS[split]["pred"],
		reference_transform=JIWER_TRANS,          hypothesis_transform=JIWER_TRANS,
	)
	print(f"WER on {split} = {wer:.1f}%")