# AIG230 NLP (Week 3 Lab) — Notebook 2: Statistical Language Models (Train, Test, Evaluate)

This notebook focuses on **n-gram Statistical Language Models (SLMs)**:
- Train **unigram**, **bigram**, **trigram** models
- Handle **OOV** with `<UNK>`
- Apply **smoothing** (Add-k)
- Evaluate with **cross-entropy** and **perplexity**
- Do **next-word prediction** and simple **text generation**

> Industry framing: even if modern systems use neural LMs, n-gram LMs are still useful for
baselines, constrained domains, and for understanding evaluation.


### What is smoothing?

Smoothing is a way to stop a language model from saying “this can never happen.”

When we train a language model from data, it only knows what it has seen before.
If it never saw a particular word sequence, the model would normally give it a probability of zero.

Smoothing fixes that.
### Why is this a problem without smoothing?

Imagine the model learned English only by reading a small number of news articles.

If it never saw:

- “oil prices explode”

the model would conclude:

- “That sentence is impossible.”

But as humans, we know it could happen. The model just hasn’t seen it yet.

Without smoothing:

- One unseen word makes the whole sentence probability zero

- Evaluation breaks

- The model is too confident and too brittle

## 0) Setup


In [1]:

import re
import math
import random
from collections import Counter, defaultdict
from typing import List, Tuple, Dict


## 1) Data: domain text you might see in real systems


We use short texts that resemble:
- release notes
- incident summaries
- operational runbooks
- customer support messaging

In practice, you would load thousands to millions of lines.


In [2]:

corpus = [
    "vpn disconnects frequently after windows update",
    "password reset link expired user cannot login",
    "api requests timeout when latency spikes",
    "portal returns 500 error after deployment",
    "email delivery delayed messages queued",
    "mfa prompt never arrives user stuck at login",
    "wifi drops in meeting rooms access point reboot helps",
    "outlook search not returning results index corrupted",
    "printer driver install fails with error 1603",
    "teams calls choppy audio jitter high",
    "permission denied accessing shared drive though in correct group",
    "battery drains fast after bios update power settings unchanged",
    "push notifications not working on android app",
    "mailbox full cannot receive emails auto archive not running",
]

# Train/test split at sentence level
random.seed(42)
random.shuffle(corpus)
split = int(0.75 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

len(train_texts), len(test_texts), train_texts[:2], test_texts[:2]


(10,
 4,
 ['printer driver install fails with error 1603',
  'push notifications not working on android app'],
 ['email delivery delayed messages queued',
  'vpn disconnects frequently after windows update'])

## 2) Tokenization + special tokens


We will:
- lowercase
- keep alphanumerics
- split on whitespace
- add sentence boundary tokens: `<s>` and `</s>`

We will also map rare tokens to `<UNK>` based on training frequency.


In [3]:

def tokenize(text: str) -> List[str]:
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]+", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text.split()

def add_boundaries(tokens: List[str], n: int) -> List[str]:
    # For n-grams, prepend (n-1) start tokens for simpler context handling
    return ["<s>"]*(n-1) + tokens + ["</s>"]

# Example
tokens = tokenize("Printer driver install fails with error 1603")
add_boundaries(tokens, n=3)


['<s>',
 '<s>',
 'printer',
 'driver',
 'install',
 'fails',
 'with',
 'error',
 '1603',
 '</s>']

## 3) Build vocabulary and handle OOV with <UNK>


In [4]:

# Build vocab from training data
train_tokens_flat = []
for t in train_texts:
    train_tokens_flat.extend(tokenize(t))

freq = Counter(train_tokens_flat)

# Typical practical rule: map tokens with frequency <= 1 to <UNK> in small corpora
min_count = 2
vocab = {w for w, c in freq.items() if c >= min_count}
vocab |= {"<UNK>", "<s>", "</s>"}

def replace_oov(tokens: List[str], vocab: set) -> List[str]:
    return [tok if tok in vocab else "<UNK>" for tok in tokens]

# Show OOV effect
sample = tokenize(test_texts[0])
sample, replace_oov(sample, vocab)


(['email', 'delivery', 'delayed', 'messages', 'queued'],
 ['<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'])

