In [81]:
import pandas as pd
from collections import defaultdict

emission_probs = pd.read_csv('emission_probs.txt')
emission_probs_np = emission_probs.to_numpy()

transition_probs = pd.read_csv('transition_probs.txt')
transition_probs_np = transition_probs.to_numpy()

In [83]:
# Viterbi Algo
# sqeuence: list of tokens/words
def viterbi(sequence, emission_probs, transition_probs):
    emission_probs_tb = defaultdict(lambda: defaultdict(float))
    transition_probs_tb = defaultdict(lambda: defaultdict(float))

    for e in emission_probs:
        emission_probs_tb[e[0]][e[1]] = e[2]
    
    for t in transition_probs:
        transition_probs_tb[t[1]][t[0]] = t[2]

    pi = defaultdict(lambda: defaultdict(float))
    pi[0]['start'] = 1.0

    N = len(sequence)
    for k in range(1, N+1):
        curr_word = sequence[k-1]

        valid_tags = emission_probs_tb[curr_word].keys()
        for t in valid_tags:
            prev_tags = pi[k-1].keys()
            for s in prev_tags:
                pi[k][t] = max(pi[k][t], pi[k-1][s] * transition_probs_tb[t][s] * emission_probs_tb[curr_word][t])
        
        #print(pi)

    for a in pi:
        for b in pi[a]:
            print(a, b, pi[a][b])
    
    return pi

sequence = "time flies like an arrow".split(" ")
pi = viterbi(sequence, emission_probs_np, transition_probs_np)
pi


0 start 1.0
1 VBZ 0.020000000000000004
1 NNZ 0.06
2 VBZ 0.010799999999999999
2 NNZ 0.0024000000000000002
3 VBZ 0.00072
3 IN 0.00216
4 DT 0.0015119999999999999
5 NNZ 0.0007559999999999999


defaultdict(<function __main__.viterbi.<locals>.<lambda>()>,
            {0: defaultdict(float, {'start': 1.0}),
             1: defaultdict(float, {'VBZ': 0.020000000000000004, 'NNZ': 0.06}),
             2: defaultdict(float,
                         {'VBZ': 0.010799999999999999,
                          'NNZ': 0.0024000000000000002}),
             3: defaultdict(float, {'VBZ': 0.00072, 'IN': 0.00216}),
             4: defaultdict(float, {'DT': 0.0015119999999999999}),
             5: defaultdict(float, {'NNZ': 0.0007559999999999999})})

In [87]:
# pi table from viterbi
def viterbi_backtrack_pos(sequence, emission_probs, transition_probs):
    pi = viterbi(sequence, emission_probs, transition_probs)
    pos_tags = []

    for n in range(len(sequence), 0, -1):
        max_tag = max(pi[n], key=lambda x: pi[n][x])
        pos_tags.append(max_tag)
    
    return list(reversed(pos_tags))

viterbi_backtrack_pos(sequence, emission_probs_np, transition_probs_np)

0 start 1.0
1 VBZ 0.020000000000000004
1 NNZ 0.06
2 VBZ 0.010799999999999999
2 NNZ 0.0024000000000000002
3 VBZ 0.00072
3 IN 0.00216
4 DT 0.0015119999999999999
5 NNZ 0.0007559999999999999


['NNZ', 'VBZ', 'IN', 'DT', 'NNZ']

In [72]:
# Forward Algo
# sqeuence: list of tokens/words
def forward(sequence, emission_probs, transition_probs):
    emission_probs_tb = defaultdict(lambda: defaultdict(float))
    transition_probs_tb = defaultdict(lambda: defaultdict(float))

    for e in emission_probs:
        emission_probs_tb[e[0]][e[1]] = e[2]
    
    for t in transition_probs:
        transition_probs_tb[t[1]][t[0]] = t[2]

    pi = defaultdict(lambda: defaultdict(float))
    pi[0]['start'] = 1.0

    N = len(sequence)
    for k in range(1, N+1):
        curr_word = sequence[k-1]

        valid_tags = emission_probs_tb[curr_word].keys()
        for t in valid_tags:
            prev_tags = pi[k-1].keys()
            for s in prev_tags:
                pi[k][t] += pi[k-1][s] * transition_probs_tb[t][s] * emission_probs_tb[curr_word][t]
        

    for a in pi:
        for b in pi[a]:
            print(a, b, pi[a][b])

    print(sum(pi[N].values()))

    return pi

sequence = "time flies like an arrow".split(" ")
forward(sequence, emission_probs_np, transition_probs_np)

0 start 1.0
1 VBZ 0.020000000000000004
1 NNZ 0.06
2 VBZ 0.010799999999999999
2 NNZ 0.004000000000000001
3 VBZ 0.0012000000000000003
3 IN 0.0029600000000000004
4 DT 0.0025520000000000004
5 NNZ 0.0012760000000000002
0.0012760000000000002


