# AIG230 NLP (Week 3 Lab) â€” Notebook 2: Statistical Language Models (Train, Test, Evaluate)

This notebook focuses on **n-gram Statistical Language Models (SLMs)**:
- Train **unigram**, **bigram**, **trigram** models
- Handle **OOV** with `<UNK>`
- Apply **smoothing** (Add-k)
- Evaluate with **cross-entropy** and **perplexity**
- Do **next-word prediction** and simple **text generation**

> Industry framing: even if modern systems use neural LMs, n-gram LMs are still useful for
baselines, constrained domains, and for understanding evaluation.


## 0) Setup


In [1]:

import re
import math
import random
from collections import Counter, defaultdict
from typing import List, Tuple, Dict


## 1) Data: domain text you might see in real systems


We use short texts that resemble:
- release notes
- incident summaries
- operational runbooks
- customer support messaging

In practice, you would load thousands to millions of lines.


In [2]:

corpus = [
    "vpn disconnects frequently after windows update",
    "password reset link expired user cannot login",
    "api requests timeout when latency spikes",
    "portal returns 500 error after deployment",
    "email delivery delayed messages queued",
    "mfa prompt never arrives user stuck at login",
    "wifi drops in meeting rooms access point reboot helps",
    "outlook search not returning results index corrupted",
    "printer driver install fails with error 1603",
    "teams calls choppy audio jitter high",
    "permission denied accessing shared drive though in correct group",
    "battery drains fast after bios update power settings unchanged",
    "push notifications not working on android app",
    "mailbox full cannot receive emails auto archive not running",
]

# Train/test split at sentence level
random.seed(42)
random.shuffle(corpus)
split = int(0.75 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

len(train_texts), len(test_texts), train_texts[:2], test_texts[:2]


(10,
 4,
 ['printer driver install fails with error 1603',
  'push notifications not working on android app'],
 ['email delivery delayed messages queued',
  'vpn disconnects frequently after windows update'])

## 2) Tokenization + special tokens


We will:
- lowercase
- keep alphanumerics
- split on whitespace
- add sentence boundary tokens: `<s>` and `</s>`

We will also map rare tokens to `<UNK>` based on training frequency.


In [4]:

def tokenize(text: str) -> List[str]:
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]+", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text.split()

def add_boundaries(tokens: List[str], n: int) -> List[str]:
    # For n-grams, prepend (n-1) start tokens for simpler context handling
    return ["<s>"]*(n-1) + tokens + ["</s>"]

# Example
tokens = tokenize("Printer driver install fails with error 1603")
add_boundaries(tokens, n=3)


['<s>',
 '<s>',
 'printer',
 'driver',
 'install',
 'fails',
 'with',
 'error',
 '1603',
 '</s>']

## 3) Build vocabulary and handle OOV with <UNK>


In [5]:

# Build vocab from training data
train_tokens_flat = []
for t in train_texts:
    train_tokens_flat.extend(tokenize(t))

freq = Counter(train_tokens_flat)

# Typical practical rule: map tokens with frequency <= 1 to <UNK> in small corpora
min_count = 2
vocab = {w for w, c in freq.items() if c >= min_count}
vocab |= {"<UNK>", "<s>", "</s>"}

def replace_oov(tokens: List[str], vocab: set) -> List[str]:
    return [tok if tok in vocab else "<UNK>" for tok in tokens]

# Show OOV effect
sample = tokenize(test_texts[0])
sample, replace_oov(sample, vocab)


(['email', 'delivery', 'delayed', 'messages', 'queued'],
 ['<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'])

## 4) Train n-gram counts (unigram, bigram, trigram)


We will compute:
- `ngram_counts[(w1,...,wn)]`
- `context_counts[(w1,...,w_{n-1})]`

Then probability:
\ndefault:  P(w_n | context) = count(context + w_n) / count(context)

This fails when an n-gram is unseen, so we add smoothing.


In [6]:
def get_ngrams(tokens: List[str], n: int) -> List[Tuple[str, ...]]:
    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]