## 4) Train n-gram counts (unigram, bigram, trigram)


We will compute:
- `ngram_counts[(w1,...,wn)]`
- `context_counts[(w1,...,w_{n-1})]`

Then probability:
\ndefault:  P(w_n | context) = count(context + w_n) / count(context)

This fails when an n-gram is unseen, so we add smoothing.


In [5]:
def get_ngrams(tokens: List[str], n: int) -> List[Tuple[str, ...]]:
    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
            
def train_ngram_counts(texts: List[str], n: int, vocab: set) -> Dict[Tuple[str, ...], int]:
    ngram_counts = Counter()
    context_counts = Counter()
    for text in texts:
        toks = replace_oov(tokenize(text), vocab)
        toks = add_boundaries(toks, n)
        for ng in get_ngrams(toks, n):
            ngram_counts[ng] += 1
            context = ng[:-1]
            context_counts[context] += 1
            
    return ngram_counts, context_counts       


In [6]:
uni_counts, uni_ctx = train_ngram_counts(train_texts, n=1, vocab=vocab)

In [7]:
uni_counts

Counter({('<UNK>',): 67,
         ('</s>',): 10,
         ('not',): 3,
         ('error',): 2,
         ('after',): 2})

In [8]:
bi_counts, bi_ctx = train_ngram_counts(train_texts, n=2, vocab=vocab)

In [9]:
bi_counts

Counter({('<UNK>', '<UNK>'): 51,
         ('<s>', '<UNK>'): 10,
         ('<UNK>', '</s>'): 10,
         ('<UNK>', 'not'): 3,
         ('not', '<UNK>'): 3,
         ('<UNK>', 'error'): 2,
         ('after', '<UNK>'): 2,
         ('error', '<UNK>'): 1,
         ('<UNK>', 'after'): 1,
         ('error', 'after'): 1})

In [10]:
tri_counts, tri_ctx = train_ngram_counts(train_texts, n=3, vocab=vocab)


In [11]:
tri_counts

Counter({('<UNK>', '<UNK>', '<UNK>'): 38,
         ('<s>', '<s>', '<UNK>'): 10,
         ('<s>', '<UNK>', '<UNK>'): 10,
         ('<UNK>', '<UNK>', '</s>'): 7,
         ('<UNK>', '<UNK>', 'not'): 3,
         ('<UNK>', 'not', '<UNK>'): 3,
         ('<UNK>', '<UNK>', 'error'): 2,
         ('not', '<UNK>', '<UNK>'): 2,
         ('<UNK>', 'error', '<UNK>'): 1,
         ('error', '<UNK>', '</s>'): 1,
         ('not', '<UNK>', '</s>'): 1,
         ('<UNK>', '<UNK>', 'after'): 1,
         ('<UNK>', 'after', '<UNK>'): 1,
         ('after', '<UNK>', '<UNK>'): 1,
         ('<UNK>', 'error', 'after'): 1,
         ('error', 'after', '<UNK>'): 1,
         ('after', '<UNK>', '</s>'): 1})

## 5) Add-k smoothing and probability function


### What does Add-k smoothing do?
Add-k smoothing tells the model:

- “Even if you didn’t see something, assume it could still happen a little bit.”

It does this by:

- Giving every possible next word a tiny amount of probability

- Not just the ones seen in training

So instead of:

- seen → possible

- unseen → impossible

We get:

- seen → more likely

- unseen → less likely, but still possible


### Why is it called Add-k?

Because we add a small number k to every word count.

Think of it as:

- adding a tiny “imaginary observation” for every word

- so no word ever has zero probability

When k is small (like 0.1 or 0.5), it gently smooths the probabilities instead of overpowering real data.

In [12]:
# This function calculates the probability of a word appearing next, given the previous words, while making sure the probability is never zero.
def prob_addk(n_gram: Tuple[str, ...], ngram_counts: Counter, context_counts: Counter, V:int, k: float = 0.5) -> float:
    """
    Compute add-k P(w_n | w_1 ... w_(n-1))
    where ngram = (w_1, w_2, ..., w_n)
    0 < k <= 1
    V is the vocabulary size
    """
    context = n_gram[:-1]
    return (ngram_counts[n_gram] + k) / (context_counts[context] + k * V)

