In [1]:
import numpy as np
from tqdm.auto import tqdm
from time import time
import json, pickle, os, string, kenlm, json
from collections import defaultdict, Counter
from itertools import groupby
import Levenshtein as Lev
import math

In [2]:
def lse(*args):
    """
    Stable log sum exp.
    """
    if all(a == -float('inf') for a in args):
        return -float('inf')
    a_max = max(args)
#     print('--',args)
#     args = [i for i in args] + [-float('inf')]
    lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
    return a_max + lsp

In [3]:
def wer_(s1, s2):
    b = set(s1.split() + s2.split())
    word2char = dict(zip(b, range(len(b))))
    w1 = [chr(word2char[w]) for w in s1.split()]
    w2 = [chr(word2char[w]) for w in s2.split()]
    
    return Lev.distance(''.join(w1), ''.join(w2))

In [4]:
labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ "

In [5]:
out = np.load("out.npy")[0]
with open("true.txt", "r") as f:
    reference = f.read()
with open("pred.txt", "r") as f:
    transcript = f.read()

# greedy decoding

In [6]:
def ctc_best_path(out,labels):
    "implements best path decoding as shown by Graves"
    out = [labels[i] for i in np.argmax(out, axis=1) if i!=labels[-1]]
    o = ""
    for i,j in groupby(out):
        o = o + i
    return o.replace("_","")

In [7]:
gred_txt = ctc_best_path(out,labels)
print(gred_txt)
wer_(gred_txt,reference)/len(reference.split(' '))*100

AND CHARGED IFEVER HE MIGHT FIND SIR GAWANE AND SIR UWANE TO BRING THEM TO THE COURT AGAIN AND THEN WERE THEY ALL GLAD AND SO PRAY DHAS OR MORE HOUSE TO RIDE WITH THEM TO THE KING'S COURT


23.076923076923077

# beam search decoding

In [8]:
lm_w = kenlm.LanguageModel('/home/hemant/4_gram.arpa')

In [9]:
def prefix_bsp(out,labels, prune=0.00001, beam_size=25,alpha=1.45,beta=3,lm=None):
    
    blank_symbol = '_'
    F = out.shape[1] # length of labels
    steps = out.shape[0] # number of time steps
    
    t_b = [('', (1.0 ,0.0 ))] # beam at every time step gets updated
    t_1 = None
    
    for t in tqdm(range(0,steps)):
        pruned_alphabet = [labels[i] for i in np.where(out[t]>prune)[0]]
        dummy_beam = defaultdict(lambda: (0,0))
        dummy = t_b
        for prefix, (pb,pnb) in t_b:
            for c in pruned_alphabet:
                p_t = out[t][labels.index(c)]
                
                if c == blank_symbol:
                    dpb,dpnb = dummy_beam[prefix]
                    dpb += p_t*(pb + pnb)
                    dummy_beam[prefix] = (dpb,dpnb)
                    continue
                
                end_t = prefix[-1] if prefix else None
                c_t = prefix + c
                dpb,dpnb = dummy_beam[c_t]
                if c == end_t and len(prefix) > 0:
                    dpb_,dpnb_ = dummy_beam[prefix]
                    dpnb += p_t*pb
                    dpnb_ += p_t*pnb
                    dummy_beam[prefix] = (dpb_,dpnb_)
                    
                elif c == ' ' and len(prefix.strip().split(' ')) > 1:
                    if prefix.upper().split()[-1] in lm:
                        prob = ([10**i[0] for i in lm.full_scores(c_t.upper(),eos=False,bos=False)][-1])**alpha
                    else: 
                        prob = 0.00000001
                        
                    word_inser = (len(prefix.strip().split(' ')))**beta
                    dpnb += prob*p_t*(pb + pnb)*word_inser
                    
                
                else:
                    dpnb += p_t*(pb + pnb)
                dummy_beam[c_t] = (dpb,dpnb)

                if c_t not in t_b and t_1 != None:
                    dpbn,dpnbn = dummy_beam[c_t]
                    for i in t_1:
                        if i[0] == c_t:
                            b_, nb_  = i[1][0], i[1][1]
                        else:
                            b_, nb_  = 0, 0
                    dpbn  += out[t][labels.index("_")]*(b_ + nb_)
                    dpnbn += p_t*nb_
                    dummy_beam[c_t] = (dpbn,dpnbn)

        t_1 = t_b
        t_b = sorted(dummy_beam.items(),
                      key=lambda x:np.sum(x[1]),
                      reverse=True)
        t_b = t_b[:beam_size]
        
    
    best = sorted([(10**lm.score(i[0],bos=True, eos=False),i[0]) for i in t_b],reverse=True)[0][1]
    
    return best

In [24]:
# prune=0.00001, beam_size=75,alpha=1.45,beta=3,lm=lm_w
beam_txt = prefix_bsp(out,labels,prune=0.00001, beam_size=500,alpha=1.6,beta=3,lm=lm_w)
print(beam_txt)
wer_(beam_txt,reference)/len(reference.strip().split(' '))*100