def train_ngram_counts(texts: list[str], n: int, vocab: set) -> Dict[Tuple[str, ...], int]:
    ngram_counts = Counter()
    context_counts = Counter()
    for text in texts:
        tokens = replace_oov(tokenize(text),vocab)
        tokens = add_boundaries(tokens, n)
        for ng in get_ngrams(tokens, n):
            ngram_counts[ng] += 1
            context = ng[:-1]
            context_counts[context] += 1
    return ngram_counts, context_counts


In [7]:
uni_counts, uni_contexts  = train_ngram_counts(train_texts, 1, vocab)

In [8]:
uni_counts


Counter({('<UNK>',): 67,
         ('</s>',): 10,
         ('not',): 3,
         ('error',): 2,
         ('after',): 2})

In [9]:
bi_counts , bi_ctx = train_ngram_counts(train_texts, 2, vocab)
bi_counts

Counter({('<UNK>', '<UNK>'): 51,
         ('<s>', '<UNK>'): 10,
         ('<UNK>', '</s>'): 10,
         ('<UNK>', 'not'): 3,
         ('not', '<UNK>'): 3,
         ('<UNK>', 'error'): 2,
         ('after', '<UNK>'): 2,
         ('error', '<UNK>'): 1,
         ('<UNK>', 'after'): 1,
         ('error', 'after'): 1})

In [10]:
tri_counts, tri_ctx = train_ngram_counts(train_texts, 3, vocab)
tri_counts

Counter({('<UNK>', '<UNK>', '<UNK>'): 38,
         ('<s>', '<s>', '<UNK>'): 10,
         ('<s>', '<UNK>', '<UNK>'): 10,
         ('<UNK>', '<UNK>', '</s>'): 7,
         ('<UNK>', '<UNK>', 'not'): 3,
         ('<UNK>', 'not', '<UNK>'): 3,
         ('<UNK>', '<UNK>', 'error'): 2,
         ('not', '<UNK>', '<UNK>'): 2,
         ('<UNK>', 'error', '<UNK>'): 1,
         ('error', '<UNK>', '</s>'): 1,
         ('not', '<UNK>', '</s>'): 1,
         ('<UNK>', '<UNK>', 'after'): 1,
         ('<UNK>', 'after', '<UNK>'): 1,
         ('after', '<UNK>', '<UNK>'): 1,
         ('<UNK>', 'error', 'after'): 1,
         ('error', 'after', '<UNK>'): 1,
         ('after', '<UNK>', '</s>'): 1})

## 5) Add-k smoothing and probability function


Add-k smoothing (a common baseline):
\na) Add *k* to every possible next word count  
b) Normalize by context_count + k * |V|

P_k(w|h) = (count(h,w) + k) / (count(h) + k*|V|)

Where V is the vocabulary.


In [11]:
def prob_addk(ngram: Tuple[str, ...], ngram_counts: Counter,context_counts:Counter, V: int, k: float = 0.5) -> float:
    """
    Compute add-k P(w_n| w_1 ... w_{n-1})
    where ngram = (w_1, w_2, ... , w_n)
    0 < k < = 1
    V is the vocabulary size
    """
    context = ngram[:-1]
    return (ngram_counts[ngram] + k) / (context_counts[context] + k*V)
    


In [12]:
V = len(vocab)
example = ( "<s>", "login")
prob_addk(example, bi_counts, bi_ctx, V, k=0.5)


0.038461538461538464

## 6) Evaluate: cross-entropy and perplexity on test set


We evaluate an LM by how well it predicts held-out text.

Cross-entropy (average negative log probability):
H = - (1/N) * sum log2 P(w_i | context)

Perplexity:
PP = 2^H

Lower perplexity is better.