In [13]:
V = len(vocab)
example = ("<s>", "login")
prob_addk(example, bi_counts, bi_ctx, V, k=0.5)

0.038461538461538464

## 6) Evaluate: cross-entropy and perplexity on test set


We evaluate an LM by how well it predicts held-out text.

Cross-entropy (average negative log probability):
H = - (1/N) * sum log2 P(w_i | context)

Perplexity:
PP = 2^H

Lower perplexity is better.


In [14]:
def evaluate_perplexity(texts: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k: float = 0.5) -> float:
    V = len(vocab)
    log2_probs = []
    token_count = 0

    for text in texts:
        toks = replace_oov(tokenize(text), vocab)
        toks = add_boundaries(toks, n)
        ngrams = get_ngrams(toks, n)
        for ng in ngrams:
            p = prob_addk(ng, ngram_counts, context_counts, V, k=k)
            log2_probs.append(math.log(p, 2))
            token_count += 1

    H = -sum(log2_probs) / token_count
    PP = 2 ** H
    return PP

In [15]:
pp_uni = evaluate_perplexity(test_texts, n=1, ngram_counts=uni_counts, context_counts=uni_ctx, vocab=vocab, k=0.5)
pp_bi  = evaluate_perplexity(test_texts, n=2, ngram_counts=bi_counts,  context_counts=bi_ctx,  vocab=vocab, k=0.5)
pp_tri = evaluate_perplexity(test_texts, n=3, ngram_counts=tri_counts, context_counts=tri_ctx, vocab=vocab, k=0.5)

pp_uni, pp_bi, pp_tri

(1.8224739937573897, 1.8712095221558311, 1.9552746520172757)

## 7) Next-word prediction (top-k)


Given a context, compute the probability of each candidate next token and return the top-k.

This mirrors:
- autocomplete in constrained domains
- template suggestion systems
- command prediction in runbooks


In [16]:
def next_word_topk(context_tokens: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k_smooth: float = 0.5, top_k: int = 5):
    # Context length should be n-1
    V = len(vocab)
    context = tuple(context_tokens[-(n-1):]) if n > 1 else tuple()
    candidates = []
    for w in vocab:
        if w in {"<s>"}:
            continue
        ng = context + (w,)
        p = prob_addk(ng, ngram_counts, context_counts, V, k=k_smooth)
        candidates.append((w, p))
    candidates.sort(key=lambda x: -x[1])
    return candidates[:top_k]

# Bigram: context is 1 token
next_word_topk(["<s>"], n=2, ngram_counts=bi_counts, context_counts=bi_ctx, vocab=vocab, top_k=8)

[('<UNK>', 0.8076923076923077),
 ('not', 0.038461538461538464),
 ('</s>', 0.038461538461538464),
 ('after', 0.038461538461538464),
 ('error', 0.038461538461538464)]

## 8) Simple generation (bigram or trigram)


Text generation is not the main goal in SLMs, but it helps you verify:
- boundary handling
- smoothing
- OOV decisions

We will sample tokens until we hit `</s>`.


In [17]:
def sample_next(context_tokens: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k_smooth: float = 0.5):
    V = len(vocab)
    context = tuple(context_tokens[-(n-1):]) if n > 1 else tuple()
    words = [w for w in vocab if w != "<s>"]
    probs = []
    for w in words:
        ng = context + (w,)
        probs.append(prob_addk(ng, ngram_counts, context_counts, V, k=k_smooth))
    # Normalize
    s = sum(probs)
    probs = [p/s for p in probs]
    return random.choices(words, weights=probs, k=1)[0]

def generate(n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, max_len: int = 20, k_smooth: float = 0.5):
    tokens = ["<s>"]*(n-1) if n > 1 else []
    out = []
    for _ in range(max_len):
        w = sample_next(tokens, n, ngram_counts, context_counts, vocab, k_smooth=k_smooth)
        if w == "</s>":
            break
        out.append(w)
        tokens.append(w)
    return " ".join(out)

for _ in range(5):
    print("BIGRAM:", generate(2, bi_counts, bi_ctx, vocab, max_len=18))

