# ASR Assignment 2024-25

This notebook has been provided as a template to get you started on the assignment.  Feel free to use it for your development, or do your development directly in Python.

You can find a full description of the assignment [here](https://www.inf.ed.ac.uk/teaching/courses/asr/coursework-2025.html).

You are provided with two Python modules `observation_model.py` and `wer.py`.  The first was described in [Lab 3](https://github.com/yiwang454/asr_labs/blob/main/asr_lab3_4.ipynb).  The second can be used to compute the number of substitution, deletion and insertion errors between ASR output and a reference text.

It can be used as follows:

```python
import wer

my_refence = 'A B C'
my_output = 'A C C D'

wer.compute_alignment_errors(my_reference, my_output)
```

This produces a tuple $(s,d,i)$ giving counts of substitution,
deletion and insertion errors respectively - in this example (1, 0, 1).  The function accepts either two strings, as in the example above, or two lists.  Matching is case sensitive.

## Template code

Assuming that you have already made a function to generate an WFST, `create_wfst()` and a decoder class, `MyViterbiDecoder`, you can perform recognition on all the audio files as follows:


In [7]:
import glob
import os
import wer
import observation_model
import openfst_python as fst
import math
import time
from collections import Counter

# ... (add your code to create WFSTs and Viterbi Decoder)

In [2]:
class MyViterbiDecoder:
    
    NLL_ZERO = 1e10  # define a constant representing -log(0).  This is really infinite, but approximate
                     # it here with a very large number
    
    def __init__(self, f, audio_file_name):
        """Set up the decoder class with an audio file and WFST f
        """
        self.om = observation_model.ObservationModel()
        self.f = f
        
        if audio_file_name:
            self.om.load_audio(audio_file_name)
        else:
            self.om.load_dummy_audio()
        
        self.initialise_decoding()

        
    def initialise_decoding(self):
        """set up the values for V_j(0) (as negative log-likelihoods)
        
        """
        
        self.V = []   # stores likelihood along best path reaching state j
        self.B = []   # stores identity of best previous state reaching state j
        self.W = []   # stores output labels sequence along arc reaching j - this removes need for 
                      # extra code to read the output sequence along the best path
        
        for t in range(self.om.observation_length()+1):
            self.V.append([self.NLL_ZERO]*self.f.num_states())
            self.B.append([-1]*self.f.num_states())
            self.W.append([[] for i in range(self.f.num_states())])  #  multiplying the empty list doesn't make multiple
        
        # The above code means that self.V[t][j] for t = 0, ... T gives the Viterbi cost
        # of state j, time t (in negative log-likelihood form)
        # Initialising the costs to NLL_ZERO effectively means zero probability    
        
        # give the WFST start state a probability of 1.0   (NLL = 0.0)
        self.V[0][self.f.start()] = 0.0
        
        # some WFSTs might have arcs with epsilon on the input (you might have already created 
        # examples of these in earlier labs) these correspond to non-emitting states, 
        # which means that we need to process them without stepping forward in time.  
        # Don't worry too much about this!  
        self.traverse_epsilon_arcs(0)        
        
    def traverse_epsilon_arcs(self, t):
        """Traverse arcs with <eps> on the input at time t
        
        These correspond to transitions that don't emit an observation
        
        We've implemented this function for you as it's slightly trickier than
        the normal case.  You might like to look at it to see what's going on, but
        don't worry if you can't fully follow it.
        
        """
        
        states_to_traverse = list(self.f.states()) # traverse all states
        while states_to_traverse:
            
            # Set i to the ID of the current state, the first 
            # item in the list (and remove it from the list)
            i = states_to_traverse.pop(0)   
        
            # don't bother traversing states which have zero probability
            if self.V[t][i] == self.NLL_ZERO:
                    continue
        
            for arc in self.f.arcs(i):
                
                if arc.ilabel == 0:     # if <eps> transition
                  
                    j = arc.nextstate   # ID of next state  
                
                    if self.V[t][j] > self.V[t][i] + float(arc.weight):
                        
                        # this means we've found a lower-cost path to
                        # state j at time t.  We might need to add it
                        # back to the processing queue.
                        self.V[t][j] = self.V[t][i] + float(arc.weight)
                        
                        # save backtrace information.  In the case of an epsilon transition, 
                        # we save the identity of the best state at t-1.  This means we may not
                        # be able to fully recover the best path, but to do otherwise would
                        # require a more complicated way of storing backtrace information
                        self.B[t][j] = self.B[t][i] 
                        
                        # and save the output labels encountered - this is a list, because
                        # there could be multiple output labels (in the case of <eps> arcs)
                        if arc.olabel != 0:
                            self.W[t][j] = self.W[t][i] + [arc.olabel]
                        else:
                            self.W[t][j] = self.W[t][i]
                        
                        if j not in states_to_traverse:
                            states_to_traverse.append(j)

    
    def forward_step(self, t):
          
        for i in self.f.states():
            
            if not self.V[t-1][i] == self.NLL_ZERO:   # no point in propagating states with zero probability
                
                for arc in self.f.arcs(i):
                    
                    if arc.ilabel != 0: # <eps> transitions don't emit an observation
                        j = arc.nextstate
                        tp = float(arc.weight)  # transition prob
                        ep = -self.om.log_observation_probability(self.f.input_symbols().find(arc.ilabel), t)  # emission negative log prob
                        prob = tp + ep + self.V[t-1][i] # they're logs
                        if prob < self.V[t][j]:
                            self.V[t][j] = prob
                            self.B[t][j] = i
                            
                            # store the output labels encountered too
                            if arc.olabel !=0:
                                self.W[t][j] = [arc.olabel]
                            else:
                                self.W[t][j] = []
                            
    
    def finalise_decoding(self):
        """ this incorporates the probability of terminating at each state
        """
        
        for state in self.f.states():
            final_weight = float(self.f.final(state))
            if self.V[-1][state] != self.NLL_ZERO:
                if final_weight == math.inf:
                    self.V[-1][state] = self.NLL_ZERO  # effectively says that we can't end in this state
                else:
                    self.V[-1][state] += final_weight
                    
        # get a list of all states where there was a path ending with non-zero probability
        finished = [x for x in self.V[-1] if x < self.NLL_ZERO]
        if not finished:  # if empty
            print("No path got to the end of the observations.")
        
        
    def decode(self):
        start = time.perf_counter()
        self.initialise_decoding()
        t = 1
        while t <= self.om.observation_length():
            self.forward_step(t)
            self.traverse_epsilon_arcs(t)
            t += 1
        self.finalise_decoding()
        
        end = time.perf_counter()
        elapsed = end - start
        
        return t, elapsed
    
    def backtrace(self):
        start = time.perf_counter()
        
        best_final_state = self.V[-1].index(min(self.V[-1])) # argmin
        best_state_sequence = [best_final_state]
        best_out_sequence = []
        
        t = self.om.observation_length()   # ie T
        j = best_final_state
        
        while t >= 0:
            i = self.B[t][j]
            best_state_sequence.append(i)
            best_out_sequence = self.W[t][j] + best_out_sequence  # computer scientists might like
                                                                                # to make this more efficient!

            # continue the backtrace at state i, time t-1
            j = i  
            t-=1
            
        best_state_sequence.reverse()
        
        # convert the best output sequence from FST integer labels into strings
        best_out_sequence = ' '.join([self.f.output_symbols().find(label) for label in best_out_sequence])
        #convert out_sequence to words from the phonesdef convert_phones_to_words(phone_sequence, phone_to_word):
        temp_phones = []
        word_sequence = []
        
        for phone in best_out_sequence.split(" "):
            temp_phones.append(phone)
            
            for word, phones in lex.items():
                if phones == temp_phones:
                    word_sequence.append(word)
                    temp_phones.clear()
        word_sequence = " ".join(word_sequence)
        
        end = time.perf_counter()
        elapsed = end - start
        
        return (best_state_sequence, word_sequence, elapsed)

In [5]:
def parse_lexicon(lex_file):
    """
    Parse the lexicon file and return it in dictionary form.
    
    Args:
        lex_file (str): filename of lexicon file with structure '<word> <phone1> <phone2>...'
                        eg. peppers p eh p er z

    Returns:
        lex (dict): dictionary mapping words to list of phones
    """
    
    lex = {}  # create a dictionary for the lexicon entries (this could be a problem with larger lexica)
    with open(lex_file, 'r') as f:
        for line in f:
            line = line.split()  # split at each space
            lex[line[0]] = line[1:]  # first field the word, the rest is the phones
    return lex

def generate_symbol_tables(lexicon, n=3):
    '''
    Return word, phone and state symbol tables based on the supplied lexicon
        
    Args:
        lexicon (dict): lexicon to use, created from the parse_lexicon() function
        n (int): number of states for each phone HMM
        
    Returns:
        word_table (fst.SymbolTable): table of words
        phone_table (fst.SymbolTable): table of phones
        state_table (fst.SymbolTable): table of HMM phone-state IDs
    '''
    
    state_table = fst.SymbolTable()
    phone_table = fst.SymbolTable()
    word_table = fst.SymbolTable()
    
    # add empty <eps> symbol to all tables
    state_table.add_symbol('<eps>')
    phone_table.add_symbol('<eps>')
    word_table.add_symbol('<eps>')
    
    for word, phones  in lexicon.items():
        
        word_table.add_symbol(word)
        
        for p in phones: # for each phone
            
            phone_table.add_symbol(p)
            for i in range(1,n+1): # for each state 1 to n
                state_table.add_symbol('{}_{}'.format(p, i))
            
    return word_table, phone_table, state_table


# call these two functions
lex = parse_lexicon('lexicon.txt')
word_table, phone_table, state_table = generate_symbol_tables(lex)

def generate_phone_wfst(f, start_state, phone, n):
    """
    Generate a WFST representating an n-state left-to-right phone HMM
    
    Args:
        f (fst.Fst()): an FST object, assumed to exist already
        start_state (int): the index of the first state, assmed to exist already
        phone (str): the phone label 
        n (int): number of states for each phone HMM
        
    Returns:
        the final state of the FST
    """
    
    current_state = start_state
    
    for i in range(1, n+1):
        
        in_label = state_table.find('{}_{}'.format(phone, i))
        
        sl_weight = fst.Weight('log', -math.log(0.1))  # weight for self-loop
        # self-loop back to current state
        f.add_arc(current_state, fst.Arc(in_label, 0, sl_weight, current_state))
        
        # transition to next state
        
        # we want to output the phone label on the final state
        # note: if outputting words instead this code should be modified
        if i == n:
            out_label = phone_table.find(phone)
        else:
            out_label = 0   # output empty <eps> label
            
        next_state = f.add_state()
        next_weight = fst.Weight('log', -math.log(0.9)) # weight to next state
        f.add_arc(current_state, fst.Arc(in_label, out_label, next_weight, next_state))    
       
        current_state = next_state
        
    return current_state

In [14]:
# Load words from lexicon
word_list = []
with open("lexicon.txt", "r") as lexicon:
    for line in lexicon:
        word = line.split()[0].lower()  # Extract only the first word
        word_list.append(word)

# Count word occurrences
word_counts = Counter(word_list)

# Total number of words
total_words = sum(word_counts.values())

# Compute unigram probabilities
word_probs = {word: count / total_words for word, count in word_counts.items()}

# Print top words
# for word, prob in sorted(word_probs.items(), key=lambda x: x[1], reverse=True):
#     print(f"{word}: {prob:.5f}")

a: 0.16667
the: 0.16667
of: 0.08333
peck: 0.08333
peppers: 0.08333
peter: 0.08333
picked: 0.08333
pickled: 0.08333
piper: 0.08333
where's: 0.08333


In [None]:
def create_wfst():
    """ generate a HMM to recognise any single word sequence for words in the lexicon
    
    Args:
        n (int): states per phone HMM

    Returns:
        the constructed WFST
    
    """
    
    f = fst.Fst('log')
    n = 3
    f.set_input_symbols(state_table)
    f.set_output_symbols(phone_table)
    
    # create a single start state
    start_state = f.add_state()
    f.set_start(start_state)
    f_weight = fst.Weight('log', -math.log(0.1))
    
    start_weight = fst.Weight('log', -math.log(0.1)) # replace with the word weight
    
    for word, phones in lex.items():
        current_state = f.add_state()
        trans_weight = fst.Weight('log', -math.log(word_probs.get(word)))
        f.add_arc(start_state, fst.Arc(0, 0, trans_weight, current_state))
        
        for phone in phones: 
            current_state = generate_phone_wfst(f, current_state, phone, n)
        # note: new current_state is now set to the final state of the previous phone WFST
        
        f.set_final(current_state)
        f.add_arc(current_state, fst.Arc(0, 0, f_weight, start_state))
        
    return f

In [15]:
def read_transcription(wav_file):
    """
    Get the transcription corresponding to wav_file.
    """
    
    transcription_file = os.path.splitext(wav_file)[0] + '.txt'
    
    with open(transcription_file, 'r') as f:
        transcription = f.readline().strip()
    
    return transcription


with open("Task 2 Trans Weight Prob.txt", "w") as file:

    f = create_wfst()

    state_count = 0
    arc_count = 0
    for state in f.states():
        state_count += 1
        arc_count += len([arc for arc in f.arcs(state)])

    print(f'State Count: {state_count}, \nArc Count: {arc_count} \n')
    file.write(f"State Count: {state_count}, Arc Count: {arc_count} \n")

    for wav_file in glob.glob('/group/teaching/asr/labs/recordings/*.wav'):    # replace path if using your own
                                                                               # audio files

        decoder = MyViterbiDecoder(f, wav_file)

        decode_steps, decode_time = decoder.decode()
        (state_path, words, backtrace_time) = decoder.backtrace()  # you'll need to modify the backtrace() from Lab 4
                                                   # to return the words along the best path

        transcription = read_transcription(wav_file)
        error_counts = wer.compute_alignment_errors(transcription, words)
        word_count = len(transcription.split())

        word_error_rate = error_counts[0] + error_counts[1] + error_counts[2]

        word_error_rate = (word_error_rate/word_count)*100
        wav_file_name = wav_file.split("/")[-1]

        print(f'File: {wav_file_name}, \nErrors: {error_counts}, \nWER: {word_error_rate}, \nExpected: "{transcription}",\nactual: "{words}", \nSteps Taken: {decode_steps}, \nDecode Time: {decode_time}, \nBacktrace Time: {backtrace_time} \n')
        file.write(f'File: {wav_file_name}, \nErrors: {error_counts}, \nWER: {word_error_rate}, \nExpected: "{transcription}",\nActual: "{words}", \nSteps Taken: {decode_steps}, \nDecode Time: {decode_time}, \nBacktrace Time: {backtrace_time} \n')

print("\nDone")
        # you'll need to accumulate these to produce an overall Word Error Rate

State Count: 116, 
Arc Count: 230 

File: 0000.wav, 
Errors: (5, 0, 6), 
WER: 137.5, 
Expected: "peter piper pickled a peck of picked peppers",
actual: "the where's of a of of pickled of picked of picked a of where's", 
Steps Taken: 758, 
Decode Time: 1.5685891180764884, 
Backtrace Time: 0.0004459129413589835 

File: 0001.wav, 
Errors: (3, 0, 7), 
WER: 166.66666666666669, 
Expected: "peter the piper of pickled peppers",
actual: "the the the piper the of the a where's where's the picked the", 
Steps Taken: 878, 
Decode Time: 1.8224703250452876, 
Backtrace Time: 0.0004793819971382618 

File: 0002.wav, 
Errors: (3, 0, 6), 
WER: 100.0, 
Expected: "picked piper peter peppers pickled of peck the where's",
actual: "piper the picked a of peter pickled where's pickled where's peck a the where's the", 
Steps Taken: 784, 
Decode Time: 1.6280001089908183, 
Backtrace Time: 0.0004796029534190893 

File: 0003.wav, 
Errors: (2, 0, 4), 
WER: 100.0, 
Expected: "peppers pickled peter a peck pickled",
act

File: 0030.wav, 
Errors: (1, 0, 5), 
WER: 200.0, 
Expected: "where's peter piper",
actual: "picked pickled where's peter peter the picked the", 
Steps Taken: 383, 
Decode Time: 0.7827339869691059, 
Backtrace Time: 0.00026195100508630276 

File: 0031.wav, 
Errors: (5, 0, 5), 
WER: 166.66666666666669, 
Expected: "where's the peter peppers peppers a",
actual: "piper the where's of where's of peck where's where's piper the", 
Steps Taken: 477, 
Decode Time: 0.9763596700504422, 
Backtrace Time: 0.000304638990201056 

File: 0032.wav, 
Errors: (3, 0, 3), 
WER: 100.0, 
Expected: "peck peter the peppers of piper",
actual: "picked a the picked where's of piper picked the", 
Steps Taken: 357, 
Decode Time: 0.7288802179973572, 
Backtrace Time: 0.00022898707538843155 

File: 0033.wav, 
Errors: (4, 0, 4), 
WER: 114.28571428571428, 
Expected: "peter picked peter peppers where's the piper",
actual: "picked of picked where's of where's where's the picked picked the", 
Steps Taken: 374, 
Decode Time: 0.

File: 0058.wav, 
Errors: (1, 0, 4), 
WER: 125.0, 
Expected: "pickled peppers peter picked",
actual: "the pickled peck peck where's peter picked the", 
Steps Taken: 383, 
Decode Time: 0.7797851730138063, 
Backtrace Time: 0.00025431206449866295 

File: 0059.wav, 
Errors: (1, 0, 2), 
WER: 60.0, 
Expected: "a peck of pickled peppers",
actual: "the of peck of pickled peppers the", 
Steps Taken: 306, 
Decode Time: 0.6227480689994991, 
Backtrace Time: 0.0002115600509569049 

File: 0060.wav, 
Errors: (0, 1, 2), 
WER: 75.0, 
Expected: "peter picked a peck",
actual: "the peter picked peck the", 
Steps Taken: 340, 
Decode Time: 0.6885773680405691, 
Backtrace Time: 0.0001946109114214778 

File: 0061.wav, 
Errors: (8, 0, 6), 
WER: 82.35294117647058, 
Expected: "peter piper picked a peck of pickled peppers where's the peck of pickled peppers peter piper picked",
actual: "the pickled of picked of peck of pickled where's where's where's picked peck of pickled peck of where's the of pickled picked the"

File: 0085.wav, 
Errors: (4, 0, 7), 
WER: 122.22222222222223, 
Expected: "where's the peck of peter piper pickled peppers picked",
actual: "picked of where's of pickled peter of piper where's of peck pickled peck where's picked the", 
Steps Taken: 562, 
Decode Time: 1.1547867649933323, 
Backtrace Time: 0.0004593359772115946 

File: 0086.wav, 
Errors: (3, 0, 7), 
WER: 125.0, 
Expected: "peter peppers picked pickled peck of peter piper",
actual: "picked peck peter of piper where's picked the peck pickled pickled peter of piper the", 
Steps Taken: 528, 
Decode Time: 1.0957471509464085, 
Backtrace Time: 0.00038306310307234526 

File: 0087.wav, 
Errors: (0, 1, 4), 
WER: 100.0, 
Expected: "pickled piper picked pickled peter",
actual: "peter the pickled piper picked peter picked the", 
Steps Taken: 409, 
Decode Time: 0.8544250380946323, 
Backtrace Time: 0.0002744460944086313 

File: 0088.wav, 
Errors: (2, 1, 6), 
WER: 128.57142857142858, 
Expected: "pickled peter picked piper picked a peck",


File: 0116.wav, 
Errors: (1, 1, 2), 
WER: 57.14285714285714, 
Expected: "peter piper picked a peck of peppers",
actual: "the peter piper peter peck of peppers the", 
Steps Taken: 417, 
Decode Time: 0.8651765499962494, 
Backtrace Time: 0.0002738900948315859 

File: 0117.wav, 
Errors: (1, 0, 2), 
WER: 50.0, 
Expected: "where's the peck peter piper picked",
actual: "the where's the peck peter piper peter the", 
Steps Taken: 371, 
Decode Time: 0.7685688770143315, 
Backtrace Time: 0.0002521689748391509 

File: 0118.wav, 
Errors: (1, 0, 4), 
WER: 100.0, 
Expected: "peter piper picked pickled peppers",
actual: "the the peter of piper peter pickled peppers the", 
Steps Taken: 445, 
Decode Time: 0.9218945860629901, 
Backtrace Time: 0.0002988490741699934 

File: 0119.wav, 
Errors: (0, 0, 2), 
WER: 33.33333333333333, 
Expected: "where's the peck of pickled peppers",
actual: "the where's the peck of pickled peppers the", 
Steps Taken: 371, 
Decode Time: 0.7680725560057908, 
Backtrace Time: 0.00025

File: 0148.wav, 
Errors: (0, 0, 2), 
WER: 40.0, 
Expected: "where's the pickled peter piper",
actual: "pickled where's the pickled peter piper picked", 
Steps Taken: 349, 
Decode Time: 0.7236056169494987, 
Backtrace Time: 0.00017343705985695124 

File: 0149.wav, 
Errors: (1, 0, 3), 
WER: 80.0, 
Expected: "a peck of peter peppers",
actual: "peck where's of peck of peter peppers the", 
Steps Taken: 272, 
Decode Time: 0.5489523010328412, 
Backtrace Time: 0.00013788801152259111 

File: 0150.wav, 
Errors: (0, 0, 3), 
WER: 50.0, 
Expected: "piper peter piper peter piper picked",
actual: "the piper peter piper peter piper picked the picked", 
Steps Taken: 741, 
Decode Time: 1.5307285450398922, 
Backtrace Time: 0.0004337819991633296 

File: 0151.wav, 
Errors: (2, 0, 4), 
WER: 100.0, 
Expected: "of the pickled peppers piper picked",
actual: "the of picked pickled picked where's piper peck picked the", 
Steps Taken: 511, 
Decode Time: 1.045903660939075, 
Backtrace Time: 0.0003431210061535239 

F

File: 0180.wav, 
Errors: (4, 1, 6), 
WER: 122.22222222222223, 
Expected: "where's the peck of pickled peppers peter piper picked",
actual: "the where's of peter pickled of of where's the of of of picked the", 
Steps Taken: 417, 
Decode Time: 0.8685527039924636, 
Backtrace Time: 0.0002344090025871992 

File: 0181.wav, 
Errors: (5, 0, 5), 
WER: 58.82352941176471, 
Expected: "peter piper picked a peck of pickled peppers where's the peck of pickled peppers peter piper picked",
actual: "the peck piper of peck of peck of peck of peppers where's the peck of pickled of where's peter piper picked the", 
Steps Taken: 648, 
Decode Time: 1.3309113350696862, 
Backtrace Time: 0.0004622309934347868 

File: 0182.wav, 
Errors: (5, 0, 4), 
WER: 52.94117647058824, 
Expected: "peter piper picked a peck of pickled peppers where's the peck of pickled peppers peter piper picked",
actual: "the pickled of of picked of peck of pickled peppers where's the peck of peck peppers peter peck picked picked the", 
Step

File: 0209.wav, 
Errors: (3, 0, 6), 
WER: 150.0, 
Expected: "peck pickled peppers peter piper picked",
actual: "picked the picked peck pickled peck where's picked picked of picked the", 
Steps Taken: 571, 
Decode Time: 1.1870789609383792, 
Backtrace Time: 0.00037787307519465685 

File: 0210.wav, 
Errors: (2, 0, 7), 
WER: 150.0, 
Expected: "peter piper picked peck pickled peppers",
actual: "picked the picked picked piper picked picked peck the pickled peck where's the", 
Steps Taken: 579, 
Decode Time: 1.1865820829989389, 
Backtrace Time: 0.0003737939987331629 

File: 0211.wav, 
Errors: (4, 0, 5), 
WER: 150.0, 
Expected: "peck pickled peppers peter piper picked",
actual: "picked pickled peck of where's picked picked peck a picked picked", 
Steps Taken: 520, 
Decode Time: 1.07915281096939, 
Backtrace Time: 0.00035947992000728846 

File: 0212.wav, 
Errors: (3, 0, 8), 
WER: 183.33333333333331, 
Expected: "peter piper picked peck pickled peppers",
actual: "the of the peck a picked peck pick

File: 0239.wav, 
Errors: (6, 0, 4), 
WER: 125.0, 
Expected: "a peck of pickled peppers picked peter piper",
actual: "picked the of picked of peck where's a peppers piper picked the", 
Steps Taken: 491, 
Decode Time: 1.004885394941084, 
Backtrace Time: 0.0002482059644535184 

File: 0240.wav, 
Errors: (3, 0, 4), 
WER: 116.66666666666667, 
Expected: "a peck of piper pickled peter",
actual: "picked the of peter of piper pickled the pickled picked", 
Steps Taken: 389, 
Decode Time: 0.795195312006399, 
Backtrace Time: 0.0002825409173965454 


Done
