# evaluate WER of whisper models

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/eval/evaluate-whisper.ipynb)

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer
# install then `import evaluate` throw error on kaggle

In [None]:
from tqdm import tqdm
import torch
from transformers import pipeline
import datasets as hugDS
import jiwer
# DO NOT USE `evaluate.evaluator`: buggy, cannot set language resulting very bad WER

JIWER_TRANS = jiwer.Compose([  # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
	jiwer.ToLowerCase(),
	jiwer.RemoveKaldiNonWords(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(**kwargs):
	return hugDS.load_dataset(**kwargs, split="test", trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))

MY_DATA = hugDS.IterableDatasetDict()
MY_DATA["commonvoice"] = load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",  ).select_columns(["audio", "sentence"])
MY_DATA["fleurs"]      = load_my_data(path="google/fleurs",                        name="vi_vn").select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
MY_DATA["vivos"]       = load_my_data(path="vivos"                                             ).select_columns(["audio", "sentence"])
MY_DATA["bud500"]      = load_my_data(path="linhtran92/viet_bud500"                            ).rename_column("transcription", "sentence")
MY_DATA["lsvsc"]       = load_my_data(path="doof-ferb/LSVSC"                                   ).select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")

ROWS_COUNT = {
	"commonvoice": 1326,
	"fleurs":       857,
	"vivos":        760,
	"bud500":      7500,
	"lsvsc":       5683,
}

In [None]:
MODEL_ID = "vinai/PhoWhisper-large"  # @param ["openai/whisper-large-v3", "vinai/PhoWhisper-large", "openai/whisper-large-v2", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-tiny", "openai/whisper-large-v3", "doof-ferb/whisper-tiny-vi"]
PIPE = pipeline(task="automatic-speech-recognition", model=MODEL_ID, device="cuda:0", torch_dtype=torch.float16)
PIPE_KWARGS = {"language": "vi", "task": "transcribe", "do_sample": True, "num_beams": 5}
BATCH_SIZE = 32  # @param {type: "integer"}
# @markdown for colab free T4 @ `float16`: 32 for large model, 40 medium, 96 small, 512 tiny

In [None]:
# workaround because KeyDataset(MY_DATA[split], "audio") raise error with streaming datasets
def data(batch):
	for row in batch:
		yield row["audio"]


@torch.inference_mode()
def predict(split):
	batch = MY_DATA[split]
	y_pred = [out["text"] for out in tqdm(PIPE(data(batch), generate_kwargs=PIPE_KWARGS, batch_size=BATCH_SIZE), total=ROWS_COUNT[split], unit="samples", desc=f"{split=}")]
	torch.cuda.empty_cache()  # forced clean
	y_true = [row["sentence"] for row in batch]
	return hugDS.Dataset.from_dict({"true": y_true, "pred": y_pred})


MY_DATA_BIS = hugDS.DatasetDict()  # do not use MY_DATA.map() because later need non-iterable with jiwer
for split in MY_DATA.keys():
	MY_DATA_BIS[split] = predict(split)

In [None]:
for split in MY_DATA_BIS.keys():
	wer = 100 * jiwer.wer(
		reference=MY_DATA_BIS[split]["true"], hypothesis=MY_DATA_BIS[split]["pred"],
		reference_transform=JIWER_TRANS,      hypothesis_transform=JIWER_TRANS,
	)
	print(f"WER on {split} = {wer:.1f}%")