BIGRAM: <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>
BIGRAM: <UNK> <UNK> <UNK> <UNK> <UNK> not <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>
BIGRAM: <UNK> <UNK> <UNK>
BIGRAM: <UNK> <UNK> <UNK> <UNK> not error after after <UNK> <UNK>
BIGRAM: <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>


## 9) Model comparison: effect of n and smoothing


Try different `k` values. Notes:
- `k=1.0` is Laplace smoothing (often too strong)
- smaller `k` (like 0.1 to 0.5) is often better

In real corpora, trigrams often beat bigrams, but require more data.


In [18]:
for k in [1.0, 0.5, 0.1, 0.01]:
    pp_bi_k  = evaluate_perplexity(test_texts, n=2, ngram_counts=bi_counts,  context_counts=bi_ctx,  vocab=vocab, k=k)
    pp_tri_k = evaluate_perplexity(test_texts, n=3, ngram_counts=tri_counts, context_counts=tri_ctx, vocab=vocab, k=k)
    print(f"k={k:>4}:  bigram PP={pp_bi_k:,.2f}   trigram PP={pp_tri_k:,.2f}")


k= 1.0:  bigram PP=1.95   trigram PP=2.09
k= 0.5:  bigram PP=1.87   trigram PP=1.96
k= 0.1:  bigram PP=1.79   trigram PP=1.80
k=0.01:  bigram PP=1.76   trigram PP=1.75


## Exercises (do these during lab)
1) Add 20 more realistic domain sentences to the corpus and re-run training/evaluation.  
2) Change `min_count` (OOV threshold) and explain how perplexity changes.  
3) Implement **backoff**: if a trigram is unseen, fall back to bigram; if unseen, fall back to unigram.


In [19]:
# add 20 more realistic domain sentences to corpus
additional_sentences = [
    "chrome browser crashes when opening multiple tabs memory leak detected",
    "laptop overheating during video calls fan speed insufficient",
    "external monitor not detected after docking station firmware update",
    "onedrive sync stuck at processing changes local cache needs clearing",
    "citrix session disconnects every fifteen minutes gateway timeout error",
    "adobe acrobat freezes when printing to network printer spooler service",
    "skype for business signs out randomly credential cache corrupted",
    "vpn tunnel drops when switching from wifi to ethernet adapter conflict",
    "sql server connection failed tcp ip protocol not enabled in configuration",
    "git pull hangs indefinitely authentication token expired needs regeneration",
    "jenkins build fails maven dependencies not resolving from repository",
    "docker container exits immediately port already in use by another process",
    "kubernetes pod stuck in pending state insufficient cluster resources",
    "splunk dashboard shows no data forwarder misconfigured on host",
    "ssl certificate expired causing browser security warnings for users",
    "active directory replication failing between domain controllers dns issue",
    "sharepoint document library not loading list view threshold exceeded",
    "zoom screen share displays black window graphics driver outdated",
    "backup job failed destination storage volume ran out of disk space",
    "patch management server offline endpoints not receiving security updates",
]

# Extend original corpus and re-shuffle
corpus_extended = corpus + additional_sentences

In [20]:
random.shuffle(corpus_extended)
split = int(0.75 * len(corpus_extended))
train_texts_extended = corpus_extended[:split]
test_texts_extended = corpus_extended[split:]

print(f"Extended corpus: {len(corpus_extended)} sentences")
print(f"Train: {len(train_texts_extended)}, Test: {len(test_texts_extended)}")

Extended corpus: 34 sentences
Train: 25, Test: 9


In [21]:
# Retrain with extended corpus (min_count=2 as baseline)
train_tokens_flat_ext = []
for t in train_texts_extended:
    train_tokens_flat_ext.extend(tokenize(t))
freq_ext = Counter(train_tokens_flat_ext)

min_count_baseline = 2
vocab_extended = {w for w, c in freq_ext.items() if c >= min_count_baseline}
vocab_extended |= {"<UNK>", "<s>", "</s>"}
V_ext = len(vocab_extended)

# Retrain n-gram models
uni_counts_ext, uni_ctx_ext = train_ngram_counts(train_texts_extended, n=1, vocab=vocab_extended)
bi_counts_ext, bi_ctx_ext = train_ngram_counts(train_texts_extended, n=2, vocab=vocab_extended)
tri_counts_ext, tri_ctx_ext = train_ngram_counts(train_texts_extended, n=3, vocab=vocab_extended)

