In [2]:
import numpy as np

In [3]:
# Load up text as list of lists of string tokens.
with open("15pctmasked.txt") as f:
    lines = f.readlines()
    sentences = []
    for line in lines:
        sentences.append(line.split())

In [4]:
# Load up bigram model as dict of dicts mapping to log probabilities, as well as vocabulary.
with open("lm.txt") as f:
    model = {}
    vocab = {}
    vocab_list = []
    index = 0
    for line in f.readlines():
        components = line.split()
        if components[0] not in model:
            model[components[0]] = {}
        model[components[0]][components[1]] = np.log(float(components[2]))
        if components[0] not in vocab:
            vocab[components[0]] = index
            vocab_list.append(components[0])
            index += 1
            
    # Fill out unseen words.
    for key1 in model:
        for key2 in vocab:
            if key2 not in model[key1]:
                model[key1][key2] = -np.inf

In [5]:
MASK = "<mask>"
START = "<start>"
STOP = "<eos>"

In [75]:
def make_score_table(sentence, model, vocab):
    seq_len = len(sentence)
    vocab_size = len(vocab)
    s = np.zeros((seq_len - 1, vocab_size, vocab_size))
    
    for i in range(seq_len - 1):
        if sentence[i + 1] == MASK:
            for prev_key in vocab:
                for key in vocab:
                    s[i, vocab[prev_key], vocab[key]] = model[prev_key][key]
        else:
            for prev_key in vocab:
                for key in vocab:
                    s[i, vocab[prev_key], vocab[key]] = 0 if key == sentence[i + 1] else -np.inf
                    
    return s
            

def decode(sentence, model, vocab, vocab_list):
    """
    Decode '<mask>' tokens using Viterbi algorithm.
    """
    seq_len = len(sentence)
    vocab_size = len(vocab)

    decoded = []
    
    # Viterbi variables.
    v = np.zeros((seq_len - 2, vocab_size))
    b = np.zeros((seq_len - 2, vocab_size), dtype = int)
    s = make_score_table(sentence, model, vocab)
    
    v[0, :] = s[0, vocab[START], :]
        
    for i in range(1, seq_len - 2):
        v[i, :] = np.max(s[i, :, :] + np.expand_dims(v[i - 1, :], 1), axis = 0)
        b[i, :] = np.argmax(s[i, :, :] + np.expand_dims(v[i - 1, :], 1), axis = 0)
        
    # Append last word.
    decoded.append(STOP)
    word = vocab_list[np.argmax(s[seq_len - 2, :, vocab[STOP]] + v[seq_len - 3, :])]
    decoded.append(word)
    for i in range(seq_len - 3, 0, -1):
        word = vocab_list[b[i, vocab[word]]]
        decoded.append(word)
    decoded.append(START)
    decoded.reverse()
    
    return decoded

In [76]:
decoded = decode(sentences[0], model, vocab, vocab_list)
print(decoded)

['<start>', 'I', '<s>', 'p', 'e', '<s>', 'm', 'e', 'n', 't', 'a', 't', 'i', 'o', 'n', '<s>', 'o', 'f', '<s>', 'G', 'e', 'o', 'r', 'g', 'i', 'a', "'", '<s>', '<s>', 'a', 'u', 'r', 'o', 'm', 'o', 'b', 'i', 'l', 'e', '<s>', 't', 'i', 't', 'l', 'e', '<s>', 'l', 'a', 'w', '<s>', 'w', 'a', 's', '<s>', 'a', 'l', 'y', '<s>', '<s>', 't', 'e', 'c', 'o', 'm', 'm', 'e', 'n', 'd', 'e', 'd', '<s>', 'b', 'e', '<s>', 't', 'h', 'e', '<s>', 'o', 'u', 't', 'g', 'o', 'i', 'n', 'g', '<s>', 'j', 'u', 'r', 'y', '<s>', '.', '<eos>']


In [77]:
try:
    with open("output.txt", "w") as f:
        for i, sentence in enumerate(sentences):
            if i % 100 == 0:
                print("Writing sentence %d ..." % i)
            decoded = decode(sentence, model, vocab, vocab_list)
            f.write("%s\n" % (" ".join(decoded)))
except KeyboardInterrupt:
    print("Graceful Exit")

Writing sentence 0 ...
Writing sentence 100 ...
Writing sentence 200 ...
Writing sentence 300 ...
Writing sentence 400 ...
Writing sentence 500 ...
Writing sentence 600 ...
Writing sentence 700 ...
Writing sentence 800 ...
Writing sentence 900 ...
Writing sentence 1000 ...
Writing sentence 1100 ...
Writing sentence 1200 ...
Writing sentence 1300 ...
Writing sentence 1400 ...
Writing sentence 1500 ...
Writing sentence 1600 ...
Writing sentence 1700 ...
Writing sentence 1800 ...
Writing sentence 1900 ...
Writing sentence 2000 ...
Writing sentence 2100 ...
Writing sentence 2200 ...
Writing sentence 2300 ...
Writing sentence 2400 ...
Writing sentence 2500 ...
Writing sentence 2600 ...
Writing sentence 2700 ...
Writing sentence 2800 ...
Writing sentence 2900 ...
Writing sentence 3000 ...
Writing sentence 3100 ...
Writing sentence 3200 ...
Writing sentence 3300 ...
Writing sentence 3400 ...
Writing sentence 3500 ...
Writing sentence 3600 ...
Writing sentence 3700 ...
Writing sentence 3800 ..