In [3]:
import re
import random
from collections import defaultdict, Counter
import numpy as np
import nltk
from nltk.corpus import brown
nltk.download('brown')

[nltk_data] Downloading package brown to /home/tanzid/nltk_data...
[nltk_data]   Package brown is already up-to-date!


True

#### This time, we will train on the NLTK brown corpus, keeping all the punctutation, but still use lowercase folding.

In [4]:
# get pre-tokenized sentences
sentences = list(brown.sents())

In [8]:
# make everything lowercase and add start and end tokens
start_token = '<s>'        
end_token = '</s>'
sentences_tokenized = [[start_token]*2+[w.lower() for w in s]+[end_token] for s in sentences]

# now we split the data into train and test sentences
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)
test_idx = random.sample(range(num_sent), num_test)

sentences_train = []
sentences_test = []
for i in range(num_sent):
    if i not in test_idx:
        sentences_train.append(sentences_tokenized[i])
    else:
        sentences_test.append(sentences_tokenized[i])    

print(f"Number of training sentences: {len(sentences_train)}")        
print(f"Number of test sentences: {len(sentences_test)}")        

Number of training sentences: 51606
Number of test sentences: 5734


In [9]:
for i in range(5):
    print(sentences_train[i])

['<s>', '<s>', 'the', 'fulton', 'county', 'grand', 'jury', 'said', 'friday', 'an', 'investigation', 'of', "atlanta's", 'recent', 'primary', 'election', 'produced', '``', 'no', 'evidence', "''", 'that', 'any', 'irregularities', 'took', 'place', '.', '</s>']
['<s>', '<s>', 'the', 'jury', 'further', 'said', 'in', 'term-end', 'presentments', 'that', 'the', 'city', 'executive', 'committee', ',', 'which', 'had', 'over-all', 'charge', 'of', 'the', 'election', ',', '``', 'deserves', 'the', 'praise', 'and', 'thanks', 'of', 'the', 'city', 'of', 'atlanta', "''", 'for', 'the', 'manner', 'in', 'which', 'the', 'election', 'was', 'conducted', '.', '</s>']
['<s>', '<s>', 'the', 'september-october', 'term', 'jury', 'had', 'been', 'charged', 'by', 'fulton', 'superior', 'court', 'judge', 'durwood', 'pye', 'to', 'investigate', 'reports', 'of', 'possible', '``', 'irregularities', "''", 'in', 'the', 'hard-fought', 'primary', 'which', 'was', 'won', 'by', 'mayor-nominate', 'ivan', 'allen', 'jr.', '.', '</s>']

In [39]:
class trigram_LM_addk():

    def __init__(self, count_threshold=2, k=0.1):
        self.count_threshold = count_threshold 
        self.k = k
        self.bigram_counts = None
        self.unigram_counts = None
        self.trigram_counts = None
        self.vocab = None
        self.word2idx = None
        self.num_sentences = None
        self.unk_token = '<UNK>'
        self.start_token = '<s>'        
        self.end_token = '</s>'

    def train(self, sentences):
        self.num_sentences = len(sentences)
        self.vocab, self.unigram_counts, self.bigram_counts, self.trigram_counts = self.get_counts(sentences)
        self.vocab = list(self.unigram_counts.keys())
        self.word2idx = {word:i for i,word in enumerate(self.vocab)}
        print("Training complete!")         

    def get_counts(self, sentences):
        # collect unigram counts 
        print("Collecting unigram counts...")
        unigram_counts = Counter()
        for s in sentences:
            for word in s:
                unigram_counts[word] += 1 
        
        # remove all words that have count below the threshold    
        print("Constructing vocab...")     
        for w in list(unigram_counts.keys()):
            if unigram_counts[w] < self.count_threshold:
                unigram_counts.pop(w)
        # construct vocab 
        vocab = [self.unk_token] + sorted(list(unigram_counts.keys()))            
        
        # replace all oov tokens in training sentences with <UNK>
        print("Replacing with oov tokens in training data...")
        sentences_unk = []
        for s in sentences:
            sent = []
            for word in s:
                if word in vocab:
                    sent.append(word)
                else:
                    sent.append(self.unk_token)
            sentences_unk.append(sent)            

        # re-collect unigram counts 
        print("Re-collecting unigram counts...")
        unigram_counts = Counter()
        for s in sentences_unk:
            for word in s:
                unigram_counts[word] += 1 
        print(f"Total num unigrams: {len(unigram_counts)}")        

        # collect bigram counts
        print("Collecting bigram counts...")
        bigram_counts = Counter()
        for s in sentences_unk:
            for bigram in zip(s[:-1], s[1:]):
                bigram_counts[bigram] += 1     
        print(f"Total num bigrams: {len(bigram_counts)}")        
    
        # collect trigram counts
        print("Collecting trigram counts...")
        trigram_counts = Counter()
        for s in sentences_unk:
            for trigram in zip(s[:-2], s[1:-1], s[2:]):
                trigram_counts[trigram] += 1     
        print(f"Total num trigrams: {len(trigram_counts)}")                

        return vocab, unigram_counts, bigram_counts, trigram_counts
    
    def compute_probs(self, word1, word2):
        probs = []
        for word3 in self.vocab:
            # compute P(word3|word1,word2)
            p = self.tg_prob(word1, word2, word3)
            probs.append(p)
        return probs      
    
    def tg_prob(self, word1, word2, word3):
        # addk probability
        p = (self.trigram_counts[(word1,word2,word3)] + self.k) / (self.bigram_counts[(word1,word2)] + self.k*len(self.vocab)) 
        return p        
    