In [22]:
# Evaluate extended models
pp_uni_ext = evaluate_perplexity(test_texts_extended, n=1, ngram_counts=uni_counts_ext, 
                                  context_counts=uni_ctx_ext, vocab=vocab_extended, k=0.5)
pp_bi_ext = evaluate_perplexity(test_texts_extended, n=2, ngram_counts=bi_counts_ext, 
                                 context_counts=bi_ctx_ext, vocab=vocab_extended, k=0.5)
pp_tri_ext = evaluate_perplexity(test_texts_extended, n=3, ngram_counts=tri_counts_ext, 
                                  context_counts=tri_ctx_ext, vocab=vocab_extended, k=0.5)

print(f"\nExtended Corpus Results (min_count={min_count_baseline}):")
print(f"  Unigram PP: {pp_uni_ext:.2f}")
print(f"  Bigram PP:  {pp_bi_ext:.2f}")
print(f"  Trigram PP: {pp_tri_ext:.2f}")

# Compare with original
print(f"\nComparison with Original (14 sentences) vs Extended (34 sentences):")
print(f"  Original Trigram PP: {pp_tri:.2f}")
print(f"  Extended Trigram PP: {pp_tri_ext:.2f}")



Extended Corpus Results (min_count=2):
  Unigram PP: 2.78
  Bigram PP:  3.14
  Trigram PP: 3.75

Comparison with Original (14 sentences) vs Extended (34 sentences):
  Original Trigram PP: 1.96
  Extended Trigram PP: 3.75


In [26]:
# Change min_count and explain perplexity changes
print("\n" + "="*60)
print("Effect of min_count (OOV threshold) on Perplexity")
print("="*60)

for min_count in [1, 2, 3, 4]:
    # Build vocab with different threshold
    vocab_test = {w for w, c in freq_ext.items() if c >= min_count}
    vocab_test |= {"<UNK>", "<s>", "</s>"}
    V_test = len(vocab_test)
    
    # Retrain all n-grams with new vocab
    uni_c, uni_ctx_c = train_ngram_counts(train_texts_extended, n=1, vocab=vocab_test)
    bi_c, bi_ctx_c = train_ngram_counts(train_texts_extended, n=2, vocab=vocab_test)
    tri_c, tri_ctx_c = train_ngram_counts(train_texts_extended, n=3, vocab=vocab_test)
    # Evaluate
    pp_u = evaluate_perplexity(test_texts_extended, n=1, ngram_counts=uni_c, 
                               context_counts=uni_ctx_c, vocab=vocab_test, k=0.5)
    pp_b = evaluate_perplexity(test_texts_extended, n=2, ngram_counts=bi_c, 
                               context_counts=bi_ctx_c, vocab=vocab_test, k=0.5)
    pp_t = evaluate_perplexity(test_texts_extended, n=3, ngram_counts=tri_c, 
                               context_counts=tri_ctx_c, vocab=vocab_test, k=0.5)
    # Count OOV occurrences in test set
    test_tokens_flat = []
    for t in test_texts_extended:
        test_tokens_flat.extend(tokenize(t))
    oov_count = sum(1 for tok in test_tokens_flat if tok not in vocab_test and tok != "<UNK>")
    oov_rate = oov_count / len(test_tokens_flat) * 100
    
    print(f"\nmin_count = {min_count}:")
    print(f"  Vocab size: {V_test-3} (+ 3 special tokens)")
    print(f"  OOV rate in test: {oov_rate:.1f}%")
    print(f"  Unigram PP: {pp_u:6.2f}")
    print(f"  Bigram PP:  {pp_b:6.2f}")
    print(f"  Trigram PP: {pp_t:6.2f}")
          


Effect of min_count (OOV threshold) on Perplexity

min_count = 1:
  Vocab size: 188 (+ 3 special tokens)
  OOV rate in test: 69.9%
  Unigram PP: 307.40
  Bigram PP:  191.60
  Trigram PP: 195.94

