In [1]:
# !pip install jiwer pyctcdecode "pypi-kenlm" --upgrade -qq

In [2]:
# for cpu inference
# !pip uninstall onnxruntime-gpu -y
# !pip install onnxruntime -qq

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sk import predict,load_model,load_decoder,labels,get_files
from pathlib import Path
import jiwer
import pandas as pd
import string
from fastcore.basics import partialler
from IPython.display import Audio
import multiprocessing
import numpy as np
pd.set_option('max_colwidth', 400)

In [3]:
path = "/content/MyDrive/neem/onnx/conformer_small.onnx"
# path = "/content/model"

In [4]:
model = load_model(path)

loading model


In [5]:
from pyctcdecode import build_ctcdecoder
import kenlm
kenlm_model = kenlm.Model("/content/out.trie.klm")
decoder = build_ctcdecoder(
    labels,
    kenlm_model, 
    alpha=0.5,
    beta=1.0, 
    ctc_token_idx=labels.index("_")
)

In [6]:
fn = "/content/test-bahasa/"
log_probs,files = predict(fn,logits=True)

  0%|          | 0/780 [00:00<?, ?it/s]

Total input path: 1
Total audio found(.wav): 780
start prediction


100%|██████████| 780/780 [00:44<00:00, 17.46it/s]


In [7]:
gt = [i.with_suffix(".txt").read_text() for i in files]

In [8]:
def parallel_decode(log_prob,decoder=decoder):
    out = decoder.decode_beams(log_prob,prune_history=True)
    return out[0][0]

# Grid Search

In [9]:
# brute force xD
alpha = np.arange(0,0.6,0.1)

In [11]:
alpha

array([0. , 0.1, 0.2, 0.3, 0.4, 0.5])

In [12]:
len(alpha)**2

36

In [13]:
%%time
from pqdm.threads import pqdm
data_grid = []
for a in alpha:
    best_wer = 1
    print("*"*20)
    for b in alpha:
        a = round(a,1)
        b = round(1-b,1)
        if a==b and a != 0.5:
            continue
        decoder.reset_params(alpha=a,beta=b)
        outs = pqdm(log_probs,parallel_decode,n_jobs=4,disable=True)
        wer = np.mean([jiwer.compute_measures(i,j)["wer"] for i,j in zip(gt,outs)])
        cer = np.mean([jiwer.compute_measures(list(i),list(j))["wer"] for i,j in zip(gt,outs)])

        if best_wer > wer:
            print(f"a={a} b={b} wer={wer} cer={cer}")
            best_wer = wer
        data_grid.append((a, b, wer,cer))

********************
a=0.0 b=1.0 wer=0.1720843642606008 cer=0.039871626034745346
a=0.0 b=0.9 wer=0.17110145827769477 cer=0.039871626034745346
a=0.0 b=0.8 wer=0.17084504802128453 cer=0.03991741358053289
a=0.0 b=0.7 wer=0.16989322206945862 cer=0.03991741358053289
a=0.0 b=0.6 wer=0.1696795468557834 cer=0.039984889963798745
a=0.0 b=0.5 wer=0.16902686620310275 cer=0.039984889963798745
********************
a=0.1 b=1.0 wer=0.15159775333137757 cer=0.03670908833013627
a=0.1 b=0.9 wer=0.15150062823425245 cer=0.036681810643284116
a=0.1 b=0.8 wer=0.1512442179778422 cer=0.03664619810767158
a=0.1 b=0.7 wer=0.15081686755049178 cer=0.03657751678899026
a=0.1 b=0.6 wer=0.15021145444507866 cer=0.03657751678899026
a=0.1 b=0.5 wer=0.14910034333396757 cer=0.03659853402312225
********************
a=0.2 b=1.0 wer=0.1408932266178558 cer=0.035654586176781904
a=0.2 b=0.7 wer=0.14084218447003946 cer=0.03587817953160415
a=0.2 b=0.6 wer=0.14082692195477695 cer=0.03596401177845335
a=0.2 b=0.5 wer=0.1406987168265718 

