In [29]:
import spacy
from datasets import load_dataset
from collections import defaultdict, Counter
import math

nlp = spacy.load("en_core_web_sm")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def preprocess(text):
    doc = nlp(text)
    return [token.lemma_ for token in doc if token.is_alpha]

documents = [preprocess(line["text"]) for line in dataset]


In [95]:
unigram_counts = Counter()
total_unigrams = 0
bigram_counts = defaultdict(Counter)

for doc in documents:
    unigram_counts.update(doc)
    total_unigrams += len(doc)

    prev_word = "<START>"
    for word in doc:
        bigram_counts[prev_word][word] += 1
        prev_word = word

unigram_probs = {word: count / total_unigrams for word, count in unigram_counts.items()}
bigram_probs = {}

for prev_word, next_words in bigram_counts.items():
    total_count = sum(next_words.values())
    bigram_probs[prev_word] = {
        word: count / total_count
        for word, count in next_words.items()
    }


In [96]:
def predict_next_word(context, bigram_probs):
    if context not in bigram_probs:
        return None
    return max(bigram_probs[context], key=bigram_probs[context].get)


In [97]:
def compute_sentence_probability(sentence, bigram_probs):
    words = preprocess(sentence)
    prob = 0
    prev_word = "<START>"
    for word in words:
        bigram_prob = bigram_probs.get(prev_word, {}).get(word, 0)
        if bigram_prob > 0:
            prob += math.log(bigram_prob)
        else:
            return -float("inf")
        prev_word = word
    return prob


def compute_perplexity(sentences, bigram_probs):
    total_log_prob = 0
    total_word_count = 0

    for sentence in sentences:
        words = preprocess(sentence)
        prev_word = "<START>"

        for word in words:
            if prev_word in bigram_probs and word in bigram_probs[prev_word]:
                word_prob = bigram_probs[prev_word][word]
            else:
                word_prob = 0
            if word_prob > 0:
                total_log_prob += math.log(word_prob)
            else:
                total_log_prob += float('-inf')
            prev_word = word

        total_word_count += len(words)

    avg_log_prob = total_log_prob / total_word_count

    perplexity = math.exp(-avg_log_prob)
    return perplexity


In [98]:
#2
context = "in"
next_word = predict_next_word(context, bigram_probs)
print(f"The next word is: {next_word}")

The next word is: the


In [99]:
#3
sentences = [
    "Brad Pitt was born in Oklahoma",
    "The actor was born in USA"
]

probabilities = [math.exp(compute_sentence_probability(sentence, bigram_probs)) for sentence in sentences]

perplexity = compute_perplexity(sentences, bigram_probs)

for i, sentence in enumerate(sentences):
    print(f"Probability of sentence {i + 1} ('{sentence}'): {probabilities[i]}")
print(f"Perplexity of the test set: {perplexity}")


Probability of sentence 1 ('Brad Pitt was born in Oklahoma'): 0.0
Probability of sentence 2 ('The actor was born in USA'): 1.286901194526109e-13
Perplexity of the test set: inf


In [100]:
def interpolated_probability(prev_word, word, bigram_probs, unigram_probs, lambda_bigram, lambda_unigram):
    bigram_prob = bigram_probs.get(prev_word, {}).get(word, 0)
    unigram_prob = unigram_probs.get(word, 0)
    return lambda_bigram * bigram_prob + lambda_unigram * unigram_prob


def compute_interpolated_sentence_probability(sentence, bigram_probs, unigram_probs, lambda_bigram=2 / 3,
                                              lambda_unigram=1 / 3):
    words = preprocess(sentence)
    prob = 0
    prev_word = "<START>"
    for word in words:
        interpolated_prob = interpolated_probability(prev_word, word, bigram_probs, unigram_probs, lambda_bigram,
                                                     lambda_unigram)
        if interpolated_prob > 0:
            prob += math.log(interpolated_prob)
        else:
            return -float("inf")
        prev_word = word
    return prob

In [101]:
def compute_interpolated_perplexity(sentences, bigram_probs, unigram_probs, lambda_bigram, lambda_unigram):
    total_log_prob = 0
    total_word_count = 0

    for sentence in sentences:
        words = preprocess(sentence)
        prev_word = "<START>"
        for word in words:
            interpolated_prob = interpolated_probability(prev_word, word, bigram_probs, unigram_probs, lambda_bigram, lambda_unigram)
            if interpolated_prob > 0:
                total_log_prob += math.log(interpolated_prob)
            else:
                return float("inf")
            prev_word = word
        total_word_count += len(words)

    avg_log_prob = total_log_prob / total_word_count
    return math.exp(-avg_log_prob)

In [102]:
lambda_bigram = 2/3
lambda_unigram = 1/3

interpolated_probabilities = [
    math.exp(compute_interpolated_sentence_probability(sentence, bigram_probs, unigram_probs, lambda_bigram, lambda_unigram))
    for sentence in sentences
]

interpolated_perplexity = compute_interpolated_perplexity(sentences, bigram_probs, unigram_probs, lambda_bigram, lambda_unigram)

for i, sentence in enumerate(sentences):
    if interpolated_probabilities[i] == 0:
        print(f"Probability of sentence {i + 1} ('{sentence}'): 0 (log prob = -inf)")
    else:
        print(f"Probability of sentence {i + 1} ('{sentence}'): {interpolated_probabilities[i]}")

print(f"Perplexity of the test set: {interpolated_perplexity}")

Probability of sentence 1 ('Brad Pitt was born in Oklahoma'): 1.9053847478951965e-16
Probability of sentence 2 ('The actor was born in USA'): 3.4324627566807654e-14
Perplexity of the test set: 270.4171866251256