min_count = 2:
  Vocab size: 21 (+ 3 special tokens)
  OOV rate in test: 89.0%
  Unigram PP:   2.78
  Bigram PP:    3.14
  Trigram PP:   3.75

min_count = 3:
  Vocab size: 4 (+ 3 special tokens)
  OOV rate in test: 98.6%
  Unigram PP:   1.60
  Bigram PP:    1.60
  Trigram PP:   1.58

min_count = 4:
  Vocab size: 2 (+ 3 special tokens)
  OOV rate in test: 98.6%
  Unigram PP:   1.55
  Bigram PP:    1.54
  Trigram PP:   1.53


Explanation of min_count effects:
- min_count=1: No OOV mapping. High sparsity, many n-grams with zero counts.
  Perplexity tends to be high due to unreliable probability estimates for rare words.
  
- min_count=2: Maps hapax legomena (single-occurrence words) to <UNK>.
  Usually OPTIMAL for small corpora - balances vocabulary coverage vs. sparsity.
  <UNK> token absorbs probability mass of rare events more reliably than singletons.
  
- min_count=3: Aggressive OOV filtering. Loses valid technical terms that appear twice.
  May increase perplexity if test set contains many moderately rare but valid terms
  that get mapped to <UNK>, reducing prediction specificity.
  
- min_count=4+: Too aggressive for small domain corpora. Vocabulary becomes too small,
  model underfits, unable to distinguish between important domain concepts.

In [None]:
# Implement Backoff (trigram -> bigram -> unigram)
# =============================================================================

print("\n" + "="*60)
print("Backoff Smoothing Implementation")
print("="*60)

def prob_with_backoff(n_gram: Tuple[str, ...], 
                     ngram_counts: Counter, context_counts: Counter,
                     backoff_counts: Counter, backoff_context: Counter,
                     unigram_counts: Counter, V: int, 
                     alpha: float = 0.4) -> float:
    """
    if n-gram unseen, back off to (n-1)-gram.
    For trigram -> bigram -> unigram.
    Unigrams use add-k smoothing to avoid zero.
    """
    n = len(n_gram)
    context = n_gram[:-1]
    
    if n == 3:
        # Try trigram first
        if ngram_counts[n_gram] > 0:
            return ngram_counts[n_gram] / context_counts[context]
        else:
            # Backoff to bigram with discount
            backoff_ngram = n_gram[1:]  # (w2, w3)
            return alpha * prob_with_backoff(backoff_ngram,
                                           backoff_counts, backoff_context,
                                           unigram_counts, None,
                                           unigram_counts, V, alpha)
    
    elif n == 2:
        # Try bigram
        if ngram_counts[n_gram] > 0:
            return ngram_counts[n_gram] / context_counts[context]
        else:
            # Backoff to unigram with discount
            word = n_gram[-1]
            k = 0.01  # Small add-k for unigram
            total_unigrams = sum(unigram_counts.values())
            return alpha * (unigram_counts[(word,)] + k) / (total_unigrams + k * V)
    
    else:
        # Unigram base case (shouldn't reach here in recursive calls)
        word = n_gram[-1]
        k = 0.01
        total_unigrams = sum(unigram_counts.values())
        return (unigram_counts[(word,)] + k) / (total_unigrams + k * V)

def evaluate_perplexity_backoff(test_texts: List[str], 
                                tri_counts: Counter, tri_ctx: Counter,
                                bi_counts: Counter, bi_ctx: Counter,
                                uni_counts: Counter, vocab: set, 
                                alpha: float = 0.4) -> float:
    """Evaluate trigram model with backoff to bigram/unigram"""
    V = len(vocab)
    log2_probs = []
    token_count = 0
    
    for text in test_texts:
        toks = replace_oov(tokenize(text), vocab)
        toks = add_boundaries(toks, n=3)
        ngrams = get_ngrams(toks, 3)
        
        for ng in ngrams:
            p = prob_with_backoff(ng, tri_counts, tri_ctx,
                                 bi_counts, bi_ctx,
                                 uni_counts, V, alpha)
            # Safety check
            if p <= 0:
                p = 1e-10
            log2_probs.append(math.log(p, 2))
            token_count += 1
    
    H = -sum(log2_probs) / token_count
    return 2 ** H


Backoff Smoothing Implementation