HBox(children=(FloatProgress(value=0.0, max=681.0), HTML(value='')))


AND CHARGED IF EVER HE MIGHT FIND SIR GAWAINE AND SIR UWAINE TO BRING THEM TO THE COURT AGAIN AND THEN WERE THEY ALL GLAD AND SO PRAYED THEY SIR ORHOUSE TO RIDE WITH THEM TO THE KING'S COURT 


2.564102564102564

In [11]:
reference.strip()

"AND CHARGED IF EVER HE MIGHT FIND SIR GAWAINE AND SIR UWAINE TO BRING THEM TO THE COURT AGAIN AND THEN WERE THEY ALL GLAD AND SO PRAYED THEY SIR MARHAUS TO RIDE WITH THEM TO THE KING'S COURT"

In [76]:
def prefix_bsl(out,labels, prune=0.00001, beam_size=20,alpha=0.01,beta=0,lm=None):
    
    blank_symbol = '_'
    F = out.shape[1] # length of labels
    steps = out.shape[0] # number of time steps
    prob_ = out
    out = np.log(out)
    NEG_INF = -float("inf")
    
    t_b = [('', (0.0, NEG_INF ))] # beam at every time step gets updated
    t_1 = None
    
    for t in tqdm(range(0,steps)):
        pruned_alphabet = [labels[i] for i in np.where(prob_[t]>prune)[0]]
        dummy_beam = defaultdict(lambda: (NEG_INF, NEG_INF))
        dummy = t_b
        for prefix, (pb,pnb) in t_b:
            for c in pruned_alphabet:
                p_t = out[t][labels.index(c)]
                
                if c == blank_symbol:
                    dpb,dpnb = dummy_beam[prefix]
                    dpb = lse(dpb, p_t+pb, p_t+pnb)
                    dummy_beam[prefix] = (dpb,dpnb)
                    continue
                
                end_t = prefix[-1] if prefix else None
                c_t = prefix + c
                dpb,dpnb = dummy_beam[c_t]
                if c == end_t and len(prefix) > 0:
                    dpb_,dpnb_ = dummy_beam[prefix]
                    dpnb = lse(dpnb,p_t+pb)
                    dpnb_ = lse(dpnb_,p_t+pnb)
                    dummy_beam[prefix] = (dpb_,dpnb_)
                    
                elif c == ' ' and len(prefix.strip().split(' ')) > 1:
                    if prefix.upper().split()[-1] in lm:
                        prob = 10**[i for i in lm.full_scores(c_t.upper(),eos=False,bos=False)][-1][0]
                        prob = alpha*math.log(prob)
                    else: 
                        prob = -1000
                    word_inser = beta*math.log((len(prefix.strip().split(' '))))
#                     print(prefix, '--' ,prob,word_inser,lse(dpnb,p_t+pb, p_t+pnb),lse(dpnb,p_t+pb, p_t+pnb,prob+word_inser))
                    
                    dpnb = lse(dpnb,p_t+pb, p_t+pnb) + prob+word_inser
                
                else:
                    dpnb = lse(dpnb, p_t+pb, p_t+ pnb)
                dummy_beam[c_t] = (dpb,dpnb)
                
                if c_t not in t_b and t_1 != None:
                    dpbn,dpnbn = dummy_beam[c_t]
                    for i in t_1:
                        if i[0] == c_t:
                            b_, nb_  = i[1][0], i[1][1]
                        else:
                            b_, nb_  = NEG_INF, NEG_INF
                    dpbn  = lse(dpbn,out[t][labels.index("_")]+b_, out[t][labels.index("_")]+ nb_)
                    dpnbn = lse(dpnbn, p_t+nb_)
                    dummy_beam[c_t] = (dpbn,dpnbn)

        t_1 = t_b
        t_b = sorted(dummy_beam.items(),
                      key=lambda x:lse(*x[1]),
                      reverse=True)
        t_b = t_b[:beam_size]
        

    best = sorted([(10**lm.score(i[0],bos=True, eos=False),i[0]) for i in t_b],reverse=True)[0][1]
    return best

In [None]:
beam_txt = prefix_bsl(out,labels,prune=0.00, beam_size=50,alpha=1.4,beta=6,lm=lm_w)
print(beam_txt)
wer_(beam_txt,reference)/len(reference.strip().split(' '))*100

HBox(children=(FloatProgress(value=0.0, max=681.0), HTML(value='')))

In [30]:
print(reference)

AND CHARGED IF EVER HE MIGHT FIND SIR GAWAINE AND SIR UWAINE TO BRING THEM TO THE COURT AGAIN AND THEN WERE THEY ALL GLAD AND SO PRAYED THEY SIR MARHAUS TO RIDE WITH THEM TO THE KING'S COURT



In [31]:
0.4*(math.log(10**[i for i in lm_w.full_scores('and charged if ever'.upper(),eos=False,bos=False)][-1][0]))

-0.05182333219773941