In [14]:
df = pd.DataFrame(data_grid, columns=["alpha", "beta", "wer","cer"]).sort_values(by="wer")
df.head()

Unnamed: 0,alpha,beta,wer,cer
26,0.4,0.8,0.13501,0.035834
25,0.4,0.9,0.13501,0.035834
30,0.5,1.0,0.135084,0.036035
27,0.4,0.7,0.135218,0.035719
28,0.4,0.6,0.135325,0.03574


In [20]:
dfcer = pd.DataFrame(data_grid, columns=["alpha", "beta", "wer","cer"]).sort_values(by="cer")
dfcer.head()

Unnamed: 0,alpha,beta,wer,cer
12,0.2,1.0,0.140893,0.035655
27,0.4,0.7,0.135218,0.035719
29,0.4,0.5,0.135698,0.035722
13,0.2,0.9,0.140899,0.035736
28,0.4,0.6,0.135325,0.03574


### let's try our best hparams to make sure it is reproducible

In [15]:
decoder = build_ctcdecoder(
    labels,
    kenlm_model, 
    alpha=df.iloc[0].alpha,
    beta=df.iloc[0].beta, 
    ctc_token_idx=labels.index("_")
)

In [16]:
preds,files,ents,timesteps = predict(fn,decoder=decoder)

  1%|          | 6/780 [00:00<00:14, 52.13it/s]

Total input path: 1
Total audio found(.wav): 780
start prediction


100%|██████████| 780/780 [00:17<00:00, 44.46it/s]


In [17]:
data = []
for i,j in zip(preds,files):
    label = j.with_suffix('.txt').read_text()
    char_label = [char for char in label]
    char_i = [char for char in i]
    data.append([j,label,i,jiwer.compute_measures(label,i)['wer'],jiwer.compute_measures(char_label,char_i)['wer']])

In [18]:
df = pd.DataFrame(data)
df.columns = ["path","label","pred","wer","cer"]
df = df.sort_values("cer",ascending=False)
df.head(50)

Unnamed: 0,path,label,pred,wer,cer
761,/content/test-bahasa/wattpad-audio-wattpad-646.wav,who are you,ayu,1.0,0.666667
450,/content/test-bahasa/-home-husein-speech-bahasa-streaming-iaitu paus odontoceti paus.wav,iaitu paus odontoceti paus,iaitu pawesordantercetipawes,0.75,0.347826
352,/content/test-bahasa/-home-husein-speech-bahasa-haqkiem-LJ118-000005.wav,boleh fuckers quora sila tutup fuck tentang iq,boleh pakar sekolah sila tutup fak tentang iku,0.5,0.307692
664,/content/test-bahasa/-home-husein-speech-bahasa-sebut-perkataan-man-ampe.wav,sebut perkataan ampe,sebuk pakatan ampa,1.0,0.277778
353,/content/test-bahasa/wattpad-audio-wattpad-81.wav,libra dan gisel yang baru datang dengan nafas ngos ngosan menatap,tiba dan kisa yang berdatang dengan nafas nusa menatap,0.545455,0.254545
64,/content/test-bahasa/wattpad-audio-wattpad-88.wav,udah ga cape kita caw yuk,udagacapir kita cayok,0.833333,0.25
643,/content/test-bahasa/wattpad-audio-wattpad-61.wav,duh duh mulai deh jangan bikin kita tambah panik dong yas protes libra,duhdumelairdah jangan bikin kita taman pendek dorong yes protes liberal,0.692308,0.241379
510,/content/test-bahasa/-home-husein-speech-bahasa-tolong-sebut-biperforate.wav,tolong sebut biperforate,tolong sebut bipufferad,0.333333,0.227273
102,/content/test-bahasa/-home-husein-speech-bahasa-sebut-perkataan-man-alau.wav,sebut perkataan alau,sebut keadaan alau,0.333333,0.222222
47,/content/test-bahasa/iium-audio-iium-45.wav,memang salah aku,memang selaku,0.666667,0.214286


In [19]:
df["wer"].mean(),df["cer"].mean()

(0.13501009020404703, 0.03583407086803436)