In [29]:
# Compare backoff vs add-k
print("\nComparing Backoff vs Add-k Smoothing (Trigram):")
for alpha in [0.3, 0.4, 0.5]:
    pp_backoff = evaluate_perplexity_backoff(test_texts_extended,
                                            tri_counts_ext, tri_ctx_ext,
                                            bi_counts_ext, bi_ctx_ext,
                                            uni_counts_ext, vocab_extended, alpha)
    print(f"  Backoff (alpha={alpha}): PP = {pp_backoff:.2f}")

pp_addk = evaluate_perplexity(test_texts_extended, n=3, ngram_counts=tri_counts_ext,
                             context_counts=tri_ctx_ext, vocab=vocab_extended, k=0.5)
print(f"  Add-k (k=0.5):            PP = {pp_addk:.2f}")

# Demonstrate backoff prediction
def next_word_backoff(context_tokens: List[str], 
                     tri_counts: Counter, tri_ctx: Counter,
                     bi_counts: Counter, bi_ctx: Counter,
                     uni_counts: Counter, vocab: set, 
                     alpha: float = 0.4, top_k: int = 5):
    """Predict next word using backoff from trigram to unigram"""
    V = len(vocab)
    context = tuple(context_tokens[-2:])  # Last 2 for trigram context
    candidates = []
    
    for w in vocab:
        if w in {"<s>"}:
            continue
        ng = context + (w,)
        p = prob_with_backoff(ng, tri_counts, tri_ctx,
                             bi_counts, bi_ctx,
                             uni_counts, V, alpha)
        candidates.append((w, p))
    
    candidates.sort(key=lambda x: -x[1])
    return candidates[:top_k]



Comparing Backoff vs Add-k Smoothing (Trigram):
  Backoff (alpha=0.3): PP = 2.66
  Backoff (alpha=0.4): PP = 2.61
  Backoff (alpha=0.5): PP = 2.57
  Add-k (k=0.5):            PP = 3.75


In [31]:
# Test backoff prediction
print("\nBackoff Prediction Examples:")
test_contexts = [
    ["<s>", "vpn"],
    ["password", "reset"],
    ["email", "delivery"]
]

for ctx in test_contexts:
    preds = next_word_backoff(ctx, tri_counts_ext, tri_ctx_ext,
                             bi_counts_ext, bi_ctx_ext,
                             uni_counts_ext, vocab_extended, alpha=0.4)
    print(f"\nContext: '{' '.join(ctx)}'")
    for word, prob in preds:
        status = ""
        tri_check = tuple(ctx[-2:] + [word]) if len(ctx) >= 2 else tuple()
        bi_check = tuple(ctx[-1:] + [word]) if len(ctx) >= 1 else tuple()
        
        if tri_check in tri_counts_ext and tri_counts_ext[tri_check] > 0:
            status = "(trigram)"
        elif bi_check in bi_counts_ext and bi_counts_ext[bi_check] > 0:
            status = "(bigram)"
        else:
            status = "(unigram)"
            
        print(f"  {word:15} {prob:.6f} {status}")




Backoff Prediction Examples:

Context: '<s> vpn'
  <UNK>           0.500000 (trigram)
  disconnects     0.500000 (trigram)
  </s>            0.016451 (unigram)
  not             0.004611 (unigram)
  after           0.002638 (unigram)

Context: 'password reset'
  <UNK>           0.109857 (unigram)
  </s>            0.016451 (unigram)
  not             0.004611 (unigram)
  after           0.002638 (unigram)
  error           0.001980 (unigram)

Context: 'email delivery'
  <UNK>           0.109857 (unigram)
  </s>            0.016451 (unigram)
  not             0.004611 (unigram)
  after           0.002638 (unigram)
  error           0.001980 (unigram)


Backoff vs Add-k:
- Backoff preserves maximum likelihood estimates for seen n-grams (no smoothing)
- Only applies smoothing/discount when backing off to lower orders
- Usually achieves lower perplexity than fixed add-k smoothing
- Alpha (discount factor) controls how much probability mass to reserve for unseen events
  - Lower alpha (0.3): less mass for unseen, more for seen events
  - Higher alpha (0.5): more mass for unseen events