class trigram_LM_interp():

    def __init__(self, count_threshold=2, lmda = [0.01, 0.2, 0.3, 0.49]):
        self.count_threshold = count_threshold 
        self.lmda = lmda
        self.bigram_counts = None
        self.unigram_counts = None
        self.trigram_counts = None
        self.total_tokens = None
        self.vocab = None
        self.word2idx = None
        self.num_sentences = None
        self.unk_token = '<UNK>'
        self.start_token = '<s>'        
        self.end_token = '</s>'

    def train(self, sentences):
        self.num_sentences = len(sentences)
        self.vocab, self.unigram_counts, self.bigram_counts, self.trigram_counts, self.total_tokens = self.get_counts(sentences)
        self.vocab = list(self.unigram_counts.keys())
        self.word2idx = {word:i for i,word in enumerate(self.vocab)}
        print("Training complete!")         

    def get_counts(self, sentences):
        # collect unigram counts 
        print("Collecting unigram counts...")
        unigram_counts = Counter()
        for s in sentences:
            for word in s:
                unigram_counts[word] += 1 
        
        # remove all words that have count below the threshold    
        print("Constructing vocab...")     
        for w in list(unigram_counts.keys()):
            if unigram_counts[w] < self.count_threshold:
                unigram_counts.pop(w)
        # construct vocab 
        vocab = [self.unk_token] + sorted(list(unigram_counts.keys()))            
        
        # replace all oov tokens in training sentences with <UNK>
        print("Replacing with oov tokens in training data...")
        sentences_unk = []
        for s in sentences:
            sent = []
            for word in s:
                if word in vocab:
                    sent.append(word)
                else:
                    sent.append(self.unk_token)
            sentences_unk.append(sent)            

        # re-collect unigram counts 
        print("Re-collecting unigram counts...")
        unigram_counts = Counter()
        total_tokens = 0
        for s in sentences_unk:
            for word in s:
                unigram_counts[word] += 1 
                total_tokens += 1
        print(f"Total num unigrams: {len(unigram_counts)}")        

        # collect bigram counts
        print("Collecting bigram counts...")
        bigram_counts = Counter()
        for s in sentences_unk:
            for bigram in zip(s[:-1], s[1:]):
                bigram_counts[bigram] += 1     
        print(f"Total num bigrams: {len(bigram_counts)}")        
    
        # collect trigram counts
        print("Collecting trigram counts...")
        trigram_counts = Counter()
        for s in sentences_unk:
            for trigram in zip(s[:-2], s[1:-1], s[2:]):
                trigram_counts[trigram] += 1     
        print(f"Total num trigrams: {len(trigram_counts)}")                

        return vocab, unigram_counts, bigram_counts, trigram_counts, total_tokens
    
    def compute_probs(self, word1, word2):
        probs = []
        for word3 in self.vocab:
            # compute P(word3|word1,word2)
            p = self.tg_prob(word1, word2, word3)
            probs.append(p)
        return probs

    def tg_prob(self, word1, word2, word3):
        # linearly interpolated probability
        p_zerogram = self.lmda[0] / len(self.vocab)
        p_unigram  = self.lmda[1] * self.unigram_counts[word3] / self.total_tokens 
        p_bigram   = self.lmda[2] * self.bigram_counts[(word2, word3)] / self.unigram_counts[word2] 
        if self.bigram_counts[(word1, word2)] > 0:
            p_trigram  = self.lmda[3] * self.trigram_counts[(word1, word2, word3)] / self.bigram_counts[(word1, word2)]  
        else:
            p_trigram = 0
        p = p_zerogram + p_unigram + p_bigram + p_trigram
        return p        


