## About

This notebook calculates valid WER score for models used in the submission: [Bengali-SR: Public Wav2vec2.0 w/ LM Baseline](https://www.kaggle.com/code/ttahara/bengali-sr-public-wav2vec2-0-w-lm-baseline). 

## Import

In [1]:
!pip install jiwer
!pip install bnunicodenormalizer
!pip install pyctcdecode
!pip install kenlm

Collecting jiwer
  Downloading jiwer-3.0.4-py3-none-any.whl (21 kB)
Installing collected packages: jiwer
Successfully installed jiwer-3.0.4
Collecting bnunicodenormalizer
  Downloading bnunicodenormalizer-0.1.7-py3-none-any.whl (23 kB)
Installing collected packages: bnunicodenormalizer
Successfully installed bnunicodenormalizer-0.1.7
Collecting pyctcdecode
  Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode)
  Downloading pygtrie-2.5.0-py3-none-any.whl (25 kB)
Collecting hypothesis<7,>=6.14 (from pyctcdecode)
  Downloading hypothesis-6.108.4-py3-none-any.whl (465 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m465.2/465.2 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pygtrie, hypothesis, pyctcdecode
Successfully installed hypothesis-6.108.4 pyctcdecode-0.5.0 pygtrie-2.5.0
Collecting kenlm
  Downloading kenlm-0.2.0.tar.gz (427 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [11]:
!pip install numpy==1.22

Collecting numpy==1.22
  Downloading numpy-1.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.8/16.8 MB[0m [31m63.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.23.5
    Uninstalling numpy-1.23.5:
      Successfully uninstalled numpy-1.23.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 23.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
cuml 23.6.0 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 23.6.1 requires cupy-cuda11x>=12.0.0, which is not installed.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.6 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have 

In [23]:
import typing as tp
from pathlib import Path
from functools import partial
from dataclasses import dataclass, field

import pandas as pd
import pyctcdecode
import numpy as np
from tqdm.notebook import tqdm

import librosa

import jiwer
import pyctcdecode
import kenlm
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
from bnunicodenormalizer import Normalizer

In [25]:
ROOT = Path.cwd().parent
INPUT = ROOT / "input"
DATA = INPUT / "bengaliai-speech"
TRAIN = DATA / "train_mp3s"
TEST = DATA / "test_mp3s"

SAMPLING_RATE = 16_000
MODEL_PATH = INPUT / "bengali-sr-download-public-trained-models/indicwav2vec_v1_bengali/"
LM_PATH = INPUT / "bengali-sr-download-public-trained-models/wav2vec2-xls-r-300m-bengali/language_model/"

### load model, processor, decoder

In [27]:
model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)

In [45]:
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

decoder = pyctcdecode.build_ctcdecoder(
    list(sorted_vocab_dict.keys()),
    str(LM_PATH / "5gram.bin"),
)

In [61]:
processor_with_lm = Wav2Vec2ProcessorWithLM(
    feature_extractor=processor.feature_extractor,
    tokenizer=processor.tokenizer,
    decoder=decoder
)

## prepare dataloader

In [73]:
class BengaliSRTestDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        audio_paths: list[str],
        sampling_rate: int
    ):
        self.audio_paths = audio_paths
        self.sampling_rate = sampling_rate
        
    def __len__(self,):
        return len(self.audio_paths)
    
    def __getitem__(self, index: int):
        audio_path = self.audio_paths[index]
        sr = self.sampling_rate
        w = librosa.load(audio_path, sr=sr, mono=False)[0]
        
        return w

In [75]:
valid = pd.read_csv(DATA / "train.csv", dtype={"id": str}).query("split == 'valid'").reset_index(drop=True)
print(valid.head())

             id                                           sentence  split
0  0000e711c2b1  তিনি এবং তাঁর মা তাদের পৈতৃক বাড়িতে থেকে প্রত...  valid
1  00036c2a2d9d  কৃত্তিবাস রামায়ণ-বহির্ভূত অনেক গল্প এই অনুবাদ...  valid
2  00065e317123  তিনি তার সুশৃঙ্খল সামরিক বাহিনী এবং সুগঠিত শাস...  valid
3  00065f40df52  তিনি বিজয়নগর সাম্রাজ্যের বিরুদ্ধে এবং বিজাপুর...  valid
4  0009b022c8ea                        এটি মূলত একটি মরুময় অঞ্চল।  valid


In [105]:
valid_audio_paths = [str(TRAIN / f"{aid}.mp3") for aid in valid["id"].values]

In [108]:
valid_dataset = BengaliSRTestDataset(
    valid_audio_paths, SAMPLING_RATE
)

collate_func = partial(
    processor_with_lm.feature_extractor,
    return_tensors="pt", sampling_rate=SAMPLING_RATE,
    padding=True,
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=8, shuffle=False,
    num_workers=2, collate_fn=collate_func, drop_last=False,
    pin_memory=True,
)

## Inference

In [109]:
if not torch.cuda.is_available():
    device = torch.device("cpu")
else:
    device = torch.device("cuda")
print(device)

cuda


In [111]:
model = model.to(device)
model = model.eval()
model = model.half()

In [None]:
bnorm = Normalizer()

def postprocess(sentence):
    period_set = set([".", "?", "!", "।"])
    _words = [bnorm(word)['normalized']  for word in sentence.split()]
    sentence = " ".join([word for word in _words if word is not None])
    try:
        if sentence[-1] not in period_set:
            sentence+="।"
    except:
        # print(sentence)
        sentence = "।"
    return sentence

In [None]:
pred_sentence_list = []

with torch.no_grad():
    for batch in tqdm(valid_loader):
        x = batch["input_values"]
        x = x.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(True):
            y = model(x).logits
        y = y.detach().cpu().numpy()
        
        for l in y:  
            sentence = processor_with_lm.decode(l, beam_width=512).text
            pred_sentence_list.append(sentence)

## Check Valid WER score

In [None]:
pp_pred_sentence_list = [
    postprocess(s) for s in tqdm(pred_sentence_list)]

In [None]:
valid["pred_sentence"] = pp_pred_sentence_list
valid["wer"] = [
    jiwer.wer(s, p_s)
    for s, p_s in tqdm(valid[["sentence", "pred_sentence"]].values)
]

print(valid.head())

In [None]:
print(valid["wer"].mean())

## EOF