In [13]:
def evaluate_perplexity(texts: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k: float = 0.5) -> float:

    V = len(vocab)
    log2_probs = []
    token_counts = 0
    for text in texts:
        tokens = replace_oov(tokenize(text),vocab)
        tokens = add_boundaries(tokens, n)
        ngrams = get_ngrams(tokens, n)
        for ng in ngrams:
            prob = prob_addk(ng, ngram_counts, context_counts, V, k)
            log2_probs.append(math.log(prob,2))
            token_counts += 1
    H = -sum(log2_probs) / token_counts
    PP = 2 ** H
    return PP
    

In [14]:
pp_uni = evaluate_perplexity(test_texts, 1, uni_counts, uni_contexts, vocab, k=0.5)
pp_bi = evaluate_perplexity(test_texts, 2, bi_counts, bi_ctx, vocab, k=0.5)
pp_tri = evaluate_perplexity(test_texts, 3, tri_counts, tri_ctx, vocab, k=0.5)

pp_uni, pp_bi, pp_tri


(1.8224739937573902, 1.8712095221558307, 1.9552746520172761)

## 7) Next-word prediction (top-k)


Given a context, compute the probability of each candidate next token and return the top-k.

This mirrors:
- autocomplete in constrained domains
- template suggestion systems
- command prediction in runbooks


In [18]:

def next_word_topk(context_tokens: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k_smooth: float = 0.5, top_k: int = 5):
    # Context length should be n-1
    V = len(vocab)
    context = tuple(context_tokens[-(n-1):]) if n > 1 else tuple()
    candidates = []
    for w in vocab:
        if w in {"<s>"}:
            continue
        ng = context + (w,)
        p = prob_addk(ng, ngram_counts, context_counts, V, k=k_smooth)
        candidates.append((w, p))
    candidates.sort(key=lambda x: -x[1])
    return candidates[:top_k]

# Bigram: context is 1 token
next_word_topk(["<s>"], n=2, ngram_counts=bi_counts, context_counts=bi_ctx, vocab=vocab, top_k=8)



[('<UNK>', 0.8076923076923077),
 ('not', 0.038461538461538464),
 ('</s>', 0.038461538461538464),
 ('after', 0.038461538461538464),
 ('error', 0.038461538461538464)]

## 8) Simple generation (bigram or trigram)


Text generation is not the main goal in SLMs, but it helps you verify:
- boundary handling
- smoothing
- OOV decisions

We will sample tokens until we hit `</s>`.


In [17]:

def sample_next(context_tokens: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k_smooth: float = 0.5):
    V = len(vocab)
    context = tuple(context_tokens[-(n-1):]) if n > 1 else tuple()
    words = [w for w in vocab if w != "<s>"]
    probs = []
    for w in words:
        ng = context + (w,)
        probs.append(prob_addk(ng, ngram_counts, context_counts, V, k=k_smooth))
    # Normalize
    s = sum(probs)
    probs = [p/s for p in probs]
    return random.choices(words, weights=probs, k=1)[0]

def generate(n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, max_len: int = 20, k_smooth: float = 0.5):
    tokens = ["<s>"]*(n-1) if n > 1 else []
    out = []
    for _ in range(max_len):
        w = sample_next(tokens, n, ngram_counts, context_counts, vocab, k_smooth=k_smooth)
        if w == "</s>":
            break
        out.append(w)
        tokens.append(w)
    return " ".join(out)

for _ in range(5):
    print("BIGRAM:", generate(2, bi_counts, bi_ctx, vocab, max_len=18))


BIGRAM: not
BIGRAM: <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>
BIGRAM: <UNK> <UNK> after <UNK>
BIGRAM: <UNK> <UNK> <UNK> <UNK> not <UNK>
BIGRAM: <UNK>


## 9) Model comparison: effect of n and smoothing


Try different `k` values. Notes:
- `k=1.0` is Laplace smoothing (often too strong)
- smaller `k` (like 0.1 to 0.5) is often better

In real corpora, trigrams often beat bigrams, but require more data.


In [19]:

