In [1]:
import numpy as np
from tqdm import tqdm
import json, pickle, os, string, tqdm, kenlm
from collections import defaultdict, Counter
from itertools import groupby
import Levenshtein as Lev

In [2]:
os.chdir("/home/hemant/decode_humonics/")

In [3]:
true_txt = 'AT FIRST THIS SORT OF THING IS UNPLEASANT ENOUGH IT TOUCHES ONE\'S SENSE OF HONOUR' # Example text

In [4]:
#s1 = True text
#s2 = predicted text

def wer_(s1, s2):
    """
    Computes the Word Error Rate, defined as the edit distance between the
    two provided sentences after tokenizing to words.
    Arguments:
        s1 (string): space-separated sentence
        s2 (string): space-separated sentence
    """

    # build mapping of words to integers
    b = set(s1.split() + s2.split())
    word2char = dict(zip(b, range(len(b))))

    # map the words to a char array (Levenshtein packages only accepts
    # strings)
    w1 = [chr(word2char[w]) for w in s1.split()]
    w2 = [chr(word2char[w]) for w in s2.split()]
    
    return Lev.distance(''.join(w1), ''.join(w2))

def cer_(s1, s2):
    """
    Computes the Character Error Rate, defined as the edit distance.

    Arguments:
        s1 (string): space-separated sentence
        s2 (string): space-separated sentence
    """
    s1, s2, = s1.replace(' ', ''), s2.replace(' ', '')

    return Lev.distance(s1, s2)




#When using the above implementation, use the code belove to calculate the wer in percentatge: 
#pred = list of ouput prediction of model (it is the text) # example [" MY NAME IS HEMANT", " I AM A GOD"]
# total_wer = 0
# for x in range(len(pred)):
#     transcript, reference = data_[x][1], pred[x]
#     wer_inst = wer(transcript, reference)
#     total_wer += float(wer_inst)
# print("WER is : ",total_wer/len(pred),"%")


In [5]:
with open("out.txt","rb") as f:
    out = pickle.load(f)[0] # Out is a 2d numpy array with shape(number_of_steps,labels_len)
with open("labels.json") as label_file:
    labels = str(''.join(json.load(label_file))) # "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ "

In [6]:
def ctc_best_path(out,labels):
    "implements best path decoding as shown by Graves"
    out = [labels[i] for i in np.argmax(out, axis=1) if i!=labels[-1]]
    o = ""
    for i,j in groupby(out):
        o = o + i
    return o.replace("_","")

In [7]:
gred_txt = ctc_best_path(out,labels)

### WORD LM Implementation

In [8]:
lm_w = kenlm.LanguageModel('/home/hemant/deep/lm/libri_lm/3-gram.binary')

In [9]:
def sort_beam(ptot,k):
    if len(ptot) < k:
        return [i for i in ptot.keys()]
    else:
        dict_ = sorted(dict((v,k) for k,v in ptot.items()).items(),reverse=True)[:k]
        return [i[1] for i in dict_]

#using WORD LM
def ctc_beam_search(out,labels, prune=0.0001, k=20, lm=None,alpha=0.3,beta=12):
    "implements CTC Prefix Search Decoding Algo13.043478260869565%'rithm as shown by Graves"
    '''
    out = ctc output
    labels = string of labels
    prune = prune the ctc output
    k=beam-width
    lm=word age model used
    alpha,beta = hyper-parameters
    '''

    bc_i = 0 # blank/special charatcter index 
    F = out.shape[1]
    out = np.vstack((np.zeros(F), out))
    steps = out.shape[0]
    
    pb, pnb = defaultdict(Counter), defaultdict(Counter)
    pb[0][''], pnb[0][''] = 1, 0
    prev_beams = ['']
    for t in range(1,steps):
        pruned_alphabet = [labels[i] for i in np.where(out[t] > prune)[0]]
        for b in prev_beams:
            for c_t in pruned_alphabet:
                index = labels.index(c_t)
                #Collapsing case (copy case as the last character in the beam)
                if c_t == "_": #Extending with a blank
                    pb[t][b] += out[t][index]*(pb[t-1][b] + pnb[t-1][b])   
                else:
                    i_plus = b + c_t
                    if len(b) > 0 and c_t == b[-1]: #Extending with the same character as the last one
                        pnb[t][b] += out[t][index]*pnb[t-1][b]
                        pnb[t][i_plus] += out[t][index]*pb[t-1][b]
                    #expanding the beam (extend case as the last character is different)
                    elif c_t == " " and len(b.replace(' ', '')) > 0 : # LM constraints
                        prob = [i[0] for i in lm.full_scores(i_plus,eos=False,bos=False)][-1]
                        lm_p = (10**prob)**alpha
                        pnb[t][i_plus] += lm_p*out[t][index]*(pb[t-1][b] + pnb[t-1][b])
                    else:
                        pnb[t][i_plus] += out[t][index]*(pb[t-1][b] + pnb[t-1][b])
                        
                    if i_plus not in prev_beams:
                        pb[t][i_plus] += out[t][index] * (pb[t - 1][i_plus] + pnb[t - 1][i_plus])
                        pnb[t][i_plus] += out[t][index] * pnb[t - 1][i_plus]

        ptot = pb[t] + pnb[t]
        for i in ptot.keys():
            ptot[i] = ptot[i]*(len(i)+1)**beta
        prev_beams = sort_beam(ptot,k)
    return prev_beams[0]

In [10]:
beam_txt=ctc_beam_search(out,labels,0.001,k=10,lm=lm_w)

### CHARACTER LM Implementation

In [11]:
lm_c = kenlm.LanguageModel('/home/hemant/decode_humonics/3_char_gram.arpa')

