In [1]:
import re
import random
from collections import defaultdict, Counter
import numpy as np
import nltk
from nltk.corpus import brown
nltk.download('brown')

[nltk_data] Downloading package brown to /home/tanzid/nltk_data...
[nltk_data]   Package brown is already up-to-date!


True

#### This time, we will train on the NLTK brown corpus, keeping all the punctutation, but still use lowercase folding.

In [2]:
# get pre-tokenized sentences
sentences = list(brown.sents())

In [3]:
# make everything lowercase and add start and end tokens
start_token = '<s>'        
end_token = '</s>'
sentences_tokenized = [[start_token]+[w.lower() for w in s]+[end_token] for s in sentences]

# now we split the data into train and test sentences
num_sent = len(sentences_tokenized)
num_test = int(0.1 * num_sent)
test_idx = random.sample(range(num_sent), num_test)

sentences_train = []
sentences_test = []
for i in range(num_sent):
    if i not in test_idx:
        sentences_train.append(sentences_tokenized[i])
    else:
        sentences_test.append(sentences_tokenized[i])    

print(f"Number of training sentences: {len(sentences_train)}")        
print(f"Number of test sentences: {len(sentences_test)}")        

Number of training sentences: 51606
Number of test sentences: 5734


In [4]:
for i in range(5):
    print(sentences_train[i])

['<s>', 'the', 'jury', 'further', 'said', 'in', 'term-end', 'presentments', 'that', 'the', 'city', 'executive', 'committee', ',', 'which', 'had', 'over-all', 'charge', 'of', 'the', 'election', ',', '``', 'deserves', 'the', 'praise', 'and', 'thanks', 'of', 'the', 'city', 'of', 'atlanta', "''", 'for', 'the', 'manner', 'in', 'which', 'the', 'election', 'was', 'conducted', '.', '</s>']
['<s>', 'the', 'september-october', 'term', 'jury', 'had', 'been', 'charged', 'by', 'fulton', 'superior', 'court', 'judge', 'durwood', 'pye', 'to', 'investigate', 'reports', 'of', 'possible', '``', 'irregularities', "''", 'in', 'the', 'hard-fought', 'primary', 'which', 'was', 'won', 'by', 'mayor-nominate', 'ivan', 'allen', 'jr.', '.', '</s>']
['<s>', '``', 'only', 'a', 'relative', 'handful', 'of', 'such', 'reports', 'was', 'received', "''", ',', 'the', 'jury', 'said', ',', '``', 'considering', 'the', 'widespread', 'interest', 'in', 'the', 'election', ',', 'the', 'number', 'of', 'voters', 'and', 'the', 'size'

In [None]:
class trigram_LM_addk():

    def __init__(self, count_threshold=2, k=0.1):
        self.count_threshold = count_threshold 
        self.k = k
        self.bigram_counts = None
        self.unigram_counts = None
        self.vocab = None
        self.word2idx = None
        self.bigram_probs = None
        self.num_sentences = None
        self.unk_token = '<UNK>'
        self.start_token = '<s>'        
        self.end_token = '</s>'

    def train(self, sentences):
        self.num_sentences = len(sentences)
        self.vocab, self.unigram_counts, self.bigram_counts = self.get_counts(sentences)
        self.vocab = list(self.unigram_counts.keys())
        self.word2idx = {word:i for i,word in enumerate(self.vocab)}
        self.compute_probs()
        print("Training complete!")         

    def get_counts(self, sentences):
        # collect unigram counts 
        print("Collecting unigram counts...")
        unigram_counts = Counter()
        for s in sentences:
            for word in s:
                unigram_counts[word] += 1 
        
        # remove all words that have count below the threshold    
        print("Constructing vocab...")     
        for w in list(unigram_counts.keys()):
            if unigram_counts[w] < self.count_threshold:
                unigram_counts.pop(w)
        # construct vocab 
        vocab = [self.unk_token] + sorted(list(unigram_counts.keys()))            
        
        # replace all oov tokens in training sentences with <UNK>
        print("Replacing with oov tokens in training data...")
        sentences_unk = []
        for s in sentences:
            sent = []
            for word in s:
                if word in vocab:
                    sent.append(word)
                else:
                    sent.append(self.unk_token)
            sentences_unk.append(sent)            

        # re-collect unigram counts 
        print("Re-collecting unigram counts...")
        unigram_counts = Counter()
        for s in sentences_unk:
            for word in s:
                unigram_counts[word] += 1 
        print(f"Total num unigrams: {len(unigram_counts)}")        

        # collect bigram counts
        print("Collecting bigram counts...")
        bigram_counts = Counter()
        for s in sentences_unk:
            for bigram in zip(s[:-1], s[1:]):
                bigram_counts[bigram] += 1     
        print(f"Total num bigrams: {len(bigram_counts)}")        

        # add an extra start token at the beginning of all sentences
        sentences_unk = [[self.start_token]+s for s in sentences_unk]        
        # collect trigram counts
        print("Collecting trigram counts...")
        trigram_counts = Counter()
        for s in sentences_unk:
            for trigram in zip(s[:-2], s[1:-1], s[2:]):
                trigram_counts[trigram] += 1     
        print(f"Total num trigrams: {len(trigram_counts)}")                

        return vocab, unigram_counts, bigram_counts, trigram_counts
    
    def compute_probs(self):
        print("Computing trigram probabilities...")
        trigram_probs = Counter()
        for word1 in self.vocab:
            probs = []
            for word2 in self.vocab:
                # compute P(word2|word1)
                p = self.bg_prob(word1, word2)
                probs.append(p)
            trigram_probs[word1] = probs 
        self.trigram_probs = trigram_probs   

    def bg_prob(self, word1, word2):
        # addk probability
        p = (self.bigram_counts[(word1, word2)] + self.k) / (self.unigram_counts[word1] + self.k*len(self.vocab)) 
        return p        
    
    def tg_prob(self, word1, word2, word3):
        # addk probability
        p = (self.trigram_counts[(word1,word2,word3)] + self.k) / (self.bigram_counts[(word1,word2)] + self.k*len(self.vocab)) 
        return p        