In [20]:
def generate_text(model, n=10, max_len=200):
    sentences = []
    i = 0
    for i in range(n):
        context_w1 = '<s>'
        context_w2 = '<s>'
        words = []    
        while True:
            # get probabilities of next word given current context, i.e P(w|context_w1, context_w2)
            probs = model.compute_probs(context_w1, context_w2)
            # now sample from the vocabulry according to this distribution
            next_word = random.choices(model.vocab, weights=probs, k=1)[0]
            if next_word == '</s>' or len(words) == max_len:
                break
            if next_word == '<s>':
                continue    
            words.append(next_word)
            context_w1 = context_w2
            context_w2 = next_word
        if len(words) > 0:    
            sentences.append(" ".join(words))
        i += 1
         
        
    return "\n".join(sentences)   

In [37]:
model = trigram_LM_addk(k=0.0000001)
model.train(sentences_train)

Collecting unigram counts...
Constructing vocab...
Replacing with oov tokens in training data...
Re-collecting unigram counts...
Total num unigrams: 26362
Collecting bigram counts...
Total num bigrams: 369300
Collecting trigram counts...
Total num trigrams: 770624
Training complete!


In [40]:
text = generate_text(model, n=5)
print(text)

a description of what their neighbors , studying interviews with members of the cake with chocolate <UNK> .
in spite of this abuse , he was sympathetic to the boarding ladder together .
grabski looked at lawrence with a heavy list to dictionary forms and shapes and made a world-shaking contribution to hollywood , who gave hope and courage to the private detective ( at the right tone of kindly <UNK> : term of office buildings in the air
the check , mercer collaborated with many early theologians , especially since he was `` stung '' with the <UNK> sexual coming of the dead people , their murderers dismembered the bodies are missing , for hosses had <UNK> from saratoga up to the editor of the new spirit , these blocks were set one within the system increases , and she hoped he was still there and waited for it is important here is truly majestic and an orchestra will take a look at penny .
he hated them too .


In [29]:
from multiprocessing import Pool
from functools import partial

In [30]:
def compute_log_prob(model, s):
    sum_log_probs = 0.0
    n = 0
    for w1,w2,w3 in zip(s[:-2], s[1:-1], s[2:]):
        # replace any oov token with <UNK>
        if w1 not in model.vocab:
            w1 = model.unk_token    
        if w2 not in model.vocab:
            w2 = model.unk_token
        if w3 not in model.vocab:
            w3 = model.unk_token
        sum_log_probs += np.log(model.tg_prob(w1, w2, w3))
        n += 1
    return sum_log_probs, n

def compute_perplexity(model, test_sentences, num_procs=8):
    # create a partial function with model argument fixed
    func = partial(compute_log_prob, model)
    # distribute computation across parallel processes
    with Pool(num_procs) as p:
        results = p.map(func, test_sentences)
    sum_log_probs = sum(result[0] for result in results)
    n = sum(result[1] for result in results)
    sum_log_probs *= (-1/n) 
    perplexity = np.exp(sum_log_probs)
    return perplexity 

### Now evaluate the add-k trigram model.

In [32]:
model = trigram_LM_addk(k=0.01)
model.train(sentences_train)

# now lets compute perplexity on both the training and test data for different k values
kvals = [1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001]
for k in kvals:
    model.k = k
    pp_train = compute_perplexity(model, sentences_train)
    pp_test = compute_perplexity(model, sentences_test)

    print(f"\nk = {k}")
    print(f"Perplexity computed on training set: {pp_train:.3f}")
    print(f"Perplexity computed on test set: {pp_test:.3f}")

Collecting unigram counts...
Constructing vocab...
Replacing with oov tokens in training data...
Re-collecting unigram counts...
Total num unigrams: 26362
Collecting bigram counts...
Total num bigrams: 369300
Collecting trigram counts...
Total num trigrams: 770624
Training complete!

k = 1.0
Perplexity computed on training set: 6423.655
Perplexity computed on test set: 9944.344

k = 0.1
Perplexity computed on training set: 1164.288
Perplexity computed on test set: 5172.969

k = 0.01
Perplexity computed on training set: 177.566
Perplexity computed on test set: 3058.485

k = 0.001
Perplexity computed on training set: 36.185
Perplexity computed on test set: 2512.401

k = 0.0001
Perplexity computed on training set: 12.569
Perplexity computed on test set: 3249.263

k = 1e-05
Perplexity computed on training set: 8.315
Perplexity computed on test set: 6355.444


#### Note that the perplexity on the training set gets really low, below 10 indicating that the model is able to fit the training data quite well. However also note that the test set perplexity values are very large (much larger than what we saw for the bigram model) which indicates that the model may be severely overfitting to the training set and does not generalize well to the test set.

In [41]:
model = trigram_LM_interp()
model.train(sentences_train)

lambda3_vals = [0.3, 0.5, 0.7, 0.8, 0.85]
for l3 in lambda3_vals:
    model.lmda = [0.01, 0.1, 0.89-l3 ,l3]
    pp_train = compute_perplexity(model, sentences_train)
    pp_test = compute_perplexity(model, sentences_test)

    print(f"\nlambdas = {model.lmda}")
    print(f"Perplexity computed on training set: {pp_train:.3f}")
    print(f"Perplexity computed on test set: {pp_test:.3f}")

Collecting unigram counts...
Constructing vocab...
Replacing with oov tokens in training data...
Re-collecting unigram counts...
Total num unigrams: 26362
Collecting bigram counts...
Total num bigrams: 369300
Collecting trigram counts...
Total num trigrams: 770624
Training complete!

lambdas = [0.01, 0.1, 0.5900000000000001, 0.3]
Perplexity computed on training set: 16.401
Perplexity computed on test set: 299.602

lambdas = [0.01, 0.1, 0.39, 0.5]
Perplexity computed on training set: 12.249
Perplexity computed on test set: 326.992

lambdas = [0.01, 0.1, 0.19000000000000006, 0.7]
Perplexity computed on training set: 9.949
Perplexity computed on test set: 398.139

lambdas = [0.01, 0.1, 0.08999999999999997, 0.8]
Perplexity computed on training set: 9.134
Perplexity computed on test set: 490.479

lambdas = [0.01, 0.1, 0.040000000000000036, 0.85]
Perplexity computed on training set: 8.782
Perplexity computed on test set: 600.247


#### With interpolation smoothing, the test set perplexities are substantially lower, indicating that the model is not severely overfitting to the training set which happened with the add-k smoothing. It seems to do a better job at generalizing to the test set.

In [46]:
model = trigram_LM_interp(lmda=[0.01, 0.1, 0.39 ,0.5])
model.train(sentences_train)

Collecting unigram counts...
Constructing vocab...
Replacing with oov tokens in training data...


In [None]:
text = generate_text(model, n=10)
print(text)

bathing the itching parts of the painting , unless it made only presidential appointees to concern its watershed lands from erosion and <UNK> showed that operations continued through the hymen .
about noon they came out durability peculiar to it , did the heated air , organized to furnish a statement was also its weakness .
`` from its dullest season .
she became intelligible .
yes '' area at least be safe to assume a center of the see considerably narrower ground of the open scaffold .
