## Kneser-Ney Smoothing

When building a language model to predict the next word given the previous N words it is advantageous to use smoothing on the counts you get from the training data to help the model generalize better. Kneser-Ney helps with this in 4 ways:  
1. If we come across an unseen example we should backoff to n-1 grams. (e.g., if we are using a bigram model and we come across 'cat', a word we've never seen before, we should back off of our bigram model to a unigram model to try and predict the word that follows 'cat')
2. If we back off we should do so "smoothly"
3. We should reduce our counts slightly, because training data often over estimates the real number of occurrences. (e.g., our training data will estimate that 'great' follows 'today is' 3 times, but given a larger body of text it is most likely less than 3.)
4. The fertility of a word is more useful than it's ngram probability (e.g., in an article about San Francisco the unigram Francisco will have a very high probability of being the next word when using a unigram model. This is a gross over estimate and should be corrected by making the probability of each word the number of unique ngrams it completes / total unique ngrams. Once can see how this would drop the probability of Fransisco dramatically)

## Kneser-Ney Algorithm

P(word | previous_n_words) =  
[max(number of occurrences that word follows previous n words - sigma, 0) / number of occurrences of previous n words followed by any word]  

plus [sigma / number of occurrences of previous n words followed by any word * number of distinct words that follow previous n words]    

multiplied by [number of distinct ngrams word completes / number of distinct n grams]

To use Kneser-Ney we need to define 7 vairables:  
1. Our vocabulary of distinct words
2. The number of ngrams that each word completes (not distinct)
3. The total number of ngrams (not distinct) that start w/previous n words
4. The distinct number of ngrams
5. The distinct number of ngrams that each word starts
5. The distinct number of ngrams that each word complets
6. Sigma, the amount of probability we will subtract from each probability

In [11]:
import pandas as pd
from collections import defaultdict

In [20]:
#create some fake text to apply Kneser-Ney to
#I have built in a few things that normal preprocessing would
#take care of:
#<s> is an indicator of a new sentence
#<s/> is an indicator of the end of a sentence
#I have made everything lower case
text = "<s> the dog jumped down the stairs <s/> <s> the dog ran down the street <s/> "

#split the text into a list for ease of iterating
#to get necessary counts
text_split = text.split()

Let's build our vaiables one at a time

## 1. Build Distinct Vocab

In [21]:
#build vocab of distinct words
vocab = []
for i in text.split():
    if i not in vocab:
        vocab.append(i)

In [22]:
#look at vocab
vocab

['<s>', 'the', 'dog', 'jumped', 'down', 'stairs', '<s/>', 'ran', 'street']

In [23]:
#length of vocab
len(vocab)

9

## 2. Ngrams that each word completes

In [41]:
#create dict to store ngrams
#this is not disctinct
#key is word that completes ngram
#value is dict where key is prior word
#and value is number of times word 
#completes (follows) prior word
ngrams_by_word = defaultdict(dict)

for idx, item in enumerate(text_split):
    if idx == 0:
        continue
    try:
        ngrams_by_word[item][text_split[idx - 1]] += 1 
    except KeyError: 
        ngrams_by_word[item][text_split[idx - 1]] = 1

In [42]:
#look at dict
ngrams_by_word

defaultdict(dict,
            {'<s/>': {'stairs': 1, 'street': 1},
             '<s>': {'<s/>': 1},
             'dog': {'the': 2},
             'down': {'jumped': 1, 'ran': 1},
             'jumped': {'dog': 1},
             'ran': {'dog': 1},
             'stairs': {'the': 1},
             'street': {'the': 1},
             'the': {'<s>': 2, 'down': 2}})

## 3. The total number of ngrams started by wi-n-1, ..., wi-1

In [74]:
ngrams_started_by_wi_minus = defaultdict(int)

for i,j in enumerate(text_split[-1::-1]):
    if i == 0:
        continue
    ngrams_started_by_wi_minus[j] += 1
    

In [75]:
ngrams_started_by_wi_minus

defaultdict(int,
            {'<s/>': 1,
             '<s>': 2,
             'dog': 2,
             'down': 2,
             'jumped': 1,
             'ran': 1,
             'stairs': 1,
             'street': 1,
             'the': 4})

## 4. The disctinct list of ngrams

In [67]:
#create list to store ngrams:
distinct_ngrams = []

for i,j in enumerate(text_split):
    if i == 0:
        continue
    if (text_split[i-1], j) not in distinct_ngrams:
        distinct_ngrams.append((text_split[i-1], j))

In [68]:
#look at list
distinct_ngrams

[('<s>', 'the'),
 ('the', 'dog'),
 ('dog', 'jumped'),
 ('jumped', 'down'),
 ('down', 'the'),
 ('the', 'stairs'),
 ('stairs', '<s/>'),
 ('<s/>', '<s>'),
 ('dog', 'ran'),
 ('ran', 'down'),
 ('the', 'street'),
 ('street', '<s/>')]

## 5a. The distinct number of ngrams started by each word

In [72]:
ngrams_started_by_word = defaultdict(int)

for i in vocab:
    for j in distinct_ngrams:
        if j[0] == i:
            ngrams_started_by_word[i] += 1

In [73]:
#look
ngrams_started_by_word

defaultdict(int,
            {'<s/>': 1,
             '<s>': 1,
             'dog': 2,
             'down': 1,
             'jumped': 1,
             'ran': 1,
             'stairs': 1,
             'street': 1,
             'the': 3})

## 5. The distinct number of ngrams each word completes

In [69]:
ngrams_completed_by_word = defaultdict(int)

for i in vocab:
    for j in distinct_ngrams:
        if j[1] == i:
            ngrams_completed_by_word[i] += 1

In [70]:
ngrams_completed_by_word

defaultdict(int,
            {'<s/>': 2,
             '<s>': 1,
             'dog': 1,
             'down': 2,
             'jumped': 1,
             'ran': 1,
             'stairs': 1,
             'street': 1,
             'the': 2})

## 6. Sigma

In [71]:
#this ones easy bc you set it to whatever you want!
#value is usually btw 0 - 1
sigma = .5

## Putting it all together

Before running the entire smoothin algorithm across all of the text, let's just run it for one point in time and examing the results.

In [96]:
# partial_text = """dog"""

#given dog, for each word in our
#vocabular what is probability that it
#will follow dog

def kneser_ney(partial_text):
    #build dict to hold probabilities
    example_probabilities = {}

    #get denominator
    try:
        denominator = ngrams_started_by_wi_minus[partial_text]
    except KeyError:
        denominator = 0

    #how much probability mass did we take away
    #when we subtracted sigma from numerator
    try:
        missing_prob = sigma / ngrams_started_by_wi_minus[partial_text] * ngrams_started_by_word[partial_text]
    except ZeroDivisionError:
        missing_prob = 1

    for i in vocab:

        #initialize to 0
        example_probabilities[i] = 0

        #get numerator, if i doesn't ever
        #come after dog set to 0
        try:
            numerator = ngrams_by_word[i][partial_text] - sigma
        except KeyError:
            numerator = 0

        #first part of algorithm
        if numerator == 0:
            part_one = 0
        elif denominator == 0:
            part_one = 0
        else:
            part_one = numerator / denominator

        #backoff part/ fertility
        fertility = ngrams_completed_by_word[i] / len(distinct_ngrams)

        example_probabilities[i] = part_one + missing_prob * fertility
        
    return example_probabilities
    

In [98]:
test_dog = kneser_ney('dog')
test_dog

{'<s/>': 0.08333333333333333,
 '<s>': 0.041666666666666664,
 'dog': 0.041666666666666664,
 'down': 0.08333333333333333,
 'jumped': 0.2916666666666667,
 'ran': 0.2916666666666667,
 'stairs': 0.041666666666666664,
 'street': 0.041666666666666664,
 'the': 0.08333333333333333}

In [99]:
#check and make sure they sum to 1
sum=0
for i in test_dog:
    sum += test_dog[i]
sum

1.0000000000000002

In [100]:
#now that we've bulit that let's say we
#get some input "The cat" and we want
#to predict what comes next
test_cat = kneser_ney('cat')
test_cat

{'<s/>': 0.16666666666666666,
 '<s>': 0.08333333333333333,
 'dog': 0.08333333333333333,
 'down': 0.16666666666666666,
 'jumped': 0.08333333333333333,
 'ran': 0.08333333333333333,
 'stairs': 0.08333333333333333,
 'street': 0.08333333333333333,
 'the': 0.16666666666666666}

In [101]:
#check and make sure they sum to 1
sum=0
for i in test_cat:
    sum += test_cat[i]
sum

1.0