for k in [1.0, 0.5, 0.1, 0.01]:
    pp_bi_k  = evaluate_perplexity(test_texts, n=2, ngram_counts=bi_counts,  context_counts=bi_ctx,  vocab=vocab, k=k)
    pp_tri_k = evaluate_perplexity(test_texts, n=3, ngram_counts=tri_counts, context_counts=tri_ctx, vocab=vocab, k=k)
    print(f"k={k:>4}:  bigram PP={pp_bi_k:,.2f}   trigram PP={pp_tri_k:,.2f}")



k= 1.0:  bigram PP=1.95   trigram PP=2.09
k= 0.5:  bigram PP=1.87   trigram PP=1.96
k= 0.1:  bigram PP=1.79   trigram PP=1.80
k=0.01:  bigram PP=1.76   trigram PP=1.75


## Exercises (do these during lab)
1) Add 20 more realistic domain sentences to the corpus and re-run training/evaluation.  
2) Change `min_count` (OOV threshold) and explain how perplexity changes.  
3) Implement **backoff**: if a trigram is unseen, fall back to bigram; if unseen, fall back to unigram.  
4) Create a function that returns **top-5 next words** given a phrase like: `"user cannot"`.


In [20]:
# ====== Exercise 1: Expand corpus + retrain + re-evaluate ======

import random
import math
from collections import Counter, defaultdict

# -------------------------------------------------
# 1) Add 20 new realistic domain sentences
# -------------------------------------------------

corpus += [
    "dns lookup fails intermittently causing slow page loads",
    "sso login redirects in a loop after recent configuration change",
    "ssl certificate expired users see security warning in browser",
    "database connection pool exhausted under peak traffic",
    "service health check failing pod keeps restarting in kubernetes",
    "disk usage on server is 95 percent logs not rotating",
    "cpu usage spikes to 100 percent after nightly batch job",
    "cannot connect to vpn error 720 on windows",
    "azure ad sync error duplicate attribute detected",
    "password policy change forces reset but emails not sent",
    "api returns 429 rate limit exceeded for valid clients",
    "file upload fails with 413 payload too large",
    "scheduled report job fails permission denied writing to s3",
    "user account locked after multiple failed mfa attempts",
    "mobile app crashes on startup after latest release",
    "latency high between regions packet loss detected",
    "redis cache miss rate increased after deployment",
    "websocket disconnects frequently on unstable network",
    "printer not found on network after ip change",
    "time drift detected on server ntp not syncing",
]

# -------------------------------------------------
# 2) Train test split
# -------------------------------------------------

random.seed(42)
random.shuffle(corpus)

