In [1]:
import openfst_python as fst
import math
import glob
from subprocess import check_call
from IPython.display import Image

class CreateWFST:
    
    def __init__(self, n_phone=3, n_word=1):
        self.num_per_phone=n_phone
        self.num_per_word=n_word
        self.end_weight=fst.Weight('log',-math.log(1))
        self.lex=self.parse_lexicon('lexicon.txt')
        self.word_table, self.phone_table, self.state_table=self.generate_symbol_tables(self.lex,3)
        self.word_table.write_text('word_table.txt')
        self.phone_table.write_text('phone_table.txt')
        self.state_table.write_text('state_table.txt')
        self.unigram_dic={}
        self.unigram_dic_cp={}
        self.bigram_dic={}
        self.bigram_dic_cp={}
        self.st_w={}
        for word in self.lex.keys():
            self.st_w[word]=fst.Weight('log',-math.log(1/len(self.lex.keys())))
            
    #1***: parse lexicon into a dic: { word: [phones] }
    def parse_lexicon(self,lex_file):
        lex={}
        with open(lex_file,'r') as f:
            for line in f:
                line=line.split()
                if line[0] in lex.keys():
                    line[0]=line[0]+'*'
                lex[line[0]]=line[1:]
        return lex

    #2***: generate fst. word_table; phone_table; state_table
    def generate_symbol_tables(self,lex_dic, n):
        word_table=fst.SymbolTable()
        phone_table=fst.SymbolTable()
        state_table=fst.SymbolTable()

        word_table.add_symbol('<eps>')
        phone_table.add_symbol('<eps>')
        state_table.add_symbol('<eps>')

        for w, ps in lex_dic.items():
            word_table.add_symbol(w)

            for p in ps:
                phone_table.add_symbol(p)
                for i in range(1,n+1):
                    state_table.add_symbol("{}_{}".format(p,i))

        return word_table, phone_table, state_table



    #3***: generate phone_n state wfst with final output phone
    def generate_phone_sequence_wfst(self, f, start_state, phone, n, sl_w, tr_w):
        
        current_state=start_state
        for i in range(1,n+1):
            in_label=self.state_table.find('{}_{}'.format(phone,i))
            f.add_arc(current_state, fst.Arc(in_label, 0, sl_w, current_state))
            if i==n:
                out_label=self.phone_table.find(phone)
            else:
                out_label=0
            next_state=f.add_state()
            f.add_arc(current_state, fst.Arc(in_label, out_label, tr_w, next_state))
            current_state=next_state
        return current_state
    
    #4***: generate word recognition wfst with phone_n state for each phone,
    # phone as output of each word sequence endstate
    def generate_word_recognition_wfst(self, n, sl_w, tr_w):
        f=fst.Fst('log')
        start_state=f.add_state()
        f.set_start(start_state)
        
        #remove_duplicate_phone=set()
        #for allphone in self.lex.values():
        #    remove_duplicate_phone = remove_duplicate_phone.union(allphone)
        
        for word, phones in self.lex.items():
            current_state=f.add_state()
            f.add_arc(start_state, fst.Arc(0, 0, self.st_w[word], current_state))
            for phone in phones:
                current_state=self.generate_phone_sequence_wfst(f, current_state, phone, n, sl_w, tr_w)
            f.set_final(current_state)
            f.add_arc(current_state, fst.Arc(0, 0, self.end_weight, start_state))
        f.set_input_symbols(self.state_table)
        f.set_output_symbols(self.phone_table)
        return f
    #5***: generate lexicon into wfst
    def generate_L_wfst(self):
        L=fst.Fst()
        start_state=L.add_state()
        L.set_start(start_state)
        
        for word, phones in self.lex.items():
            current_state=start_state
            for (i,phone) in enumerate(phones):
                next_state=L.add_state()
                if i==len(phones)-1:
                    L.add_arc(current_state, fst.Arc(self.phone_table.find(phone), self.word_table.find(word), None, next_state))
                else:
                    L.add_arc(current_state, fst.Arc(self.phone_table.find(phone), 0, None, next_state))
                current_state=next_state
            L.set_final(current_state)
            L.add_arc(current_state, fst.Arc(0,0,None,start_state))
        L.set_input_symbols(self.phone_table)
        L.set_output_symbols(self.word_table)
        return L
    
    #6***: generate linear phone sequence wfst
    def generate_linear_phone_wfst(self,seq_list):
        P=fst.Fst()
        start_state=P.add_state()
        P.set_start(start_state)
        
        for phone in seq_list:
            i_label=self.phone_table.find(phone)
            current_state=P.add_state()
            P.add_arc(start_state, fst.Arc(i_label,i_label,None,current_state))
            start_state=current_state
        P.set_final(start_state)
        P.set_input_symbols(self.phone_table)
        P.set_output_symbols(self.phone_table)
        return P                    
        
    #5***:generate word recognition wfst with phone_n state for each phone
    # word as final output of each word sequence endstate
    def generate_ow_word_recognition_wfst(self, n, sl_w, tr_w):
        fw=fst.Fst('log')
        start_state=fw.add_state()
        fw.set_start(start_state)
        
        for word, phones in self.lex.items():
            current_state=fw.add_state()
            fw.add_arc(start_state, fst.Arc(0, 0, self.st_w[word], current_state))
            for phone in phones:
                for i in range(1, n+1):
                    inlabel=self.state_table.find('{}_{}'.format(phone,i))
                    fw.add_arc(current_state, fst.Arc(inlabel, 0, sl_w, current_state))
                    if i==n and phone==phones[-1]:
                        outlabel=self.word_table.find(word)
                    else:
                        outlabel=0
                    next_state=fw.add_state()
                    fw.add_arc(current_state, fst.Arc(inlabel, outlabel, tr_w, next_state))
                    current_state=next_state
            fw.set_final(current_state)
            fw.add_arc(current_state, fst.Arc(0, 0, self.end_weight, start_state))
        fw.set_input_symbols(self.state_table)
        fw.set_output_symbols(self.word_table)
        return fw
    
    #6***: generate word recognition wfst with unigram probability with for each word path
    # with phone_n state for each phone
    # word as final output of each word sequence end state
    def generate_unigram_word_recognition_wfst(self, n, sl_w, tr_w, unigram_dic):
        fu=fst.Fst('log')
        start_state=fu.add_state()
        fu.set_start(start_state)
        
        for word, phones in self.lex.items():
            current_state=fu.add_state()
            if '*' in word:
                fu.add_arc(start_state, fst.Arc(0, 0, unigram_dic[word.rstrip('*')], current_state))
            else:
                fu.add_arc(start_state, fst.Arc(0, 0, unigram_dic[word], current_state))
            for phone in phones:
                for i in range(1, n+1):
                    inlabel=self.state_table.find('{}_{}'.format(phone,i))
                    fu.add_arc(current_state, fst.Arc(inlabel, 0, sl_w, current_state))
                    if i==n and phone==phones[-1]:
                        outlabel=self.word_table.find(word)
                    else:
                        outlabel=0
                    next_state=fu.add_state()
                    fu.add_arc(current_state, fst.Arc(inlabel, outlabel, tr_w, next_state))
                    current_state=next_state
            fu.set_final(current_state)
            fu.add_arc(current_state, fst.Arc(0, 0, self.end_weight, start_state))
        fu.set_input_symbols(self.state_table)
        fu.set_output_symbols(self.word_table)
        return fu
    # generate unigram dic: {'word': 'fst_type' log of probability} P(word_i)= word_i count / total words count
    def generate_unigram(self, txt_file_path):
        to_transcription=[]
        for txt_f in glob.glob(txt_file_path):
            with open(txt_f,'r') as f:
                tst=f.readline().strip()
                to_transcription+=tst.split( )
        for token in to_transcription:
            if token not in self.unigram_dic.keys():
                self.unigram_dic[token]=1
            else:
                self.unigram_dic[token]+=1
        self.unigram_dic_cp=self.unigram_dic.copy()
        for word in self.unigram_dic.keys():
            self.unigram_dic[word]=fst.Weight('log',-math.log(self.unigram_dic[word]/len(to_transcription)))
        return self.unigram_dic
    
    #7*** genenrate 5 state silence wfst:
    def add_silence_wfst(self, afst):
        n=5 # per silence state
        self.state_table.add_symbol('sil_1')
        self.state_table.add_symbol('sil_2')
        self.state_table.add_symbol('sil_3')
        self.state_table.add_symbol('sil_4')
        self.state_table.add_symbol('sil_5')
        wei_next=fst.Weight('log',-math.log(0.3))
        wei_3=fst.Weight('log',-math.log(1/3))
        wei_back=fst.Weight('log',-math.log(0.1))
        
        start_state=0
        afst.set_start(start_state)
        st_record=[]
        current_state=start_state
        
        for i in range(1,n+1):
            next_state=afst.add_state()
            st_record.append(next_state)
            if i== 3 or i==4:
                in_label= self.state_table.find('{}_{}'.format('sil',i))
                afst.add_arc(current_state, fst.Arc(in_label,0,wei_next,next_state))
                afst.add_arc(next_state, fst.Arc(in_label,0,wei_next,next_state))
                in_label= self.state_table.find('{}_{}'.format('sil',i-1))
                afst.add_arc(next_state, fst.Arc(in_label,0,wei_back,current_state))
            else:
                in_label= self.state_table.find('{}_{}'.format('sil',i))
                afst.add_arc(current_state, fst.Arc(in_label,0,wei_3,next_state))
                afst.add_arc(next_state, fst.Arc(in_label,0,wei_3,next_state))
            current_state=next_state
        for i in range(len(st_record)-1):
            i_label1=self.state_table.find('sil_5')
            if i ==2 or i==3:
                afst.add_arc(st_record[i], fst.Arc(i_label1,0,wei_next,next_state))
            else:
                afst.add_arc(st_record[i], fst.Arc(i_label1,0,wei_3,next_state))
        afst.set_final(next_state)
        afst.add_arc(next_state, fst.Arc(0,0,self.end_weight,start_state))
        afst.set_input_symbols(self.state_table)
        #afst.set_output_symbols(self.phone_table)
        return afst
    
    #8***:generate wfst with bigram language model weight
    def generate_bigram_word_recognition_wfst(self, n, sl_w, tr_w, unigram_dic, bigram_dic):
     
        #bigram_dict = generate_bigram_dict(recording_txt_files,lex)

        G = fst.Fst('log')
        start_state = G.add_state()
        G.set_start(start_state)

        for wordi, phones in self.lex.items():
            current_state=G.add_state()
            if '*' in wordi:
                G.add_arc(start_state, fst.Arc(0, 0, unigram_dic[wordi.rstrip('*')], current_state))
            else:
                G.add_arc(start_state, fst.Arc(0, 0, unigram_dic[wordi], current_state))
            for phone in phones:
                for i in range(1, n+1):
                    in_label=self.state_table.find('{}_{}'.format(phone, i))
                    G.add_arc(current_state, fst.Arc(in_label, 0, sl_w, current_state))
                    if i==n and phone==phones[-1]:
                        out_label=self.word_table.find(word)
                    else:
                        out_label=0
                    next_state=G.add_state()
                    G.add_arc(current_state, fst.Arc(in_label, out_label, tr_w, next_state))
                    current_state=next_state
            #add bigram model by change weight here, as we know Sum[ P(w'_i | w_1), ...., P(w'_i | w_n) ]= 1
            start_state_ii=current_state
            for wordii, phoneiis in self.lex.items():
                current_state_ii=G.add_state()
                if '*' in wordii:
                    wordii=wordii.rstrip('*')
                    G.add_arc(start_state_ii, fst.Arc(0, 0, bigram_dic[(wordi, wordii)], current_state_ii))
                else:
                    G.add_arc(start_state_ii, fst.Arc(0, 0, bigram_dic[(wordi, wordii)], current_state_ii))
                for phone in phoneiis:
                    for j in range(1, n+1):
                        inlabel_ii=self.state_table.find('{}_{}'.format(phone, i))
                        G.add_arc(current_state_ii, fst.Arc(inlabel_ii, 0, sl_w, current_state_ii))
                        if j==n and phone==phones[-1]:
                            outlabel_ii=self.word_table.find(wordii)
                        else:
                            outlabel_ii=0
                        next_state_ii=G.add_state()
                        G.add_arc(current_state_ii, fst.Arc(inlabel_ii, outlabel_ii, tr_w, next_state_ii))
                        current_state_ii=next_state_ii
            G.set_final(current_state_ii)
            G.add_arc(current_state_ii, fst.Arc(0, 0, self.end_weight, start_state))
        G.set_input_symbols(self.state_table)
        G.set_output_symbols(self.word_table)

        return G
    # compute the bigram dic: {'(w_i, w_i+1)' :  'fst_type' log of probability} P(w_i | w_i+1)= 'w_i, w_i+1' count / w_i count
    def generate_bigram(self, txt_file_path):
        to_transcription1=[]
        for txt_f in glob.glob(txt_file_path):
            with open(txt_f,'r') as f:
                tst=f.readline().strip()
                to_transcription1+=tst.split( )
        for ti in range(len(to_transcription1)-1):
            if (to_transcription1[ti],to_transcription1[ti+1]) not in self.bigram_dic.keys():
                self.bigram_dic[(to_transcription1[ti],to_transcription1[ti+1])]=1
            else:
                self.bigram_dic[(to_transcription1[ti],to_transcription1[ti+1])]+=1
        self.bigram_dic_cp=self.bigram_dic.copy()
        for ty in self.bigram_dic_cp:
            self.bigram_dic[ty]=fst.Weight('log',-math.log(self.bigram_dic[ty] / self.unigram_dic_cp[ty[0]]))
        return self.bigram_dic

