# evaluate WER of whisper with PEFT LoRA

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/eval/evaluate-whisper-lora.ipynb)

*kaggle TPU crash when running inference* ~~can be used on kaggle TPU, but do not enable `XLA_USE_BF16` because of AMP (Automatic Mixed Precision)~~

try `transformers.pipeline` but error with `torch.autocast`

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer bitsandbytes peft
# install then `import evaluate` throw error on kaggle

In [None]:
import torch
# import torch_xla.core.xla_model as xm  # on kaggle TPU
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer
import datasets as hugDS
import jiwer

JIWER_TRANS = jiwer.Compose([  # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
	jiwer.ToLowerCase(),
	jiwer.RemoveKaldiNonWords(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(**kwargs):
	return hugDS.load_dataset(**kwargs, split="test", trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))

MY_DATA = hugDS.DatasetDict()
MY_DATA["commonvoice"] = load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",  ).select_columns(["audio", "sentence"])
MY_DATA["fleurs"]      = load_my_data(path="google/fleurs",                        name="vi_vn").select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
MY_DATA["vivos"]       = load_my_data(path="vivos"                                             ).select_columns(["audio", "sentence"])
MY_DATA["bud500"]      = load_my_data(path="linhtran92/viet_bud500"                            ).rename_column("transcription", "sentence")
MY_DATA["lsvsc"]       = load_my_data(path="doof-ferb/LSVSC"                                   ).select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
# samples count: 1326 + 857 + 760 + 7500 + 5683

In [None]:
PEFT_MODEL_ID = "doof-ferb/whisper-large-peft-lora-vi"  # @param ["doof-ferb/whisper-large-peft-lora-vi", "daila/whisper-large-v3_LoRA_Common-Vi_WER", "daila/whisper-large-v3_LoRA_vi", "vikas85/whisper-vlsp-peft", "vikas85/whisper-vlsp", "vikas85/whisper-fosd-peft", "vikas85/whisper-fleurs-peft-vi-2", "DuyTa/vi-whisper-medium-Lora", "vikas85/whisper-cv-fleur-v6", "vikas85/fleurs-vn-peft-v2", "Yuhthe/openai-whisper-small-vivos-LORA-colab"]
BASE_MODEL_ID = PeftConfig.from_pretrained(PEFT_MODEL_ID).base_model_name_or_path
print("adapter to", BASE_MODEL_ID)

# declare task & language in extractor & tokenizer have no effect in inference
FEATURE_EXTRACTOR = WhisperFeatureExtractor.from_pretrained(BASE_MODEL_ID)
TOKENIZER = WhisperTokenizer.from_pretrained(BASE_MODEL_ID)

MODEL = PeftModel.from_pretrained(
	WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16).to("cuda"),  # load_in_8bit make inference super slow
	# WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.bfloat16).to(xm.xla_device()),  # on kaggle TPU
	PEFT_MODEL_ID
).merge_and_unload(progressbar=True)  # reduce latency with LoRA

# the only way to declare task & language
DECODER_ID = torch.tensor(
	TOKENIZER.convert_tokens_to_ids(["<|startoftranscript|>", "<|vi|>", "<|transcribe|>", "<|notimestamps|>"]),  # [50258, 50278, 50359, 50363] except for large-v3: [50258, 50278, 50360, 50364]
	device=MODEL.device
).unsqueeze(dim=0)

In [None]:
@torch.autocast(device_type="cuda")  # required by PEFT
# @torch.autocast(device_type="xla", dtype=torch.bfloat16)  # on kaggle TPU
@torch.inference_mode()
def predict(batch):
	inputs = FEATURE_EXTRACTOR(batch["audio"]["array"], sampling_rate=SAMPLING_RATE, return_tensors="pt").to(MODEL.device)
	predicted_ids = MODEL.generate(input_features=inputs.input_features, decoder_input_ids=DECODER_ID)
	batch["pred"] = TOKENIZER.batch_decode(predicted_ids, skip_special_tokens=True)[0]
	return batch

MY_DATA_BIS = MY_DATA.map(predict, remove_columns=["audio"])  # progress bar included

In [None]:
for split in MY_DATA_BIS.keys():
	wer = 100 * jiwer.wer(
		reference=MY_DATA_BIS[split]["sentence"], hypothesis=MY_DATA_BIS[split]["pred"],
		reference_transform=JIWER_TRANS,          hypothesis_transform=JIWER_TRANS,
	)
	print(f"WER on {split} = {wer:.1f}%")