Introduction To NLP @ Esade BAIB

# NGram Language Models

In [1]:
# pandas dataframes and utils for counting
import pandas as pd
from collections import Counter
import logging

## Load Movie Reviews

In [2]:
movie_reviews_df = pd.read_csv('../data/movie_reviews_text.csv')
movie_reviews_df.sample(10)

Unnamed: 0,text
9911,cruel and inhuman cinematic punishment . . . s...
9757,the pretensions -- and disposable story -- sin...
10208,it has the right approach and the right openin...
4071,""" brown sugar "" admirably aspires to be more t..."
6781,the adventure doesn't contain half the excitem...
3791,"a burst of color , music , and dance that only..."
2398,"the ring is worth a look , if you don't demand..."
10483,is it really an advantage to invest such subtl...
5173,smarter than its commercials make it seem .
2435,"with one exception , every blighter in this pa..."


## Counting ngrams

In [3]:
import nltk
from nltk.util import ngrams
nltk.download('punkt_tab')


[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\rsast\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [4]:
from typing import List, Dict 
start_symbol = "_START_"
stop_symbol = "_STOP_"

def count_ngrams_up_to(n_max: int, texts: List[str], tokenizer=nltk.word_tokenize) -> Dict[tuple, Counter]:
    counts = {} # dict from context to a counter of next symbol
    for text in texts:
        tokens = tokenizer(text) + [stop_symbol]
        for n in range(n_max):
            starts = [start_symbol] * n
            for ngram in ngrams(starts + tokens, n+1):
                context = ngram[:-1]
                end_symbol = ngram[-1]
                counts.setdefault(context, Counter()).update([end_symbol])
    return counts

In [5]:
mr_ngram_counts = count_ngrams_up_to(3, movie_reviews_df.text)

In [6]:
print("Number of ngram contexts: ", len(mr_ngram_counts))

Number of ngram contexts:  133415


In [7]:
# check most common words for "empty" context
mr_ngram_counts[()].most_common(10)

[('.', 14010),
 ('_STOP_', 10662),
 ('the', 10113),
 (',', 10037),
 ('a', 7314),
 ('and', 6201),
 ('of', 6062),
 ('to', 4234),
 ('is', 3559),
 ("'s", 3537)]

In [8]:
# check most common words following "steven"
mr_ngram_counts[('steven', )].most_common(10)

[('spielberg', 9),
 ('soderbergh', 8),
 ('seagal', 4),
 ('shainberg', 2),
 ('segal', 1)]

In [10]:
mr_ngram_counts[('movie', )].most_common(10)

[('.', 164),
 ('is', 152),
 ('that', 133),
 (',', 127),
 ("'s", 64),
 ('with', 34),
 ('about', 27),
 ('in', 26),
 ('has', 25),
 ('of', 19)]

## NGram Language Models

In [11]:
from typing import Dict, Set
from collections import Counter 
from random import choices

class NGramLanguageModel:
    """ 
    An NGram Language model with Katz back-off discount. 

    n is the order of the ngram model, i.e n=3 for a trigram model

    ngram_counts is a dictionary of word counts for each ngram context. 

    back_off_discount is the discount value for Katz back-off. A None value 
       indicates no-back off, the counts are used without any smoothing,. 

    """
    def __init__(self, n: int, ngram_counts: Dict[tuple, Counter], back_off_discount: float):
        self.n = n 
        self.back_off_discount = back_off_discount
        self.ngram_counts = ngram_counts
        self.vocab = set(self.ngram_counts.get((), {}).keys())
    
    def p_next_word(self, tokens: tuple, top: int = None, n=None, _vocab: Set[str] = None) -> dict[str, float]:
        """Returns the probability distribution over the next word given the tokens. """
        logger = logging.getLogger('p_next_word')
        if n is None:
            n = self.n 
        elif n == 0:
            return {}
        ctx_len = n-1   # len of context is the order of the ngram model minus one 
        if ctx_len == 0: 
            ctx = ()  # take empty context
            ctx_counts = self.ngram_counts.get(ctx)   
        else:
            if len(tokens) < ctx_len:
                starts = [start_symbol] * (ctx_len - len(tokens))
                tokens = starts + list(tokens)            
            ctx = tuple(tokens[-ctx_len:])
            assert(len(ctx)==ctx_len)
            ctx_counts = self.ngram_counts.get(ctx, Counter()) # take counts of context, or an empty Counter if context does not exist
        total_count = sum(ctx_counts.values())
        logger.debug(f"context_length={ctx_len} context={ctx} total_count={total_count} observed_words={len(ctx_counts)}")

        if _vocab is None: 
            _vocab = self.vocab
        if self.back_off_discount:
            _vocab_observed = set(_vocab).intersection(set(ctx_counts.keys()))
            _vocab_unobserved = set(_vocab).difference(_vocab_observed)
            word_prob_observed = [(word, (ctx_counts.get(word, 0) - self.back_off_discount) / total_count) for word in _vocab_observed]
            total_discount = self.back_off_discount*len(_vocab_observed)
            if total_count == 0:
                mass_discount = 1
            else: 
                mass_discount = total_discount / total_count
            logger.debug(f"backing-off to context len {n-1} for {len(_vocab_unobserved)} unobserved words")
            word_prob_unobserved = [(w, p*mass_discount) for w, p in self.p_next_word(tokens, n=n-1, _vocab=_vocab_unobserved).items()]
            word_prob = word_prob_observed + word_prob_unobserved
        elif total_count == 0:
            word_prob = [(word, 0) for word in _vocab]
        else:
            word_prob = [(word, ctx_counts.get(word, 0) / total_count) for word in _vocab]
        word_prob = sorted(word_prob, key=lambda wp: wp[1], reverse=True)
        if top is not None:
            word_prob = word_prob[:top]
        return {w: p for w, p in word_prob}

def prob_text(ngram_model: NGramLanguageModel, text: str, tokenizer=nltk.word_tokenize) -> float:
    logger = logging.getLogger('prob_text')
    tokens = tokenizer(text) + [stop_symbol]
    logger.debug(f"Text tokens: {tokens}")
    p = 1
    for i in range(len(tokens)):
        p_next = ngram_model.p_next_word(tokens=tuple(tokens[:i]))
        p_token = p_next.get(tokens[i], 0)
        logger.debug(f"p({tokens[i]} | {tokens[:i]}) = {p_token} ")
        p = p*p_token
    return p


def text_generator(ngram_model: NGramLanguageModel, tokens: List[str] = None, randomize: bool = False, limit=1000):
    logger = logging.getLogger('text_generator')
    if tokens is None:
        tokens = []
    next = None
    while (next != stop_symbol):
        probs = ngram_model.p_next_word(tokens)
        if randomize:
            next = choices(list(probs.keys()), list(probs.values()), k=1)[0]
        else:
            next = list(probs.items())[0][0]  # take first item (most likely one), and take its first element (the word)
        tokens.append(next)
        if len(tokens) == limit:
            logger.info(f"Reached limit of {limit} tokens!")
            break
    return tokens



### Create a few models with the ngram counts

In [12]:
lm_bigram_smoothed = NGramLanguageModel(2, mr_ngram_counts, back_off_discount=0.1)

In [13]:
lm_bigram_rough = NGramLanguageModel(2, mr_ngram_counts, back_off_discount=0)

In [14]:
lm_trigram = NGramLanguageModel(3, mr_ngram_counts, back_off_discount=0.1)

In [15]:
logging.getLogger("p_next_word").setLevel(logging.DEBUG)
lm_bigram_rough.p_next_word(('this',), top=10)

{'is': 0.17117117117117117,
 'movie': 0.08731808731808732,
 'film': 0.07207207207207207,
 'one': 0.03395703395703396,
 'time': 0.013167013167013167,
 'year': 0.012474012474012475,
 '.': 0.011088011088011088,
 'story': 0.009009009009009009,
 ',': 0.009009009009009009,
 'picture': 0.007623007623007623}

In [16]:
lm_bigram_smoothed.p_next_word(('steven',), top=10)

{'spielberg': 0.37083333333333335,
 'soderbergh': 0.32916666666666666,
 'seagal': 0.1625,
 'shainberg': 0.07916666666666666,
 'segal': 0.0375,
 '.': 0.0012120766958466913,
 '_STOP_': 0.0009224220389473043,
 'the': 0.0008749249043482111,
 ',': 0.0008683497090303039,
 'a': 0.0006327673820479172}

In [17]:
lm_trigram.p_next_word(('a', 'movie', 'of', 'steven',), top=10)

{'soderbergh': 0.475,
 'spielberg': 0.475,
 'seagal': 0.008125,
 'shainberg': 0.003958333333333334,
 'segal': 0.001875,
 '.': 3.636230087540075e-05,
 '_STOP_': 2.7672661168419136e-05,
 'the': 2.624774713044634e-05,
 ',': 2.6050491270909125e-05,
 'a': 1.8983021461437524e-05}

### Hack: reverse the language! 

In [18]:
reversed_ngram_counts = count_ngrams_up_to(3, [' '.join(reversed(nltk.word_tokenize(t))) for t in movie_reviews_df.text])

In [19]:
reversed_ngram_counts[('spielberg', 'steven')]

Counter({'of': 2,
         'even': 2,
         'like': 1,
         ',': 1,
         '_STOP_': 1,
         'movie': 1,
         'realizing': 1})

In [20]:
reversed_lm = NGramLanguageModel(n=3, ngram_counts=reversed_ngram_counts, back_off_discount=0)

In [21]:
reversed_lm.p_next_word(('soderbergh', 'steven'), top=10)

{'of': 0.25,
 '_STOP_': 0.25,
 'if': 0.125,
 'and': 0.125,
 'in': 0.125,
 ',': 0.125,
 'fruit': 0.0,
 'carlin': 0.0,
 'sharply': 0.0,
 'cheeky': 0.0}

In [22]:
reversed_lm.p_next_word(('spielberg', 'steven'), top=10)

{'of': 0.2222222222222222,
 'even': 0.2222222222222222,
 'like': 0.1111111111111111,
 'movie': 0.1111111111111111,
 '_STOP_': 0.1111111111111111,
 ',': 0.1111111111111111,
 'realizing': 0.1111111111111111,
 'fruit': 0.0,
 'carlin': 0.0,
 'sharply': 0.0}

In [23]:
lm_trigram.p_next_word(('and', 'steven'), top=10)

{'soderbergh': 0.9,
 'spielberg': 0.037083333333333336,
 'seagal': 0.01625,
 'shainberg': 0.007916666666666667,
 'segal': 0.00375,
 '.': 9.696613566773531e-05,
 '_STOP_': 7.379376311578435e-05,
 'the': 6.999399234785689e-05,
 ',': 6.946797672242431e-05,
 'a': 5.062139056383338e-05}

In [24]:
lm_trigram.p_next_word(('even', 'steven'), top=10)

{'spielberg': 0.95,
 'soderbergh': 0.016458333333333335,
 'seagal': 0.008125,
 'shainberg': 0.003958333333333334,
 'segal': 0.001875,
 '.': 4.8483067833867656e-05,
 '_STOP_': 3.6896881557892176e-05,
 'the': 3.4996996173928444e-05,
 ',': 3.4733988361212156e-05,
 'a': 2.531069528191669e-05}

## Computations with LMs

### Compute p(x)

In [25]:
logging.getLogger("p_next_word").setLevel(logging.INFO)
logging.getLogger("prob_text").setLevel(logging.DEBUG)

prob_text(ngram_model=lm_bigram_rough, text="i like this movie, woody")

0.0

In [26]:
prob_text(ngram_model=lm_bigram_smoothed, text="i like this movie, woody")

2.030024504174122e-16

### Completing a text

In [27]:
text_generator(ngram_model=lm_bigram_rough)

['the', 'film', 'is', 'a', 'movie', '.', '_STOP_']

In [28]:
text_generator(ngram_model=lm_bigram_smoothed)

['the', 'film', 'is', 'a', 'movie', '.', '_STOP_']

In [29]:
text_generator(ngram_model=lm_trigram)

['the',
 'film',
 'is',
 'a',
 'movie',
 'that',
 "'s",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',
 "'re",
 'not',
 'a',
 'bad',
 'sign',
 'when',
 'you',


In [30]:
text_generator(ngram_model=lm_bigram_rough, tokens=['the', 'movie', 'by', 'steven'], randomize=True)

['the',
 'movie',
 'by',
 'steven',
 'seagal',
 'pessimists',
 ':',
 'two',
 'hours',
 'represents',
 'two',
 'literary',
 'purists',
 'might',
 'otherwise',
 'calculated',
 '.',
 '_STOP_']

In [31]:
logging.getLogger("prob_text").setLevel(logging.INFO)
for i in range(10):
    text = text_generator(lm_trigram, tokens=['i', 'like', 'movies', 'by', 'steven'], randomize=True)
    prob = prob_text(lm_trigram, ' '.join(text))
    print(prob, text)

6.882028267321108e-34 ['i', 'like', 'movies', 'by', 'steven', 'spielberg', 'brings', 'us', 'right', 'into', 'ideas', 'and', 'fanciful', 'sexist', 'or', 'mean-spirited', '.', '_STOP_']
2.02277078940983e-21 ['i', 'like', 'movies', 'by', 'steven', 'spielberg', "'s", '1993', 'classic', '.', '_STOP_']
1.2303750683239652e-20 ['i', 'like', 'movies', 'by', 'steven', 'segal', '.', '_STOP_']
1.7095953781368685e-42 ['i', 'like', 'movies', 'by', 'steven', 'soderbergh', "'s", 'solaris', 'so', 'much', 'that', 'is', 'life', 'affirming', 'and', 'heartbreaking', 'to', 'witness', 'the', 'conflict', 'from', 'the', 'french', 'revolution', 'from', 'stark', 'desert', 'to', 'gorgeous', 'beaches', '.', 'the', 'interviews', 'that', 'follow', ',', 'iwai', "'s", 'vaunted', 'empathy', '.', '_STOP_']
1.2303750683239652e-20 ['i', 'like', 'movies', 'by', 'steven', 'segal', '.', '_STOP_']
5.8471700181716176e-62 ['i', 'like', 'movies', 'by', 'steven', 'spielberg', 'has', 'dreamed', 'up', 'such', 'blatant', 'and', 'sic

###  Fill in the gap

We now implement a prediction method to fill in the gap: the input is a sentence with a gap to be filled with a word. We represent it by a prefix, a suffix, and a list of words to choose from. The method will return a distribution over the words of choice. 

This implementation is very slow whenever the number of choices is very large. More efficient implementations are possible, but are significantly more involved to code. 

In [32]:
from tqdm import tqdm
def fill_in_the_gap(ngram_model: NGramLanguageModel, prefix: str, suffix: str, choices: List[str] = None, top=None):
    if choices is None: 
        choices = ngram_model.vocab
    word_prob = []
    for word in tqdm(choices):
        p_text = prob_text(ngram_model=ngram_model, text = prefix + " " + word + " " + suffix)
        word_prob.append((word, p_text))
    sum_probs = sum([wp[1] for wp in word_prob])
    word_prob = sorted(word_prob, key=lambda wp: wp[1], reverse=True)
    if top:
        word_prob = word_prob[:10]
    return {w: p/sum_probs for w, p in word_prob}


In [33]:
fill_in_the_gap(lm_bigram_smoothed, prefix='steven', suffix='was director of the movie.', choices=['spielberg', 'allen', 'segal', 'soderbergh'])

100%|██████████| 4/4 [00:01<00:00,  2.17it/s]


{'spielberg': 0.48654978137336197,
 'soderbergh': 0.4201293954119751,
 'segal': 0.09137451752918674,
 'allen': 0.0019463056854763252}

In [34]:
p_next_steven = lm_trigram.p_next_word(tokens=['steven'], top=10)
p_next_steven

{'soderbergh': 0.6333333333333333,
 'spielberg': 0.3,
 'seagal': 0.010833333333333334,
 'shainberg': 0.005277777777777777,
 'segal': 0.0025,
 '.': 4.848306783386766e-05,
 '_STOP_': 3.6896881557892176e-05,
 'the': 3.499699617392845e-05,
 ',': 3.473398836121216e-05,
 'a': 2.5310695281916696e-05}

In [35]:
fill_in_the_gap(lm_bigram_smoothed, prefix='steven', suffix='was director of the movie.', choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:03<00:00,  2.67it/s]


{'spielberg': 0.33773583108484606,
 'soderbergh': 0.2916304887079038,
 'seagal': 0.17784465394770904,
 'shainberg': 0.1071213486715338,
 'segal': 0.06342711434498713,
 '_STOP_': 0.015601751503639159,
 ',': 0.004368837561037187,
 '.': 0.0013558499704477795,
 'the': 0.0004953282869056742,
 'a': 0.00041879592099058855}

## Adding more data 

In this section we add additional texts and recompute the counts and recreate the language models. 

We want to show that ngram language models can remember textual patterns as long as these are within the window of "n" words of the ngram model. 

We will show that simple LMs are able to remember which movie director directed what movie, by first adding this data to the text collection, then pose questions in the form of fill-in-the-gap. 

We will see that bigram, trigram, or even six-gram models can remember or not the textual patterns.

In [36]:
# here the "knowledge" between the last name of director and movie title is within 3 words
more_texts = [
    'steven spielberg directed jaws',
    'steven spielberg directed et',
    'steven spielberg directed ai',
    'steven soderbergh directed traffic'
]

In [38]:
new_ngram_counts = count_ngrams_up_to(3, more_texts + list(movie_reviews_df.text))
new_trigram_model = NGramLanguageModel(n=3, ngram_counts=new_ngram_counts, back_off_discount=0.5)
new_bigram_model = NGramLanguageModel(n=2, ngram_counts=new_ngram_counts, back_off_discount=0.5)

In [39]:
fill_in_the_gap(ngram_model=new_trigram_model, prefix="steven", suffix="directed ai", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:03<00:00,  3.29it/s]


{'spielberg': 0.994923484739286,
 'soderbergh': 0.005076140228261662,
 'seagal': 2.16764651023116e-07,
 'shainberg': 9.024487511982789e-08,
 '_STOP_': 3.996722567478841e-08,
 ',': 1.5098132993922466e-08,
 'the': 6.342054757884297e-09,
 'a': 5.362073048926721e-09,
 '.': 1.2534395821493066e-09,
 'segal': 0.0}

In [40]:
fill_in_the_gap(ngram_model=new_trigram_model, prefix="steven", suffix="directed traffic", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:02<00:00,  3.38it/s]


{'soderbergh': 0.9552232786929382,
 'spielberg': 0.04477609118873148,
 'seagal': 3.6420149559667644e-07,
 'shainberg': 1.5162674510555511e-07,
 '_STOP_': 6.715173944138865e-08,
 ',': 2.5367432333409817e-08,
 'the': 1.065573107550253e-08,
 'a': 9.009195063403056e-09,
 '.': 2.105991767127051e-09,
 'segal': 0.0}

In [41]:
fill_in_the_gap(ngram_model=new_bigram_model, prefix="steven", suffix="directed traffic", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:02<00:00,  4.76it/s]


{'spielberg': 0.8365228375749165,
 'soderbergh': 0.1630062288238869,
 'seagal': 0.00022019780264424577,
 'shainberg': 0.00011667623828422371,
 ',': 5.6933766430629166e-05,
 'segal': 4.861509928509321e-05,
 '_STOP_': 2.1530426428643396e-05,
 'the': 3.4164779032233527e-06,
 'a': 2.8885597470367713e-06,
 '.': 6.752304732334426e-07}

In [42]:
fill_in_the_gap(ngram_model=new_bigram_model, prefix="steven", suffix="directed ai", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:02<00:00,  4.81it/s]


{'spielberg': 0.8365228375749169,
 'soderbergh': 0.16300622882388696,
 'seagal': 0.00022019780264424585,
 'shainberg': 0.00011667623828422374,
 ',': 5.693376643062918e-05,
 'segal': 4.861509928509323e-05,
 '_STOP_': 2.1530426428643402e-05,
 'the': 3.4164779032233535e-06,
 'a': 2.8885597470367718e-06,
 '.': 6.752304732334427e-07}

In [44]:
# here the "knowledge" between the last name of director and movie title is within 6 words
even_more_texts = [
    'steven spielberg was the director of jaws',
    'steven spielberg was the director of et',
    'steven spielberg was the director of ai',
    'steven soderbergh was the director of traffic'
]

In [45]:
ngram_counts_v3 = count_ngrams_up_to(6, more_texts + even_more_texts + list(movie_reviews_df.text))
trigram_model_v3 = NGramLanguageModel(n=3, ngram_counts=ngram_counts_v3, back_off_discount=0.1)
sixgram_model = NGramLanguageModel(n=6, ngram_counts=ngram_counts_v3, back_off_discount=0.1)

In [46]:
fill_in_the_gap(ngram_model=sixgram_model, prefix="steven", suffix="was the director of ai", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:07<00:00,  1.37it/s]


{'spielberg': 0.9817584686104596,
 'soderbergh': 0.018241531389540364,
 'seagal': 0.0,
 'shainberg': 0.0,
 'segal': 0.0,
 '.': 0.0,
 '_STOP_': 0.0,
 'the': 0.0,
 ',': 0.0,
 'a': 0.0}

In [47]:
fill_in_the_gap(ngram_model=sixgram_model, prefix="steven", suffix="was the director of traffic", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:07<00:00,  1.34it/s]


{'soderbergh': 0.9083374655992839,
 'spielberg': 0.09166253440071612,
 'seagal': 0.0,
 'shainberg': 0.0,
 'segal': 0.0,
 '.': 0.0,
 '_STOP_': 0.0,
 'the': 0.0,
 ',': 0.0,
 'a': 0.0}

In [48]:
fill_in_the_gap(ngram_model=trigram_model_v3, prefix="steven", suffix="was the director of traffic", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:04<00:00,  2.39it/s]


{'spielberg': 0.8032306918578586,
 'soderbergh': 0.19676930328121087,
 'seagal': 2.0496918378130644e-09,
 '_STOP_': 1.6962856196816857e-09,
 'shainberg': 9.70037309323985e-10,
 'the': 5.381371959803274e-11,
 'a': 4.549904935505647e-11,
 ',': 3.111745503040745e-11,
 '.': 1.4485752752514247e-11,
 'segal': 0.0}

In [49]:
fill_in_the_gap(ngram_model=trigram_model_v3, prefix="steven", suffix="was the director of ai", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:04<00:00,  2.39it/s]


{'spielberg': 0.8032306918578586,
 'soderbergh': 0.19676930328121087,
 'seagal': 2.0496918378130644e-09,
 '_STOP_': 1.6962856196816857e-09,
 'shainberg': 9.70037309323985e-10,
 'the': 5.381371959803274e-11,
 'a': 4.549904935505647e-11,
 ',': 3.111745503040745e-11,
 '.': 1.4485752752514247e-11,
 'segal': 0.0}

In [50]:
fill_in_the_gap(ngram_model=sixgram_model, prefix="steven", suffix="was the director of ai", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:07<00:00,  1.38it/s]


{'spielberg': 0.9817584686104596,
 'soderbergh': 0.018241531389540364,
 'seagal': 0.0,
 'shainberg': 0.0,
 'segal': 0.0,
 '.': 0.0,
 '_STOP_': 0.0,
 'the': 0.0,
 ',': 0.0,
 'a': 0.0}

In [51]:
fill_in_the_gap(ngram_model=sixgram_model, prefix="steven", suffix="was the director of ai", choices=list(p_next_steven.keys()))

100%|██████████| 10/10 [00:07<00:00,  1.37it/s]


{'spielberg': 0.9817584686104596,
 'soderbergh': 0.018241531389540364,
 'seagal': 0.0,
 'shainberg': 0.0,
 'segal': 0.0,
 '.': 0.0,
 '_STOP_': 0.0,
 'the': 0.0,
 ',': 0.0,
 'a': 0.0}