In [2]:
import observation_model
import math

class ViterbiDecoder:
    Infinity=1e10 # define a constant represent -log(0)
    
    def __init__(self, f, audio_file_name):
        
        self.om=observation_model.ObservationModel()
        self.f=f
        self.forward_count=0
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        self.initialise_decoding()
    
    #11***: initialise variables, function calling for assigning initial values
    def initialise_decoding(self):
        
        self.V=[]
        self.B=[]
        self.W=[]
        
        for t in range(self.om.observation_length()+1):
            self.V.append([self.Infinity]*self.f.num_states()) # V[t: 0->initial, 1~T][state number: 0~N] => values: maxP() or min-Log(P)
            self.B.append([-1]*self.f.num_states()) # B[t: 0 will not use, 1~T][state number: 0~N] => values: current state j B[t][j]=last state number i of BEST PATH( min(V[t][i]+ arc[i->j]_weight) )
            self.W.append([[] for i in range(self.f.num_states())]) # W[t: 0 will not use, 1~T][state number: 0~N] => values: [output_label] if output <eps> -> [], and the initial all are [] empty list
        
        self.V[0][self.f.start()]=0.0 #V[0][0]=0.0
        self.traverse_epsilon_arcs(0) #actually, t never has value '0'; this will give initil values of V[0][...]
        # consider every path (arcm, arcn,...) ends into a state j which's V[0][j]= min(arcm_w, arcn_w, ...)
        # consider the normal (word) wfst, state 0 has k arcs means k words (phone sequence sub_wfst)
        # maybe one example: V[0][1, 9, 15, 32] (so k here is 4), and the 1~8(8 as end state) is a phone sequence wfst
        # after do "self.traverse_epsilon_arcs(0)", only V[0][0, 1, 9, 15, 32] has values and V[0][0]=0.0 because
        # each V[0][i] will compare and save the smallest one! V[0][0] will compare V[0][8]+ arc[8->0]_weight; & V[0][14]+ arc[14->0]_weight
        # so I think whatever that V of (state pointer to state 0) and plus a arc_weight will be greater than just 0.0 => so V[0][0]= 0.0
        # Reminder!!! the min path only compare and asign when arc_input_label is <eps>! " if arc.ilabel==0: "
    
    #12***: traverse arcs with <eps> on the input at time t
    def traverse_epsilon_arcs(self, t):
        
        states_to_traverse=list(self.f.states())
        
        while states_to_traverse:
            i=states_to_traverse.pop(0) # always pop out the first (state number) of current state list
                                        # start from the first state (maybe 0)
            
            if self.V[t][i]==self.Infinity:
                continue # 0 probability means a disavaliable path
            for arc in self.f.arcs(i):
                
                if arc.ilabel==0: # if this is <eps> transition
                    j=arc.nextstate
                    
                    if self.V[t][j] > self.V[t][i]+float(arc.weight):
                        # if this path (at time t, V[t][i]+arc.w less than the original saved cost 
                        # at time t state, V[t][j])
                        # we may want to save this lowwer path to update the probability V[][]
                        # also, we need update the previous state number B[][]
                        self.V[t][j]= self.V[t][i]+float(arc.weight)
                        self.B[t][j]=self.B[t][i]
                        self.forward_count+=1
                        
                        if arc.olabel != 0: # also, if there's output, we should record that.
                                            # <eps> input laways have multiple arcs with multiple outputs
                                            # if the outputs is <eps>, it just pass this list to state J;
                                            # if not add the output label in the tail of the list and pass to state J
                            self.W[t][j] = self.W[t][i] + [arc.olabel]
                        else:
                            self.W[t][j] = self.W[t][i]
                        
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)
    
    #13***: forward_step:
    def forward_step(self, t):
        
        for i in self.f.states(): # start with the fist state (maybe 0)
            if not self.V[t-1][i] == self.Infinity: # because the following assign is on state j, if previous t-1 and last i cost V[t-1][i] is 0 probability, pass by
                
                for arc in self.f.arcs(i): # arcs: selfloop, next phone_n, other connection
                    if arc.ilabel !=0: # <eps> transitions don't emit observation, so the <eps> input arc will not calculate here
                        #but the <eps> path will be compared in traverse Fnc, then the end state k (of other paths & <eps> path) 's V[at every t][k] will have the min(-log)
                        
                        j=arc.nextstate # laways assign next state; the 0 state will be (compare) and assign by all the end states' back arc
                        t_p=float(arc.weight) # arc_weigh = transition probability aij (qi->qj)
                        e_p= -self.om.log_observation_probability( self.f.input_symbols().find(arc.ilabel), t) # emission probability: P(label | obsevation Ot) label: state represent <=> also equals to (arc points to state j)'s inlabel 
                        # class om has log_observation_probability object, which input ('string', 'int')
                        p_j = t_p + e_p + self.V[t-1][i] # p_j also means temp V[t][j], the path cost in state j at time t = transition_P + emission_P + V[t-1][i] (becasue last time t-1 state i --arc--> current state j) 
                        
                        if p_j < self.V[t][j]: # Viterbi key not sum(path1 + path2), but compare and save min(paths) at each state
                            self.V[t][j]=p_j # this assign actually contains V[t-1][last state i], so imagine all the (future t)'s cost V will contain previous (t-1)'s cost; so this is how propagation in HMM (from one state to next time's state)
                            self.B[t][j]=i # So we need the best path from: from lower arc's parent state i
                            self.forward_count+=1
                            
                            if arc.olabel !=0: # example n=3, so only the phone_3 state has output: W[any t][phone_3]=[phone label]; else: W[any t][phone_1,2]=[]
                                self.W[t][j]=[arc.olabel]
                            else:
                                self.W[t][j]=[]
    #14***: generate finalise_decoding for make no output state in V[T] sets Inifinity value (which we won't except to compare in backtrace: min(V[-1]))
    # And we use float(f.final(end_state)) to get the end_state weight (intial 0.0) to add in the V[T][end_state]; float(not end state weight) intial is math.inf
    def finalise_decoding(self):
        for state in self.f.states():
            final_weight=float(self.f.final(state)) #if state isn't final state, output inf
            if self.V[-1][state]!=self.Infinity:
                if final_weight==math.inf:
                    self.V[-1][state]=self.Infinity # inf means the state is not a end_state, so there's no emit
                    #so set the cost to infinity, then we won't compare in backtrce
                else:
                    self.V[-1][state]=self.V[-1][state]+final_weight #if the end state has a weight, it will also be added in V
        finished=[x for x in self.V[-1] if x < self.Infinity]
        if not finished: #means there's no end_state in the V[T] <=> all state_weight is inf (as initial)
            print("No path got to the end of the observations.")
    
    #15***: decode main
    def decode(self):
        self.initialise_decoding()
        t=1
        while t<= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t+=1
        self.finalise_decoding()
    
    #16***: generate Viterbi algorithm to find minist cost path by backtracing
    def backtrace(self):
        best_final_state=self.V[-1].index(min(self.V[-1])) # this is min cost state index
        best_state_sequence=[best_final_state]
        best_out_sequence=[]
        
        t=self.om.observation_length()
        j=best_final_state
        
        while t>=0: # we can print W[t][j] & j & t to understand
            i=self.B[t][j] # as we save in B[t][j] is always the last best state i; here for intuition: there may be several same state here becasue they were in the selfloop, so the previous state is itself.
            best_state_sequence.append(i) #here add the connect state number of best path to the list behind, so it needs reverse.
            best_out_sequence=self.W[t][j]+best_out_sequence #each output will add in the list front, as it's from T to 1, so the final list will be correct sequence.
            j=i
            t-=1
        
        best_state_sequence.reverse() #from 'T to 1' to '1 to T'
        best_out_sequence=' '.join([self.f.output_symbols().find(label) for label in best_out_sequence]) #f.output_symbols() as set before is list of words in wfst
        
        return (best_state_sequence, best_out_sequence)
    
    #17***: pruning method decode & forward
    def decode_pruning(self,p_td):
        self.initialise_decoding()
        t=1
        while t<= self.om.observation_length():
            self.forward_step_pruning(t,p_td)
            self.traverse_epsilon_arcs(t)
            t+=1
        self.finalise_decoding()
    def forward_step_pruning(self, t, pruning_threshold):
        for i in self.f.states(): # start with the fist state (maybe 0)
            if not self.V[t-1][i] == self.Infinity: # because the following assign is on state j, if previous t-1 and last i cost V[t-1][i] is 0 probability, pass by
                
                if self.V[t-1][i]< pruning_threshold: # the path cost is P*P*P... for eahc time, when -log(): it should be v1+v2+v3+...vT
                    for arc in self.f.arcs(i): # arcs: selfloop, next phone_n, other connection
                        if arc.ilabel !=0: # <eps> transitions don't emit observation, so the <eps> input arc will not calculate here
                            #but the <eps> path will be compared in traverse Fnc, then the end state k (of other paths & <eps> path) 's V[at every t][k] will have the min(-log)

                            j=arc.nextstate # laways assign next state; the 0 state will be (compare) and assign by all the end states' back arc
                            t_p=float(arc.weight) # arc_weigh = transition probability aij (qi->qj)
                            e_p= -self.om.log_observation_probability( self.f.input_symbols().find(arc.ilabel), t) # emission probability: P(label | obsevation Ot) label: state represent <=> also equals to (arc points to state j)'s inlabel 
                            # class om has log_observation_probability object, which input ('string', 'int')
                            p_j = t_p + e_p + self.V[t-1][i] # p_j also means temp V[t][j], the path cost in state j at time t = transition_P + emission_P + V[t-1][i] (becasue last time t-1 state i --arc--> current state j) 

                            if p_j < self.V[t][j]: # Viterbi key not sum(path1 + path2), but compare and save min(paths) at each state
                                self.V[t][j]=p_j # this assign actually contains V[t-1][last state i], so imagine all the (future t)'s cost V will contain previous (t-1)'s cost; so this is how propagation in HMM (from one state to next time's state)
                                self.B[t][j]=i # So we need the best path from: from lower arc's parent state i
                                self.forward_count+=1

                                if arc.olabel !=0: # example n=3, so only the phone_3 state has output: W[any t][phone_3]=[phone label]; else: W[any t][phone_1,2]=[]
                                    self.W[t][j]=[arc.olabel]
                                else:
                                    self.W[t][j]=[]
    
    #17***: Beamsearch method decode & forward
    def decode_beamsearch(self,b_td):
        self.initialise_decoding()
        t=1
        while t<= self.om.observation_length():
            self.forward_step_beam(t,b_td)
            self.traverse_epsilon_arcs(t)
            t+=1
        self.finalise_decoding()
    def forward_step_beam(self, t, beam_threshold):
        Vt=[self.V[t-1][i] for i in self.f.states()]#collect all of the probs for states in timestep t-1
        beam_s = []
        for index, value in sorted(enumerate(Vt), key=lambda x:x[1]):
            beam_s.append(index)
        beam_s=beam_s[:beam_threshold]
        for i in beam_s: # start with the fist state (maybe 0)
            if not self.V[t-1][i] == self.Infinity: # because the following assign is on state j, if previous t-1 and last i cost V[t-1][i] is 0 probability, pass by
                
                for arc in self.f.arcs(i): # arcs: selfloop, next phone_n, other connection
                    if arc.ilabel !=0: # <eps> transitions don't emit observation, so the <eps> input arc will not calculate here
                        #but the <eps> path will be compared in traverse Fnc, then the end state k (of other paths & <eps> path) 's V[at every t][k] will have the min(-log)
                        
                        j=arc.nextstate # laways assign next state; the 0 state will be (compare) and assign by all the end states' back arc
                        t_p=float(arc.weight) # arc_weigh = transition probability aij (qi->qj)
                        e_p= -self.om.log_observation_probability( self.f.input_symbols().find(arc.ilabel), t) # emission probability: P(label | obsevation Ot) label: state represent <=> also equals to (arc points to state j)'s inlabel 
                        # class om has log_observation_probability object, which input ('string', 'int')
                        p_j = t_p + e_p + self.V[t-1][i] # p_j also means temp V[t][j], the path cost in state j at time t = transition_P + emission_P + V[t-1][i] (becasue last time t-1 state i --arc--> current state j) 
                        
                        if p_j < self.V[t][j]: # Viterbi key not sum(path1 + path2), but compare and save min(paths) at each state
                            self.V[t][j]=p_j # this assign actually contains V[t-1][last state i], so imagine all the (future t)'s cost V will contain previous (t-1)'s cost; so this is how propagation in HMM (from one state to next time's state)
                            self.B[t][j]=i # So we need the best path from: from lower arc's parent state i
                            self.forward_count+=1
                            
                            if arc.olabel !=0: # example n=3, so only the phone_3 state has output: W[any t][phone_3]=[phone label]; else: W[any t][phone_1,2]=[]
                                self.W[t][j]=[arc.olabel]
                            else:
                                self.W[t][j]=[]

