In [1]:
import typing as tp
from pathlib import Path
from functools import partial
from dataclasses import dataclass, field

import pandas as pd
import pyctcdecode
import numpy as np
from tqdm.notebook import tqdm

import librosa

import pyctcdecode
import kenlm
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
from bnunicodenormalizer import Normalizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SAMPLING_RATE = 16_000

### load model, processor, decoder

In [3]:
model = torch.load('model/best.pth',map_location=torch.device('cpu'))

In [4]:
processor = Wav2Vec2Processor.from_pretrained('model')
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

## prepare dataloader

In [5]:
class BengaliSRTestDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        audio_paths: list[str],
        sampling_rate: int
    ):
        self.audio_paths = audio_paths
        self.sampling_rate = sampling_rate
        
    def __len__(self,):
        return len(self.audio_paths)
    
    def __getitem__(self, index: int):
        audio_path = self.audio_paths[index]
        sr = self.sampling_rate
        w = librosa.load(audio_path, sr=sr, mono=False)[0]
        
        return w

In [6]:
test = pd.read_csv("data/train.csv").iloc[:1024]
test

Unnamed: 0,id,sentence,split
0,000005f3362c,ও বলেছে আপনার ঠিকানা!,train
1,00001dddd002,কোন মহান রাষ্ট্রের নাগরিক হতে চাও?,train
2,00001e0bc131,"আমি তোমার কষ্টটা বুঝছি, কিন্তু এটা সঠিক পথ না।",train
3,000024b3d810,নাচ শেষ হওয়ার পর সকলে শরীর ধুয়ে একসঙ্গে ভোজন...,train
4,000028220ab3,"হুমম, ওহ হেই, দেখো।",train
...,...,...,...
1019,00417912a6ee,আমাদের সঙ্গে ছিলেন এক বৃদ্ধ সাধু।,train
1020,0041949399ee,"প্রেম, মার তাকে!",train
1021,0041a6298d5c,আমি তাকে অবাধ্য হতে প্ররোচিত করিনি।,train
1022,0041a78a26ec,তিনি বর্তমানে হেভিওয়েট বিভাগে প্রতিদ্বন্দ্বিত...,train


In [7]:
test_audio_paths = [ "data/train_mp3s/"+f"{aid}.mp3" for aid in test["id"].values]

In [8]:
test_dataset = BengaliSRTestDataset(
    test_audio_paths, SAMPLING_RATE
)

collate_func = partial(
    processor.feature_extractor,
    return_tensors="pt", sampling_rate=SAMPLING_RATE,
    padding=True,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False,
    num_workers=8, collate_fn=collate_func, drop_last=False,
    pin_memory=True,
)

## Inference

In [9]:
if not torch.cuda.is_available():
    device = torch.device("cpu")
else:
    device = torch.device("cuda")
print(device)

cpu


In [10]:
model = model.to(device)
model = model.eval()

In [11]:
pred_sentence_list = []
with torch.no_grad():
    for batch in test_loader:
        x = batch["input_values"]
        y = model(x).logits
        y = y.detach().cpu().numpy()
        pred_sentence_list.append(y)

tcmalloc: large alloc 1925120000 bytes == 0x61162000 @  0x7f5f5d761680 0x7f5f5d782824 0x7f5f5d782b8a 0x7f5e38d45e55 0x7f5e38d22f03 0x7f5e612183a9 0x7f5e60a40233 0x7f5e6114bc6b 0x7f5e612019a8 0x7f5e612024ce 0x7f5e6120296c 0x7f5e61b23728 0x7f5e61b23775 0x7f5e6135f1e5 0x7f5e60b02b1d 0x7f5e61b244c6 0x7f5e61b24547 0x7f5e6134712b 0x7f5e60af6c3d 0x7f5e61b23f85 0x7f5e61b23fef 0x7f5e6130cd8f 0x7f5e62e0cb03 0x7f5e62e0d9f7 0x7f5e61346473 0x7f5e60af9e35 0x7f5e61cd4b61 0x7f5e617a29bc 0x7f5e7837e97a 0x4ea71b 0x63425d
tcmalloc: large alloc 1925120000 bytes == 0xd3d52000 @  0x7f5f5d761680 0x7f5f5d782824 0x7f5f5d782b8a 0x7f5e38d45e55 0x7f5e38d22f03 0x7f5e6074c0b1 0x7f5e60745af4 0x7f5e60745b40 0x7f5e60745b94 0x7f5e60e49fef 0x7f5e61980e61 0x7f5e61980ebb 0x7f5e61613717 0x7f5e619441bf 0x7f5e61653952 0x7f5e6122bfb7 0x7f5e6120278c 0x7f5e6120296c 0x7f5e61b23728 0x7f5e61b23775 0x7f5e6135f1e5 0x7f5e60b02b1d 0x7f5e61b244c6 0x7f5e61b24547 0x7f5e6134712b 0x7f5e60af6c3d 0x7f5e61b23f85 0x7f5e61b23fef 0x7f5e6130cd8f 

In [18]:
import evaluate
metric = evaluate.load("wer")

Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.49k/4.49k [00:00<00:00, 18.2MB/s]


In [32]:
def WER(args):
    alpha,beta = args[0],args[1]
    decoder = pyctcdecode.build_ctcdecoder(
    list(sorted_vocab_dict.keys()),
    'arijitx-wav2vec2-xls-r-300m-bengali/language_model/5gram.bin',
    alpha=alpha,beta=beta,
    )

    processor_with_lm = Wav2Vec2ProcessorWithLM(
        feature_extractor=processor.feature_extractor,
        tokenizer=processor.tokenizer,
        decoder=decoder
    )

    transcriptions = []
    for logits in pred_sentence_list:
        for logit in logits:
            transcriptions.append(processor_with_lm.decode(logit, beam_width=512).text)

    wer = metric.compute(predictions=transcriptions, references=test.sentence.values.tolist())
    return wer

In [31]:
import itertools

36

In [33]:
import multiprocessing
pool = multiprocessing.Pool(processes=36)

In [None]:
scores = pool.map(WER, itertools.product(np.linspace(0.5,5,6), np.linspace(0.5,6,6)))

In [None]:
with open(r"scores.pickle", "wb") as output_file:
    pickle.dump(scores, output_file)