# evaluate WER of whisper models

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/evaluate-whisper.ipynb)

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer
# install then `import evaluate` throw error on kaggle

In [None]:
import torch
from transformers import pipeline
import datasets as hugDS
import jiwer
# DO NOT USE `evaluate.evaluator`: buggy, cannot set language resulting very bad WER

JIWER_TRANS = jiwer.Compose([  # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
	jiwer.ToLowerCase(),
	jiwer.RemoveKaldiNonWords(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(**kwargs):  # disable streaming coz lightweight
	return hugDS.load_dataset(**kwargs, split="test", trust_remote_code=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))

MY_DATA = hugDS.DatasetDict()
MY_DATA["commonvoice"] = load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",  ).select_columns(["audio", "sentence"])
MY_DATA["fleurs"]      = load_my_data(path="google/fleurs",                        name="vi_vn").select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
MY_DATA["vivos"]       = load_my_data(path="vivos"                                             ).select_columns(["audio", "sentence"])
# samples count: 1326 + 857 + 760

In [None]:
MODEL_ID = "openai/whisper-tiny"  # @param ["openai/whisper-large-v3", "openai/whisper-large-v2", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-tiny", "openai/whisper-large-v3", "doof-ferb/whisper-tiny-vi"]
PIPE = pipeline(task="automatic-speech-recognition", model=MODEL_ID, device="cuda:0", torch_dtype=torch.float16)
PIPE_KWARGS = {"language": "vi", "task": "transcribe"}

In [None]:
@torch.inference_mode()
def predict(batch):
	batch["pred"] = PIPE(batch["audio"], generate_kwargs=PIPE_KWARGS)["text"]
	return batch

MY_DATA_BIS = MY_DATA.map(predict, remove_columns=["audio"])  # progress bar included

In [None]:
for split in ["commonvoice", "fleurs", "vivos"]:
	wer = 100 * jiwer.wer(
		reference=MY_DATA_BIS[split]["sentence"], hypothesis=MY_DATA_BIS[split]["pred"],
		reference_transform=JIWER_TRANS,          hypothesis_transform=JIWER_TRANS,
	)
	print(f"WER on {split} = {wer:.1f}%")