In [3]:
import glob
import os
import wer
import observation_model
import timeit
import numpy

txt_files = '/group/teaching/asr/labs/recordings/*.txt'
wav_files = '/group/teaching/asr/labs/recordings/*.wav'

#21***: Fnc for get standard transcription from wav matched file
def get_transcription(f_wav):
    f_txt=os.path.splitext(f_wav)[0]+'.txt'
    with open(f_txt, 'r') as f:
        transcription=f.readline().strip()
    return transcription

#22***: Fnc for experiment set up
def experiment_set_up(mode, s_w, t_w, pru_td):
    wfst=CreateWFST()
    L=wfst.generate_L_wfst()
    selfloop_weight=fst.Weight('log',-math.log(s_w))
    transist_weight=fst.Weight('log',-math.log(t_w))
    if mode == 'pru_b_w':
        pruning_threshold=float(fst.Weight('log',pru_td))
    elif mode == 'beam_b_w':
        beam_threshold=int(pru_td)
    #v1=ViterbiDecoder()
    #1****b_p
    if mode == 'b_p':
        F=wfst.generate_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create wfst
        
        print("Mode: Baseline Recognition, Use Compose To Transfer Recognized Phone Sequecen Into Word Sequence")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_1=0
        states_1=0
        for s_1 in F.states():
            states_1+=1
            for a_1 in F.arcs(s_1):
                arcs_1+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_1,arcs_1))
        
        decode_time_1=[]
        backtrace_time_1=[]
        forward_count_1=[]
        to_error_count_1=[]
        to_transcription_1=[]
        for wav in glob.glob(wav_files):
            v1=ViterbiDecoder(F,wav)
            
            start_de_time_1= timeit.default_timer() #decode
            v1.decode()
            end_de_time_1= timeit.default_timer()
            decode_time_1.append(end_de_time_1-start_de_time_1)
            
            start_ba_time_1= timeit.default_timer() #backtrace
            (path_states_1, phone_seq_1)=v1.backtrace()
            #print(list(phone_seq_1))
            #s=input()
            P_1=wfst.generate_linear_phone_wfst(phone_seq_1.split())# input should be list
            P_1.arcsort(sort_type='ilabel')
            comp_1=fst.compose(P_1,L)
            comp_1.project(project_output='True')
            comp_1.rmepsilon()
            #comp_1.draw('tu.dot',portrait=True)
            #check_call(['dot','-Tpng','-Gdpi=1000','tu.dot','-o','tu.png'])
            #Image(filename='tu.png')
            #s=input()
            word_seq_1=[]
            for q_1 in comp_1.states():
                for arc_1 in comp_1.arcs(q_1):
                    word_seq_1.append(comp_1.output_symbols().find(arc_1.olabel))
            end_ba_time_1= timeit.default_timer()
            backtrace_time_1.append(end_ba_time_1-start_ba_time_1)
            forward_count_1.append(v1.forward_count) # after decode, the count should has values
            
            word_seq_1=' '.join(word_seq_1)
            word_seq_1=word_seq_1.replace('*','')
            transcription_1=get_transcription(wav)
            error_counts_1=wer.compute_alignment_errors(transcription_1,word_seq_1)
            WER_1=sum(error_counts_1)/len(transcription_1.split())
            to_error_count_1.append(sum(error_counts_1))
            to_transcription_1.append(len(transcription_1.split()))
            
            print("\n\n*****************************************")
            print("File:",wav,"\nSpeed of Viterbi Decoder:",decode_time_1[-1],"s"," Speed of Backtrace:",backtrace_time_1[-1],"s"," Forward Count:",forward_count_1[-1])
            #print("D:",(start_de_time_1-end_de_time_1),decode_time_1[-1],"B:",start_ba_time_1-end_ba_time_1,backtrace_time_1[-1],"F:",v1.forward_count,forward_count_1[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_1, len(transcription_1.split()))
            print("WER:",WER_1)
            print("Trascription:",transcription_1)
            print("Recognition:",word_seq_1)
        print("Average")
        print("Total Errors:",sum(to_error_count_1))
        print("Total Transcription Words:",sum(to_transcription_1))
        print("Total WER:",sum(to_error_count_1) / sum(to_transcription_1))
        print("The number of forward computations per wav: ",sum(forward_count_1)/len(forward_count_1))
        print("Average decode time:",numpy.mean(decode_time_1))
        print("Average backtrace time:",numpy.mean(backtrace_time_1))
    
    #2****b_w
    elif mode == 'b_w':
        F_bw=wfst.generate_ow_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create specific kind of wfst
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_2=0
        states_2=0
        for s_2 in F_bw.states():
            states_2+=1
            for a_2 in F_bw.arcs(s_2):
                arcs_2+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_2,arcs_2))
        
        decode_time_2=[]
        backtrace_time_2=[]
        forward_count_2=[]
        to_error_count_2=[]
        to_transcription_2=[]
        for wav_2 in glob.glob(wav_files):
            v2=ViterbiDecoder(F_bw,wav_2) #step2: create Viterbi instance for wfst
            
            start_de_time_2= timeit.default_timer() #step3: decode
            v2.decode()
            end_de_time_2= timeit.default_timer()
            decode_time_2.append(end_de_time_2-start_de_time_2)
            
            start_ba_time_2= timeit.default_timer() #step4: backtrace
            (path_states_2, word_seq_2)=v2.backtrace()
            #word_seq_1=[]
            
            end_ba_time_2= timeit.default_timer()
            backtrace_time_2.append(end_ba_time_2-start_ba_time_2)
            forward_count_2.append(v2.forward_count) # after decode, the count should has values
            
            word_seq_2=word_seq_2.replace('*','') #step5: compare result with transcript and comput WER
            transcription_2=get_transcription(wav_2)
            error_counts_2=wer.compute_alignment_errors(transcription_2,word_seq_2)
            WER_2=sum(error_counts_2)/len(transcription_2.split())
            to_error_count_2.append(sum(error_counts_2))
            to_transcription_2.append(len(transcription_2.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_2,"\nSpeed of Viterbi Decoder:",decode_time_2[-1],"s"," Speed of Backtrace:",backtrace_time_2[-1],"s"," Forward Count:",forward_count_2[-1])
            #print("D:",(start_de_time_1-end_de_time_1),decode_time_1[-1],"B:",start_ba_time_1-end_ba_time_1,backtrace_time_1[-1],"F:",v1.forward_count,forward_count_1[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_2, len(transcription_2.split()))
            print("WER:",WER_2)
            print("Trascription:",transcription_2)
            print("Recognition:",word_seq_2)
        print("Average")
        print("Total Errors:",sum(to_error_count_2))
        print("Total Transcription Words:",sum(to_transcription_2))
        print("Total WER:",sum(to_error_count_2) / sum(to_transcription_2))
        print("The number of forward computations per wav: ",sum(forward_count_2)/len(forward_count_2))
        print("Average decode time:",numpy.mean(decode_time_2))
        print("Average backtrace time:",numpy.mean(backtrace_time_2))
    
    #3****uni_w
    elif mode == 'uni_w':
        #txt_files = '/group/teaching/asr/labs/recordings/*.txt'
        uni_dic_3=wfst.generate_unigram(txt_files)
        F_uni=wfst.generate_unigram_word_recognition_wfst(3,selfloop_weight,transist_weight,uni_dic_3) #step1: create wfst
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_3=0
        states_3=0
        for s_3 in F_uni.states():
            states_3+=1
            for a_3 in F_uni.arcs(s_3):
                arcs_3+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_3,arcs_3))
        
        decode_time_3=[]
        backtrace_time_3=[]
        forward_count_3=[]
        to_error_count_3=[]
        to_transcription_3=[]
        for wav_3 in glob.glob(wav_files):
            v3=ViterbiDecoder(F_uni,wav_3)
            
            start_de_time_3= timeit.default_timer() #decode
            v3.decode()
            end_de_time_3= timeit.default_timer()
            decode_time_3.append(end_de_time_3-start_de_time_3)
            
            start_ba_time_3= timeit.default_timer() #backtrace
            (path_states_3, word_seq_3)=v3.backtrace()
            #word_seq_1=[]
            
            end_ba_time_3= timeit.default_timer()
            backtrace_time_3.append(end_ba_time_3-start_ba_time_3)
            forward_count_3.append(v3.forward_count) # after decode, the count should has values
            
            word_seq_3=word_seq_3.replace('*','')
            transcription_3=get_transcription(wav_3)
            error_counts_3=wer.compute_alignment_errors(transcription_3,word_seq_3)
            WER_3=sum(error_counts_3)/len(transcription_3.split())
            to_error_count_3.append(sum(error_counts_3))
            to_transcription_3.append(len(transcription_3.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_3,"\nSpeed of Viterbi Decoder:",decode_time_3[-1],"s"," Speed of Backtrace:",backtrace_time_3[-1],"s"," Forward Count:",forward_count_3[-1])
            #print("D:",(start_de_time_1-end_de_time_1),decode_time_1[-1],"B:",start_ba_time_1-end_ba_time_1,backtrace_time_1[-1],"F:",v1.forward_count,forward_count_1[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_3, len(transcription_3.split()))
            print("WER:",WER_3)
            print("Trascription:",transcription_3)
            print("Recognition:",word_seq_3)
        print("Average")
        print("Total Errors:",sum(to_error_count_3))
        print("Total Transcription Words:",sum(to_transcription_3))
        print("Total WER:",sum(to_error_count_3) / sum(to_transcription_3))
        print("The number of forward computations per wav: ",sum(forward_count_3)/len(forward_count_3))
        print("Average decode time:",numpy.mean(decode_time_3))
        print("Average backtrace time:",numpy.mean(backtrace_time_3))
    
    #4*** Pruning implement in baseline 
    elif mode == 'pru_b_w':
        Pru_bw=wfst.generate_ow_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create wfst
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_4=0
        states_4=0
        for s_4 in Pru_bw.states():
            states_4+=1
            for a_4 in Pru_bw.arcs(s_4):
                arcs_4+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_4,arcs_4))
        
        decode_time_4=[]
        backtrace_time_4=[]
        forward_count_4=[]
        to_error_count_4=[]
        to_transcription_4=[]
        for wav_4 in glob.glob(wav_files):
            v4=ViterbiDecoder(Pru_bw,wav_4)
            
            start_de_time_4= timeit.default_timer() #decode
            v4.decode_pruning(pruning_threshold)
            end_de_time_4= timeit.default_timer()
            decode_time_4.append(end_de_time_4-start_de_time_4)
            
            start_ba_time_4= timeit.default_timer() #backtrace
            (path_states_4, word_seq_4)=v4.backtrace()
            #word_seq_1=[]
            
            end_ba_time_4= timeit.default_timer()
            backtrace_time_4.append(end_ba_time_4-start_ba_time_4)
            forward_count_4.append(v4.forward_count) # after decode, the count should has values
            
            word_seq_4=word_seq_4.replace('*','')
            transcription_4=get_transcription(wav_4)
            error_counts_4=wer.compute_alignment_errors(transcription_4,word_seq_4)
            WER_4=sum(error_counts_4)/len(transcription_4.split())
            to_error_count_4.append(sum(error_counts_4))
            to_transcription_4.append(len(transcription_4.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_4,"\nSpeed of Viterbi Decoder:",decode_time_4[-1],"s"," Speed of Backtrace:",backtrace_time_4[-1],"s"," Forward Count:",forward_count_4[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_4, len(transcription_4.split()))
            print("WER:",WER_4)
            print("Trascription:",transcription_4)
            print("Recognition:",word_seq_4)
        print("Average")
        print("Total Errors:",sum(to_error_count_4))
        print("Total Transcription Words:",sum(to_transcription_4))
        print("Total WER:",sum(to_error_count_4) / sum(to_transcription_4))
        print("The number of forward computations per wav: ",sum(forward_count_4)/len(forward_count_4))
        print("Average decode time:",numpy.mean(decode_time_4))
        print("Average backtrace time:",numpy.mean(backtrace_time_4))
    
    #5*** Beamsearch implement in baseline 
    elif mode == 'beam_b_w':
        Beam_bw=wfst.generate_ow_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create wfst
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_5=0
        states_5=0
        for s_5 in Beam_bw.states():
            states_5+=1
            for a_5 in Beam_bw.arcs(s_5):
                arcs_5+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_5,arcs_5))
        
        decode_time_5=[]
        backtrace_time_5=[]
        forward_count_5=[]
        to_error_count_5=[]
        to_transcription_5=[]
        for wav_5 in glob.glob(wav_files):
            v5=ViterbiDecoder(Beam_bw,wav_5)
            
            start_de_time_5= timeit.default_timer() #decode
            v5.decode_beamsearch(beam_threshold)
            end_de_time_5= timeit.default_timer()
            decode_time_5.append(end_de_time_5-start_de_time_5)
            
            start_ba_time_5= timeit.default_timer() #backtrace
            (path_states_5, word_seq_5)=v5.backtrace()
            #word_seq_1=[]
            
            end_ba_time_5= timeit.default_timer()
            backtrace_time_5.append(end_ba_time_5-start_ba_time_5)
            forward_count_5.append(v5.forward_count) # after decode, the count should has values
            
            word_seq_5=word_seq_5.replace('*','')
            transcription_5=get_transcription(wav_5)
            error_counts_5=wer.compute_alignment_errors(transcription_5,word_seq_5)
            WER_5=sum(error_counts_5)/len(transcription_5.split())
            to_error_count_5.append(sum(error_counts_5))
            to_transcription_5.append(len(transcription_5.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_5,"\nSpeed of Viterbi Decoder:",decode_time_5[-1],"s"," Speed of Backtrace:",backtrace_time_5[-1],"s"," Forward Count:",forward_count_5[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_5, len(transcription_5.split()))
            print("WER:",WER_5)
            print("Trascription:",transcription_5)
            print("Recognition:",word_seq_5)
        print("Average")
        print("Total Errors:",sum(to_error_count_5))
        print("Total Transcription Words:",sum(to_transcription_5))
        print("Total WER:",sum(to_error_count_5) / sum(to_transcription_5))
        print("The number of forward computations per wav: ",sum(forward_count_5)/len(forward_count_5))
        print("Average decode time:",numpy.mean(decode_time_5))
        print("Average backtrace time:",numpy.mean(backtrace_time_5))
    
    #6****add silence to baseline
    elif mode == 'sil_b_w':
        Sil_bw=wfst.generate_ow_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create wfst
        Sil_bw=wfst.add_silence_wfst(Sil_bw)
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_6=0
        states_6=0
        for s_6 in Sil_bw.states():
            states_6+=1
            for a_6 in Sil_bw.arcs(s_6):
                arcs_6+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_6,arcs_6))
        
        decode_time_6=[]
        backtrace_time_6=[]
        forward_count_6=[]
        to_error_count_6=[]
        to_transcription_6=[]
        for wav_6 in glob.glob(wav_files):
            v6=ViterbiDecoder(Sil_bw,wav_6)
            
            start_de_time_6= timeit.default_timer() #decode
            v6.decode()
            end_de_time_6= timeit.default_timer()
            decode_time_6.append(end_de_time_6-start_de_time_6)
            
            start_ba_time_6= timeit.default_timer() #backtrace
            (path_states_6, word_seq_6)=v6.backtrace()
            #word_seq_1=[]
            
            end_ba_time_6= timeit.default_timer()
            backtrace_time_6.append(end_ba_time_6-start_ba_time_6)
            forward_count_6.append(v6.forward_count) # after decode, the count should has values
            
            word_seq_6=word_seq_6.replace('*','')
            transcription_6=get_transcription(wav_6)
            error_counts_6=wer.compute_alignment_errors(transcription_6,word_seq_6)
            WER_6=sum(error_counts_6)/len(transcription_6.split())
            to_error_count_6.append(sum(error_counts_6))
            to_transcription_6.append(len(transcription_6.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_6,"\nSpeed of Viterbi Decoder:",decode_time_6[-1],"s"," Speed of Backtrace:",backtrace_time_6[-1],"s"," Forward Count:",forward_count_6[-1])
            #print("D:",(start_de_time_1-end_de_time_1),decode_time_1[-1],"B:",start_ba_time_1-end_ba_time_1,backtrace_time_1[-1],"F:",v1.forward_count,forward_count_1[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_6, len(transcription_6.split()))
            print("WER:",WER_6)
            print("Trascription:",transcription_6)
            print("Recognition:",word_seq_6)
        print("Average")
        print("Total Errors:",sum(to_error_count_6))
        print("Total Transcription Words:",sum(to_transcription_6))
        print("Total WER:",sum(to_error_count_6) / sum(to_transcription_6))
        print("The number of forward computations per wav: ",sum(forward_count_6)/len(forward_count_6))
        print("Average decode time:",numpy.mean(decode_time_6))
        print("Average backtrace time:",numpy.mean(backtrace_time_6))
    #7****preprocess the lexicon into Tree structure generate wfst
    elif mode == 'tree_b_w':
        Tree_bw=wfst.generate_ow_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create wfst
        Tree_bw=fst.determinize(Tree_bw)
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_7=0
        states_7=0
        for s_7 in Tree_bw.states():
            states_7+=1
            for a_7 in Tree_bw.arcs(s_7):
                arcs_7+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_7,arcs_7))
        
        decode_time_7=[]
        backtrace_time_7=[]
        forward_count_7=[]
        to_error_count_7=[]
        to_transcription_7=[]
        for wav_7 in glob.glob(wav_files):
            v7=ViterbiDecoder(Tree_bw,wav_7)
            
            start_de_time_7= timeit.default_timer() #decode
            v7.decode()
            end_de_time_7= timeit.default_timer()
            decode_time_7.append(end_de_time_7-start_de_time_7)
            
            start_ba_time_7= timeit.default_timer() #backtrace
            (path_states_7, word_seq_7)=v7.backtrace()
            #word_seq_1=[]
            
            end_ba_time_7= timeit.default_timer()
            backtrace_time_7.append(end_ba_time_7-start_ba_time_7)
            forward_count_7.append(v7.forward_count) # after decode, the count should has values
            
            word_seq_7=word_seq_7.replace('*','')
            transcription_7=get_transcription(wav_7)
            error_counts_7=wer.compute_alignment_errors(transcription_7,word_seq_7)
            WER_7=sum(error_counts_7)/len(transcription_7.split())
            to_error_count_7.append(sum(error_counts_7))
            to_transcription_7.append(len(transcription_7.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_7,"\nSpeed of Viterbi Decoder:",decode_time_7[-1],"s"," Speed of Backtrace:",backtrace_time_7[-1],"s"," Forward Count:",forward_count_7[-1])
            #print("D:",(start_de_time_1-end_de_time_1),decode_time_1[-1],"B:",start_ba_time_1-end_ba_time_1,backtrace_time_1[-1],"F:",v1.forward_count,forward_count_1[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_7, len(transcription_7.split()))
            print("WER:",WER_7)
            print("Trascription:",transcription_7)
            print("Recognition:",word_seq_7)
        print("Average")
        print("Total Errors:",sum(to_error_count_7))
        print("Total Transcription Words:",sum(to_transcription_7))
        print("Total WER:",sum(to_error_count_7) / sum(to_transcription_7))
        print("The number of forward computations per wav: ",sum(forward_count_7)/len(forward_count_7))
        print("Average decode time:",numpy.mean(decode_time_7))
        print("Average backtrace time:",numpy.mean(backtrace_time_7))
    #8****b_w
    elif mode == 'big_b_w':
        Bi_f=wfst.generate_ow_word_recognition_wfst(3,selfloop_weight,transist_weight) #step1: create wfst
        uni_dic_8=wfst.generate_unigram(txt_files)
        big_dic_8=wfst.generate_bigram(txt_files)
        Bi_g=wfst.generate_bigram_word_recognition_wfst(3,selfloop_weight,transist_weight,uni_dic_3,big_dic_8)
        #Bi_g=fst.determinize(Bi_g)
        #Bi_f=fst.determinize(Bi_f)
        Bi_f.arcsort(sort_type ='olabel')
        #g=fst.determinize(g)
        #f1=fst.determinize(f1)
        Bi_bw=fst.compose(Bi_f, Bi_g)
        #Bi_bw=fst.determinize(Bi_bw)
        #Bi_bw.minimize()
        #Bi_bw.rmepsilon()
        
        print("Mode: Baseline Recognition, WFST set only word label output on its arc path")
        print("Phone per state: 3\nStart Weight: -log(1.0)\nEnd Weight: initial(0.0)")
        print("Selfloop Weight:",selfloop_weight,"-log({})".format(s_w))
        print("Transist Weight:",transist_weight,"-log({})".format(t_w))
        arcs_8=0
        states_8=0
        for s_8 in Bi_bw.states():
            states_8+=1
            for a_8 in Bi_bw.arcs(s_8):
                arcs_8+=1
        print("Memory Cost: States Count: %d  Arcs Count:%d"%(states_8,arcs_8))
        
        decode_time_8=[]
        backtrace_time_8=[]
        forward_count_8=[]
        to_error_count_8=[]
        to_transcription_8=[]
        for wav_8 in glob.glob(wav_files):
            v8=ViterbiDecoder(Bi_bw,wav_8)
            
            start_de_time_8= timeit.default_timer() #decode
            v8.decode()
            end_de_time_8= timeit.default_timer()
            decode_time_8.append(end_de_time_8-start_de_time_8)
            
            start_ba_time_8= timeit.default_timer() #backtrace
            (path_states_8, word_seq_8)=v8.backtrace()
            #word_seq_1=[]
            
            end_ba_time_8= timeit.default_timer()
            backtrace_time_8.append(end_ba_time_8-start_ba_time_8)
            forward_count_8.append(v8.forward_count) # after decode, the count should has values
            
            word_seq_8=word_seq_8.replace('*','')
            transcription_8=get_transcription(wav_8)
            error_counts_8=wer.compute_alignment_errors(transcription_8,word_seq_8)
            WER_8=sum(error_counts_8)/len(transcription_8.split())
            to_error_count_8.append(sum(error_counts_8))
            to_transcription_8.append(len(transcription_8.split()))
            
            print("\n\n*****************************************")
            print("File:",wav_8,"\nSpeed of Viterbi Decoder:",decode_time_8[-1],"s"," Speed of Backtrace:",backtrace_time_8[-1],"s"," Forward Count:",forward_count_8[-1])
            #print("D:",(start_de_time_1-end_de_time_1),decode_time_1[-1],"B:",start_ba_time_1-end_ba_time_1,backtrace_time_1[-1],"F:",v1.forward_count,forward_count_1[-1])
            print("(Substitutions, Deletions, Insertions) N:", error_counts_8, len(transcription_8.split()))
            print("WER:",WER_8)
            print("Trascription:",transcription_8)
            print("Recognition:",word_seq_8)
        print("Average")
        print("Total Errors:",sum(to_error_count_8))
        print("Total Transcription Words:",sum(to_transcription_8))
        print("Total WER:",sum(to_error_count_8) / sum(to_transcription_8))
        print("The number of forward computations per wav: ",sum(forward_count_8)/len(forward_count_8))
        print("Average decode time:",numpy.mean(decode_time_8))
        print("Average backtrace time:",numpy.mean(backtrace_time_8))
    
    return