# evaluate WER of wav2vec BERT v2

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/eval/evaluate-whisper.ipynb)

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer
# install then `import evaluate` throw error on kaggle

In [None]:
import torch
from transformers import AutoProcessor, Wav2Vec2BertForCTC
import datasets as hugDS
import jiwer

JIWER_TRANS = jiwer.Compose([  # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
	jiwer.ToLowerCase(),
	jiwer.RemoveKaldiNonWords(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(**kwargs):
	return hugDS.load_dataset(**kwargs, split="test", trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))

MY_DATA = hugDS.DatasetDict()
MY_DATA["commonvoice"] = load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi",  ).select_columns(["audio", "sentence"])
MY_DATA["fleurs"]      = load_my_data(path="google/fleurs",                        name="vi_vn").select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
MY_DATA["vivos"]       = load_my_data(path="vivos"                                             ).select_columns(["audio", "sentence"])
MY_DATA["bud500"]      = load_my_data(path="linhtran92/viet_bud500"                            ).rename_column("transcription", "sentence")
MY_DATA["lsvsc"]       = load_my_data(path="doof-ferb/LSVSC"                                   ).select_columns(["audio", "transcription"]).rename_column("transcription", "sentence")
# samples count: 1326 + 857 + 760 + 7500 + 5683

In [None]:
MODEL_ID = "trick4kid/w2v-bert-2.0-vietnamese-CV16.0"  # @param ["facebook/w2v-bert-2.0", "trick4kid/w2v-bert-2.0-vietnamese-CV16.0"]
MODEL = Wav2Vec2BertForCTC.from_pretrained(MODEL_ID, target_lang="vi").to("cuda")
PROCESSOR = AutoProcessor.from_pretrained(MODEL_ID)

In [None]:
@torch.inference_mode()
def predict(batch):
	inputs = PROCESSOR(batch["audio"]["array"], sampling_rate=SAMPLING_RATE, return_tensors="pt").to(MODEL.device)
	logits = MODEL(**inputs).logits
	predicted_ids = torch.argmax(logits, dim=-1)
	batch["pred"] = PROCESSOR.batch_decode(predicted_ids)[0]
	return batch

MY_DATA_BIS = MY_DATA.map(predict, remove_columns=["audio"])  # progress bar included

In [None]:
for split in MY_DATA_BIS.keys():
	wer = 100 * jiwer.wer(
		reference=MY_DATA_BIS[split]["sentence"], hypothesis=MY_DATA_BIS[split]["pred"],
		reference_transform=JIWER_TRANS,          hypothesis_transform=JIWER_TRANS,
	)
	print(f"WER on {split} = {wer:.1f}%")