In [12]:
def sort_beam(ptot,k):
    if len(ptot) < k:
        return [i for i in ptot.keys()]
    else:
        dict_ = sorted(dict((v,k) for k,v in ptot.items()).items(),reverse=True)[:k]
        return [i[1] for i in dict_]

#using CHARACTER LM
def ctc_beam_search_clm(out,labels, prune=0.001, k=20, lm=None,alpha=0.3,beta=12):
    "implements CTC Prefix Search Decoding Algorithm as shown by Graves"
    
    '''
    out = ctc output
    labels = string of labels
    prune = prune the ctc output
    k=beam-width
    lm=charac language model used
    alpha,beta = hyper-parameters
    '''
    
    bc_i = 0 # blank/special charatcter index 
    F = out.shape[1]
    out = np.vstack((np.zeros(F), out))
    steps = out.shape[0]
    
    pb, pnb = defaultdict(Counter), defaultdict(Counter)
    pb[0][''], pnb[0][''] = 1, 0
    prev_beams = ['']
    for t in range(1,steps):
        pruned_alphabet = [labels[i] for i in np.where(out[t] > prune)[0]]
        for b in prev_beams:
            for c_t in pruned_alphabet:
                index = labels.index(c_t)
                #Collapsing case (copy case as the last character in the beam)
                if c_t == "_": #Extending with a blank
                    pb[t][b] += out[t][index]*(pb[t-1][b] + pnb[t-1][b])  
                else:  # LM constraints
                    i_plus = b + c_t
                     #Extending with the same character as the last one
                    if len(b) > 0 and c_t == b[-1]:
                        pnb[t][b] += out[t][index]*pnb[t-1][b]
                        pnb[t][i_plus] += out[t][index]*pb[t-1][b]
                    #expanding the beam (extend case as the last character is different)
                    elif len(b.replace(' ', '')) > 0 :
                        prob = [i[0] for i in lm.full_scores(i_plus,eos=False,bos=False)][-1]
                        lm_p = 1#(10**prob)**alpha
                        pnb[t][i_plus] += lm_p*out[t][index]*(pb[t-1][b] + pnb[t-1][b])
                    else:
                        pnb[t][i_plus] += out[t][index]*(pb[t-1][b] + pnb[t-1][b])
                        
                    if i_plus not in prev_beams:
                        pb[t][i_plus] += out[t][index] * (pb[t - 1][i_plus] + pnb[t - 1][i_plus])
                        pnb[t][i_plus] += out[t][index] * pnb[t - 1][i_plus]
                        
        ptot = pb[t] + pnb[t]
        for i in ptot.keys():
            ptot[i] = ptot[i]*(len(i)+1)**beta
        prev_beams = sort_beam(ptot,k)
    return prev_beams[0]

In [13]:
beam_txt=ctc_beam_search_clm(out,labels,0.001,k=10,lm=lm_c)

# IMPLEMANTATION

In [14]:
import os
os.chdir("/home/hemant/deep/")

import pickle
import json
import os.path
from data.data_loader import SpectrogramParser
import torch
from decoder import GreedyDecoder
import argparse

from tqdm import tqdm
import warnings

from opts import add_decoder_args, add_inference_args
from utils import load_model

In [15]:
with open("/home/hemant/decode_humonics/updatedAfricanNames/wav_utterance.txt", "r") as f:
    data_ = f.readlines()
    
data_ = [[i.split()[0], " ".join(i.split()[1:])] for i in data_]

In [16]:
device = torch.device("cuda")
model = load_model(device, "/home/hemant/decode_humonics/updatedAfricanNames/deepspeech_final.pth",False)
spect_parser = SpectrogramParser(model.audio_conf, normalize=True)

In [17]:
lm_w = kenlm.LanguageModel("/home/hemant/decode_humonics/updatedAfricanNames/3_gram.arpa")
lm_c= kenlm.LanguageModel('/home/hemant/decode_humonics/3_char_gram.arpa')

In [None]:
total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
for i in tqdm(data_[:1000]):
    audio_path = "/home/hemant/decode_humonics/updatedNames/" + i[0] 
    
    try:
        spect = spect_parser.parse_audio(audio_path).contiguous()
        spect = spect.view(1, 1, spect.size(0), spect.size(1))
        spect = spect.to(device)

        input_sizes = torch.IntTensor([spect.size(3)]).int()
        out, output_sizes = model(spect, input_sizes)
        out = out.cpu().detach().numpy()[0]
#         out = ctc_best_path(out,labels)
        out = ctc_beam_search(out,labels,0.00000001,k=100,lm=lm_w)
#         out = ctc_beam_search_clm(out,labels,0.001,k=10,lm=lm_c)
        transcript, reference = out, i[1]
        wer_inst = wer_(transcript, reference)
        cer_inst = cer_(transcript, reference)
        total_wer += wer_inst
        total_cer += cer_inst
        num_tokens += len(reference.split())
        num_chars += len(reference)

        
    except: pass

 37%|███▋      | 370/1000 [00:49<01:21,  7.73it/s]

In [19]:
wer = float(total_wer) / num_tokens
cer = float(total_cer) / num_chars
wer,cer

(0.16761827079934746, 0.059149722735674676)

In [20]:
0.16435562805872758, 0.05772139136279617
0.2956769983686786, 0.09973113762392875
0.33849918433931486, 0.11502268526298101
(0.26101141924959215, 0.09216938329692488)
(0.200652528548124, 0.0711645101663586)
(0.17169657422512236, 0.06015795664594186)
0.16761827079934746, 0.059149722735674676

(0.17169657422512236, 0.06015795664594186)

In [21]:
#greddy lm = 2.97
#word lm = 4.33
#character lm = 4.38