defaultdict(<function __main__.forward.<locals>.<lambda>()>,
            {0: defaultdict(float, {'start': 1.0}),
             1: defaultdict(float, {'VBZ': 0.020000000000000004, 'NNZ': 0.06}),
             2: defaultdict(float,
                         {'VBZ': 0.010799999999999999,
                          'NNZ': 0.004000000000000001}),
             3: defaultdict(float,
                         {'VBZ': 0.0012000000000000003,
                          'IN': 0.0029600000000000004}),
             4: defaultdict(float, {'DT': 0.0025520000000000004}),
             5: defaultdict(float, {'NNZ': 0.0012760000000000002})})

In [27]:
import os
from collections import defaultdict

def get_ngrams(sequence, n):
    """
    Given a sequence, return a list of n-grams, where each n-gram is a Python tuple.
    This should work for arbitrary values of 1 <= n < len(sequence).
    """

    # Modify sequence to include START/STOP
    if n <= 2:
        sequence = ['START'] + sequence
    else:
        sequence = ['START'] * (n-1) + sequence

    sequence += ['STOP']

    n_grams = []

    for i in range(len(sequence) - n + 1):
        n_gram_tup = tuple(sequence[i:i+n])
        n_grams.append(n_gram_tup)

    return n_grams


def estimate_hmm_params():
    unigram_counts = defaultdict(int)
    bigram_counts = defaultdict(int)

    word_emission_counts = defaultdict(lambda: defaultdict(int))

    for fname in os.listdir('tagged'):
        file_reader = open(os.path.join('tagged', fname), 'r')

        curr_words = []
        curr_tags = []

        for line in file_reader:
            if line[0] == "=" or len(line) == 0:
                continue

            tags = line.strip().split(" ")

            for t in tags:
                if t == '[' or t == ']' or len(t) == 0:
                    continue
                
                word_tag = t.split("/")
                word, tag = word_tag[0].lower(), word_tag[1]

                curr_words.append(word)
                curr_tags.append(tag)

                word_emission_counts[word][tag] += 1

                if word == "." and tag == ".":
                    unigrams = get_ngrams(curr_tags, 1)
                    bigrams = get_ngrams(curr_tags, 2)

                    for u in unigrams:
                        unigram_counts[u] += 1
                    
                    for b in bigrams:
                        bigram_counts[b] += 1
                    
                    curr_words = []
                    curr_tags = []

                
                #print(word, tag)
                
    return unigram_counts, bigram_counts, word_emission_counts

In [28]:
unigram_counts, bigram_counts, word_emission_counts = estimate_hmm_params()

In [31]:
transition_probs_tb = defaultdict(lambda: defaultdict(float))
emission_probs_tb = defaultdict(lambda: defaultdict(float))

for bigram in bigram_counts.keys():
    transition_probs_tb[bigram[1]][bigram[0]] = bigram_counts[bigram] / unigram_counts[(bigram[0],)]
    print(bigram, bigram_counts[bigram] / unigram_counts[(bigram[0],)])

for word in word_emission_counts.keys():
    for tag in word_emission_counts[word].keys():
        emission_probs_tb[word][tag] = word_emission_counts[word][tag] / unigram_counts[(tag, )]
        print(word, tag, word_emission_counts[word][tag] / unigram_counts[(tag, )])

('START', 'CD') 0.0073145245559038665
('CD', 'NNS') 0.1530791788856305
('NNS', 'PRP') 0.005954349983460139
('PRP', 'MD') 0.12645687645687645
('MD', 'VB') 0.8155339805825242
('VB', 'IN') 0.11942051683633516
('IN', '$') 0.026809920674766542
('$', 'CD') 0.988950276243094
('CD', 'CC') 0.020821114369501466
('CC', 'JJR') 0.010596026490066225
('JJR', ':') 0.005249343832020997
(':', 'LS') 0.008880994671403197
('LS', '.') 0.38461538461538464
('.', 'STOP') 0.9881259679917398
('START', 'VB') 0.004440961337513062
('VB', 'DT') 0.2364917776037588
('DT', 'JJ') 0.20269442743417024
('JJ', 'NNP') 0.033773358477627295
('NNP', '.') 0.05184160102192889
('START', 'LS') 0.0013061650992685476
('JJ', 'NN') 0.44848277044402535
('NN', '.') 0.10622293138819239
('VB', 'PRP$') 0.039545810493343776
('PRP$', 'NN') 0.43994778067885115
('NN', 'TO') 0.03996656788997797
('TO', 'DT') 0.12987608994951813
('DT', 'NN') 0.470789957134109
('NN', 'NN') 0.12628219740141328
('NN', 'IN') 0.24747359623128942
('IN', 'DT') 0.31679887