split = int(0.8 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

# -------------------------------------------------
# 3) Tokenization helpers
# -------------------------------------------------

def tokenize(text):
    return ["<s>"] + text.split() + ["</s>"]

# -------------------------------------------------
# 4) Build vocab with min_count
# -------------------------------------------------

min_count = 1   # change later in Exercise 2

freq = Counter()
for sent in train_texts:
    freq.update(tokenize(sent))

vocab = {w for w, c in freq.items() if c >= min_count}
vocab.add("<UNK>")

def replace_oov(tokens, vocab):
    return [t if t in vocab else "<UNK>" for t in tokens]

# -------------------------------------------------
# 5) Ngram counters
# -------------------------------------------------

def build_ngram_counts(texts, n, vocab):
    ngram_counts = Counter()
    context_counts = Counter()

    for sent in texts:
        tokens = replace_oov(tokenize(sent), vocab)

        for i in range(len(tokens) - n + 1):
            ng = tuple(tokens[i:i+n])
            ngram_counts[ng] += 1
            context = ng[:-1]
            context_counts[context] += 1

    return ngram_counts, context_counts


uni_counts, _ = build_ngram_counts(train_texts, 1, vocab)
bi_counts, bi_ctx = build_ngram_counts(train_texts, 2, vocab)
tri_counts, tri_ctx = build_ngram_counts(train_texts, 3, vocab)

# -------------------------------------------------
# 6) Add-k probability
# -------------------------------------------------

def prob_addk(ngram, ngram_counts, context_counts, V, k=0.5):
    if len(ngram) == 1:
        return (ngram_counts[ngram] + k) / (sum(ngram_counts.values()) + k * V)

    context = ngram[:-1]
    return (ngram_counts[ngram] + k) / (context_counts[context] + k * V)

# -------------------------------------------------
# 7) Perplexity
# -------------------------------------------------

def perplexity(texts, n, ngram_counts, context_counts, vocab, k=0.5):

    V = len(vocab)
    log_prob_sum = 0
    N = 0

    for sent in texts:
        tokens = replace_oov(tokenize(sent), vocab)

        for i in range(len(tokens) - n + 1):
            ng = tuple(tokens[i:i+n])
            p = prob_addk(ng, ngram_counts, context_counts, V, k)

            log_prob_sum += math.log(p)
            N += 1

    return math.exp(-log_prob_sum / N)


pp_uni = perplexity(test_texts, 1, uni_counts, None, vocab)
pp_bi = perplexity(test_texts, 2, bi_counts, bi_ctx, vocab)
pp_tri = perplexity(test_texts, 3, tri_counts, tri_ctx, vocab)

pp_uni, pp_bi, pp_tri


(193.5384198636684, 158.84278982001533, 163.72100037353383)

In [None]:
# ====== Exercise 2: Change min_count (OOV threshold) and compare perplexity ======

import random
import math
from collections import Counter

# ---------------------------
# 1) Train test split
# ---------------------------

random.seed(42)
random.shuffle(corpus)

split = int(0.8 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

# ---------------------------
# 2) Tokenization + OOV helpers
# ---------------------------

def tokenize(text):
    return ["<s>"] + text.split() + ["</s>"]

def replace_oov(tokens, vocab):
    return [t if t in vocab else "<UNK>" for t in tokens]

# ---------------------------
# 3) Build ngram counts
# ---------------------------

def build_ngram_counts(texts, n, vocab):
    ngram_counts = Counter()
    context_counts = Counter()

    for sent in texts:
        tokens = replace_oov(tokenize(sent), vocab)

        for i in range(len(tokens) - n + 1):
            ng = tuple(tokens[i:i+n])
            ngram_counts[ng] += 1
            context = ng[:-1]
            context_counts[context] += 1

    return ngram_counts, context_counts

# ---------------------------
# 4) Add-k probability
# ---------------------------

def prob_addk(ngram, ngram_counts, context_counts, V, k=0.5):
    if len(ngram) == 1:
        return (ngram_counts[ngram] + k) / (sum(ngram_counts.values()) + k * V)
    context = ngram[:-1]
    return (ngram_counts[ngram] + k) / (context_counts[context] + k * V)

# ---------------------------
# 5) Perplexity
# ---------------------------

def perplexity(texts, n, ngram_counts, context_counts, vocab, k=0.5):
    V = len(vocab)
    log_prob_sum = 0.0
    N = 0

    for sent in texts:
        tokens = replace_oov(tokenize(sent), vocab)

        for i in range(len(tokens) - n + 1):
            ng = tuple(tokens[i:i+n])
            p = prob_addk(ng, ngram_counts, context_counts, V, k)
            log_prob_sum += math.log(p)
            N += 1

    return math.exp(-log_prob_sum / N)

# ---------------------------
# 6) Run experiment for different min_count values
# ---------------------------

k_smooth = 0.5
min_counts = [1, 2, 3, 5]  

results = []

for mc in min_counts:
    # build vocab
    freq = Counter()
    for sent in train_texts:
        freq.update(tokenize(sent))

    vocab = {w for w, c in freq.items() if c >= mc}
    vocab.add("<UNK>")

    # train ngrams
    uni_counts, _ = build_ngram_counts(train_texts, 1, vocab)
    bi_counts, bi_ctx = build_ngram_counts(train_texts, 2, vocab)
    tri_counts, tri_ctx = build_ngram_counts(train_texts, 3, vocab)

    # evaluate perplexity
    pp_uni = perplexity(test_texts, 1, uni_counts, None, vocab, k=k_smooth)
    pp_bi  = perplexity(test_texts, 2, bi_counts, bi_ctx, vocab, k=k_smooth)
    pp_tri = perplexity(test_texts, 3, tri_counts, tri_ctx, vocab, k=k_smooth)

    # report oov ratio on test set (helpful for explanation)
    total_tokens = 0
    oov_tokens = 0
    for sent in test_texts:
        toks = tokenize(sent)
        total_tokens += len(toks)
        oov_tokens += sum(1 for t in toks if t not in vocab)

    oov_ratio = oov_tokens / total_tokens

    results.append((mc, len(vocab), oov_ratio, pp_uni, pp_bi, pp_tri))

# ---------------------------
# 7)  print
# ---------------------------

print(f"{'min_count':>9} | {'VocabSize':>9} | {'OOV%':>6} | {'PP_uni':>10} | {'PP_bi':>10} | {'PP_tri':>10}")
print("-" * 74)

for mc, vsize, oov_ratio, ppu, ppb, ppt in results:
    print(f"{mc:>9} | {vsize:>9} | {oov_ratio*100:>5.1f}% | {ppu:>10.2f} | {ppb:>10.2f} | {ppt:>10.2f}")

results


min_count | VocabSize |   OOV% |     PP_uni |      PP_bi |     PP_tri
--------------------------------------------------------------------------
        1 |       168 |  52.3% |     201.43 |     168.51 |     168.04
        2 |        30 |  69.2% |       3.98 |       3.73 |       5.87
        3 |        14 |  69.2% |       3.35 |       2.80 |       3.61
        5 |         6 |  75.4% |       2.36 |       1.78 |       1.85


[(1,
  168,
  0.5230769230769231,
  201.43152163872222,
  168.5054504820778,
  168.03898861922355),
 (2,
  30,
  0.6923076923076923,
  3.976190035742321,
  3.728033113080755,
  5.867980977297841),
 (3,
  14,
  0.6923076923076923,
  3.346262972145768,
  2.7950536828537005,
  3.614217671609578),
 (5,
  6,
  0.7538461538461538,
  2.360129610683516,
  1.7836713862003923,
  1.8455029140131973)]

When I increased min_count, the vocabulary size dropped a lot and more words were treated as <UNK>, which increased the OOV rate. With min_count = 1, the vocabulary was large, but because the dataset is small, many n grams only appeared once. This caused data sparsity and led to very high perplexity on the test set.
When I changed min_count to 2 and 3, rare words were removed from the vocabulary, which reduced sparsity and made the model generalize better. As a result, perplexity decreased significantly for all three models.
With min_count = 5, the vocabulary became extremely small and most tokens in the test data were replaced with <UNK>. This further reduced perplexity, but it also means the model lost a lot of specific information about the text. This shows there is a trade off between keeping a richer vocabulary and reducing sparsity when selecting the OOV threshold.

In [22]:
# ====== Exercise 3: Backoff (trigram -> bigram -> unigram) with add-k smoothing ======

import random
import math
from collections import Counter

# ---------------------------
# 1) Train test split
# ---------------------------

random.seed(42)
random.shuffle(corpus)

split = int(0.8 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

# ---------------------------
# 2) Tokenization + OOV
# ---------------------------

def tokenize(text):
    return ["<s>"] + text.split() + ["</s>"]

def replace_oov(tokens, vocab):
    return [t if t in vocab else "<UNK>" for t in tokens]

# ---------------------------
# 3) Build vocab
# ---------------------------

min_count = 1  # you can change this
freq = Counter()
for sent in train_texts:
    freq.update(tokenize(sent))

vocab = {w for w, c in freq.items() if c >= min_count}
vocab.add("<UNK>")

# ---------------------------
# 4) Build ngram counts
# ---------------------------

def build_ngram_counts(texts, n, vocab):
    ngram_counts = Counter()
    context_counts = Counter()

    for sent in texts:
        tokens = replace_oov(tokenize(sent), vocab)

        for i in range(len(tokens) - n + 1):
            ng = tuple(tokens[i:i+n])
            ngram_counts[ng] += 1
            context = ng[:-1]
            context_counts[context] += 1

    return ngram_counts, context_counts

uni_counts, _ = build_ngram_counts(train_texts, 1, vocab)
bi_counts, bi_ctx = build_ngram_counts(train_texts, 2, vocab)
tri_counts, tri_ctx = build_ngram_counts(train_texts, 3, vocab)

# ---------------------------
# 5) Add-k probability (base)
# ---------------------------

def prob_addk(ngram, ngram_counts, context_counts, V, k=0.5):
    if len(ngram) == 1:
        return (ngram_counts[ngram] + k) / (sum(ngram_counts.values()) + k * V)
    context = ngram[:-1]
    return (ngram_counts[ngram] + k) / (context_counts[context] + k * V)

# ---------------------------
# 6) Backoff probability
# ---------------------------

def prob_backoff_addk(trigram, tri_counts, tri_ctx, bi_counts, bi_ctx, uni_counts, vocab, k=0.5):
    """
    trigram is a tuple of length 3: (w1, w2, w3)
    If trigram unseen -> backoff to bigram (w2,w3)
    If bigram unseen -> backoff to unigram (w3,)
    Always uses add-k smoothing at the chosen order.
    """
    V = len(vocab)
    w1, w2, w3 = trigram

    # try trigram if context exists (seen)
    tri_context = (w1, w2)
    if tri_ctx[tri_context] > 0:
        return prob_addk((w1, w2, w3), tri_counts, tri_ctx, V, k)

    # backoff to bigram
    bi_context = (w2,)
    if bi_ctx[bi_context] > 0:
        return prob_addk((w2, w3), bi_counts, bi_ctx, V, k)

    # backoff to unigram
    return prob_addk((w3,), uni_counts, None, V, k)

# ---------------------------
# 7) Perplexity for backoff trigram
# ---------------------------

def perplexity_backoff(test_texts, tri_counts, tri_ctx, bi_counts, bi_ctx, uni_counts, vocab, k=0.5):
    log_prob_sum = 0.0
    N = 0

    for sent in test_texts:
        tokens = replace_oov(tokenize(sent), vocab)

        # Need trigram windows, so pad is already handled by <s> and </s>
        for i in range(len(tokens) - 3 + 1):
            tri = (tokens[i], tokens[i+1], tokens[i+2])
            p = prob_backoff_addk(tri, tri_counts, tri_ctx, bi_counts, bi_ctx, uni_counts, vocab, k=k)
            log_prob_sum += math.log(p)
            N += 1

    return math.exp(-log_prob_sum / N)

# ---------------------------
# 8) Compare PP: plain trigram vs backoff trigram
# ---------------------------

k_smooth = 0.5

def perplexity_plain_trigram(test_texts, tri_counts, tri_ctx, vocab, k=0.5):
    V = len(vocab)
    log_prob_sum = 0.0
    N = 0
    for sent in test_texts:
        tokens = replace_oov(tokenize(sent), vocab)
        for i in range(len(tokens) - 3 + 1):
            tri = (tokens[i], tokens[i+1], tokens[i+2])
            p = prob_addk(tri, tri_counts, tri_ctx, V, k=k)
            log_prob_sum += math.log(p)
            N += 1
    return math.exp(-log_prob_sum / N)

pp_tri_plain = perplexity_plain_trigram(test_texts, tri_counts, tri_ctx, vocab, k=k_smooth)
pp_tri_backoff = perplexity_backoff(test_texts, tri_counts, tri_ctx, bi_counts, bi_ctx, uni_counts, vocab, k=k_smooth)

pp_tri_plain, pp_tri_backoff


(168.11933369555558, 205.88947572965333)

In [23]:
# ====== Exercise 4: Top-5 next word predictions given a phrase (e.g., "user cannot") ======

import random
import math
from collections import Counter

# ---------------------------
# 1) Train test split
# ---------------------------

random.seed(42)
random.shuffle(corpus)

split = int(0.8 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

# ---------------------------
# 2) Tokenization + OOV
# ---------------------------

def tokenize(text):
    return ["<s>"] + text.split() + ["</s>"]

def replace_oov(tokens, vocab):
    return [t if t in vocab else "<UNK>" for t in tokens]

# ---------------------------
# 3) Build vocab
# ---------------------------

min_count = 2   # you can change this based on Exercise 2
freq = Counter()
for sent in train_texts:
    freq.update(tokenize(sent))

vocab = {w for w, c in freq.items() if c >= min_count}
vocab.add("<UNK>")

# ---------------------------
# 4) Build ngram counts
# ---------------------------

def build_ngram_counts(texts, n, vocab):
    ngram_counts = Counter()
    context_counts = Counter()

    for sent in texts:
        tokens = replace_oov(tokenize(sent), vocab)

        for i in range(len(tokens) - n + 1):
            ng = tuple(tokens[i:i+n])
            ngram_counts[ng] += 1
            context = ng[:-1]
            context_counts[context] += 1

    return ngram_counts, context_counts

uni_counts, _ = build_ngram_counts(train_texts, 1, vocab)
bi_counts, bi_ctx = build_ngram_counts(train_texts, 2, vocab)
tri_counts, tri_ctx = build_ngram_counts(train_texts, 3, vocab)

# ---------------------------
# 5) Add-k probability
# ---------------------------

def prob_addk(ngram, ngram_counts, context_counts, V, k=0.5):
    if len(ngram) == 1:
        return (ngram_counts[ngram] + k) / (sum(ngram_counts.values()) + k * V)
    context = ngram[:-1]
    return (ngram_counts[ngram] + k) / (context_counts[context] + k * V)

# ---------------------------
# 6) Backoff probability for next word
# ---------------------------

def next_word_prob_backoff(context_tokens, candidate, tri_counts, tri_ctx, bi_counts, bi_ctx, uni_counts, vocab, k=0.5):
    """
    context_tokens: list of tokens (already OOV-replaced)
    candidate: next token to score
    backoff order: trigram -> bigram -> unigram
    """
    V = len(vocab)

    # try trigram if we have >=2 context tokens
    if len(context_tokens) >= 2:
        w1, w2 = context_tokens[-2], context_tokens[-1]
        tri_context = (w1, w2)
        if tri_ctx[tri_context] > 0:
            return prob_addk((w1, w2, candidate), tri_counts, tri_ctx, V, k)

    # backoff to bigram if we have >=1 context token
    if len(context_tokens) >= 1:
        w2 = context_tokens[-1]
        bi_context = (w2,)
        if bi_ctx[bi_context] > 0:
            return prob_addk((w2, candidate), bi_counts, bi_ctx, V, k)

    # backoff to unigram
    return prob_addk((candidate,), uni_counts, None, V, k)

# ---------------------------
# 7) Top-k next word function
# ---------------------------

def topk_next_words(phrase, top_k=5, k_smooth=0.5, include_end_token=False):
    """
    Returns list of (word, prob) for top_k next word predictions.
    """
    # tokenize phrase (do NOT add <s> </s> around it, treat it as context only)
    raw_tokens = phrase.split()
    ctx = replace_oov(raw_tokens, vocab)

    # candidates: vocab minus special start token
    candidates = [w for w in vocab if w != "<s>"]
    if not include_end_token:
        candidates = [w for w in candidates if w != "</s>"]

    scored = []
    for w in candidates:
        p = next_word_prob_backoff(ctx, w, tri_counts, tri_ctx, bi_counts, bi_ctx, uni_counts, vocab, k=k_smooth)
        scored.append((w, p))

    scored.sort(key=lambda x: x[1], reverse=True)
    return scored[:top_k]

# ---------------------------
# 8) Example usage
# ---------------------------

topk_next_words("user cannot", top_k=5, k_smooth=0.5)


[('login', 0.09090909090909091),
 ('expired', 0.030303030303030304),
 ('on', 0.030303030303030304),
 ('in', 0.030303030303030304),
 ('with', 0.030303030303030304)]