# AIG230 NLP (Week 3 Lab) — Notebook 2: Statistical Language Models (Train, Test, Evaluate)

This notebook focuses on **n-gram Statistical Language Models (SLMs)**:
- Train **unigram**, **bigram**, **trigram** models
- Handle **OOV** with `<UNK>`
- Apply **smoothing** (Add-k)
- Evaluate with **cross-entropy** and **perplexity**
- Do **next-word prediction** and simple **text generation**

> Industry framing: even if modern systems use neural LMs, n-gram LMs are still useful for
baselines, constrained domains, and for understanding evaluation.


## 0) Setup


In [1]:

import re
import math
import random
from collections import Counter, defaultdict
from typing import List, Tuple, Dict


## 1) Data: domain text you might see in real systems


We use short texts that resemble:
- release notes
- incident summaries
- operational runbooks
- customer support messaging

In practice, you would load thousands to millions of lines.


In [2]:

corpus = [
    "vpn disconnects frequently after windows update",
    "password reset link expired user cannot login",
    "api requests timeout when latency spikes",
    "portal returns 500 error after deployment",
    "email delivery delayed messages queued",
    "mfa prompt never arrives user stuck at login",
    "wifi drops in meeting rooms access point reboot helps",
    "outlook search not returning results index corrupted",
    "printer driver install fails with error 1603",
    "teams calls choppy audio jitter high",
    "permission denied accessing shared drive though in correct group",
    "battery drains fast after bios update power settings unchanged",
    "push notifications not working on android app",
    "mailbox full cannot receive emails auto archive not running",
]

# Train/test split at sentence level
random.seed(42)
random.shuffle(corpus)
split = int(0.75 * len(corpus))
train_texts = corpus[:split]
test_texts = corpus[split:]

len(train_texts), len(test_texts), train_texts[:2], test_texts[:2]


(10,
 4,
 ['printer driver install fails with error 1603',
  'push notifications not working on android app'],
 ['email delivery delayed messages queued',
  'vpn disconnects frequently after windows update'])

## 2) Tokenization + special tokens


We will:
- lowercase
- keep alphanumerics
- split on whitespace
- add sentence boundary tokens: `<s>` and `</s>`

We will also map rare tokens to `<UNK>` based on training frequency.


In [3]:

def tokenize(text: str) -> List[str]:
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]+", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text.split()

def add_boundaries(tokens: List[str], n: int) -> List[str]:
    # For n-grams, prepend (n-1) start tokens for simpler context handling
    return ["<s>"]*(n-1) + tokens + ["</s>"]

# Example
tokens = tokenize("Printer driver install fails with error 1603")
add_boundaries(tokens, n=3)


['<s>',
 '<s>',
 'printer',
 'driver',
 'install',
 'fails',
 'with',
 'error',
 '1603',
 '</s>']

## 3) Build vocabulary and handle OOV with <UNK>


In [4]:

# Build vocab from training data
train_tokens_flat = []
for t in train_texts:
    train_tokens_flat.extend(tokenize(t))

freq = Counter(train_tokens_flat)

# Typical practical rule: map tokens with frequency <= 1 to <UNK> in small corpora
min_count = 2
vocab = {w for w, c in freq.items() if c >= min_count}
vocab |= {"<UNK>", "<s>", "</s>"}

def replace_oov(tokens: List[str], vocab: set) -> List[str]:
    return [tok if tok in vocab else "<UNK>" for tok in tokens]

# Show OOV effect
sample = tokenize(test_texts[0])
sample, replace_oov(sample, vocab)


(['email', 'delivery', 'delayed', 'messages', 'queued'],
 ['<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'])

## 4) Train n-gram counts (unigram, bigram, trigram)


We will compute:
- `ngram_counts[(w1,...,wn)]`
- `context_counts[(w1,...,w_{n-1})]`

Then probability:
\ndefault:  P(w_n | context) = count(context + w_n) / count(context)

This fails when an n-gram is unseen, so we add smoothing.


In [5]:
def get_ngrams(tokens: List[str], n: int) -> List[Tuple[str, ...]]:
    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]

def train_ngram_counts(texts: List[str], n: int, vocab: set) -> Dict[Tuple[str, ...], int]:
    ngram_counts = Counter()
    context_counts = Counter()
    for text in texts:
       toks = replace_oov(tokenize(text), vocab)
       toks = add_boundaries(toks, n)
       for ng in get_ngrams(toks, n):
           ngram_counts[ng] += 1
           context = ng[:-1]
           context_counts[context] += 1
    return ngram_counts, context_counts




In [6]:
uni_counts, uni_ctx = train_ngram_counts(train_texts, 1, vocab)
uni_counts

Counter({('<UNK>',): 67,
         ('</s>',): 10,
         ('not',): 3,
         ('error',): 2,
         ('after',): 2})

In [7]:
bi_counts, bi_ctx = train_ngram_counts(train_texts, 2, vocab)
bi_counts

Counter({('<UNK>', '<UNK>'): 51,
         ('<s>', '<UNK>'): 10,
         ('<UNK>', '</s>'): 10,
         ('<UNK>', 'not'): 3,
         ('not', '<UNK>'): 3,
         ('<UNK>', 'error'): 2,
         ('after', '<UNK>'): 2,
         ('error', '<UNK>'): 1,
         ('<UNK>', 'after'): 1,
         ('error', 'after'): 1})

In [8]:
tri_counts, tri_ctx = train_ngram_counts(train_texts, 3, vocab)
tri_counts

Counter({('<UNK>', '<UNK>', '<UNK>'): 38,
         ('<s>', '<s>', '<UNK>'): 10,
         ('<s>', '<UNK>', '<UNK>'): 10,
         ('<UNK>', '<UNK>', '</s>'): 7,
         ('<UNK>', '<UNK>', 'not'): 3,
         ('<UNK>', 'not', '<UNK>'): 3,
         ('<UNK>', '<UNK>', 'error'): 2,
         ('not', '<UNK>', '<UNK>'): 2,
         ('<UNK>', 'error', '<UNK>'): 1,
         ('error', '<UNK>', '</s>'): 1,
         ('not', '<UNK>', '</s>'): 1,
         ('<UNK>', '<UNK>', 'after'): 1,
         ('<UNK>', 'after', '<UNK>'): 1,
         ('after', '<UNK>', '<UNK>'): 1,
         ('<UNK>', 'error', 'after'): 1,
         ('error', 'after', '<UNK>'): 1,
         ('after', '<UNK>', '</s>'): 1})

## 5) Add-k smoothing and probability function


Add-k smoothing (a common baseline):
\na) Add *k* to every possible next word count  
b) Normalize by context_count + k * |V|

P_k(w|h) = (count(h,w) + k) / (count(h) + k*|V|)

Where V is the vocabulary.


In [9]:

def prob_addk(ngram: Tuple[str, ...], ngram_counts: Counter,
              context_counts: Counter, vocab_size: int, k: float) -> float:
    
    context = ngram[:-1]
    return (ngram_counts[ngram] + k) / (context_counts[context] + k * vocab_size)

In [11]:
V = len(vocab)
example = ("<s>", "login")
prob_addk(example, bi_counts, bi_ctx, V, k = 0.5)

0.038461538461538464

## 6) Evaluate: cross-entropy and perplexity on test set


We evaluate an LM by how well it predicts held-out text.

Cross-entropy (average negative log probability):
H = - (1/N) * sum log2 P(w_i | context)

Perplexity:
PP = 2^H

Lower perplexity is better.


In [10]:
def evaluate_perplexity(texts: List[str], n: int,
                        ngram_counts: Counter, context_counts: Counter,
                        vocab: set, k: float) -> float:
    V = len(vocab)
    total_log_prob = 0.0
    total_ngrams = 0
    for text in texts:
        toks = replace_oov(tokenize(text), vocab)
        toks = add_boundaries(toks, n)
        ngrams = get_ngrams(toks, n)
        for ng in ngrams:
            prob = prob_addk(ng, ngram_counts, context_counts, V, k)
            total_log_prob += math.log(prob)
            total_ngrams += 1
    avg_log_prob = total_log_prob / total_ngrams
    perplexity = math.exp(-avg_log_prob)
    return perplexity


## 7) Next-word prediction (top-k)


Given a context, compute the probability of each candidate next token and return the top-k.

This mirrors:
- autocomplete in constrained domains
- template suggestion systems
- command prediction in runbooks


In [18]:
def next_word_topk(context_tokens: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k_smooth: float = 0.5, top_k: int = 5):
    # Context length should be n-1
    V = len(vocab)
    context = tuple(context_tokens[-(n-1):]) if n > 1 else tuple()
    candidates = []
    for w in vocab:
        if w in {"<s>"}:
            continue
        ng = context + (w,)
        p = prob_addk(ng, ngram_counts, context_counts, V, k=k_smooth)
        candidates.append((w, p))
    candidates.sort(key=lambda x: -x[1])
    return candidates[:top_k]

# Bigram: context is 1 token
next_word_topk(["<s>"], n=2, ngram_counts=bi_counts, context_counts=bi_ctx, vocab=vocab, top_k=8)



[('<UNK>', 0.8076923076923077),
 ('after', 0.038461538461538464),
 ('error', 0.038461538461538464),
 ('not', 0.038461538461538464),
 ('</s>', 0.038461538461538464)]

## 8) Simple generation (bigram or trigram)


Text generation is not the main goal in SLMs, but it helps you verify:
- boundary handling
- smoothing
- OOV decisions

We will sample tokens until we hit `</s>`.


In [11]:

def sample_next(context_tokens: List[str], n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, k_smooth: float = 0.5):
    V = len(vocab)
    context = tuple(context_tokens[-(n-1):]) if n > 1 else tuple()
    words = [w for w in vocab if w != "<s>"]
    probs = []
    for w in words:
        ng = context + (w,)
        probs.append(prob_addk(ng, ngram_counts, context_counts, V, k=k_smooth))
    # Normalize
    s = sum(probs)
    probs = [p/s for p in probs]
    return random.choices(words, weights=probs, k=1)[0]

def generate(n: int, ngram_counts: Counter, context_counts: Counter, vocab: set, max_len: int = 20, k_smooth: float = 0.5):
    tokens = ["<s>"]*(n-1) if n > 1 else []
    out = []
    for _ in range(max_len):
        w = sample_next(tokens, n, ngram_counts, context_counts, vocab, k_smooth=k_smooth)
        if w == "</s>":
            break
        out.append(w)
        tokens.append(w)
    return " ".join(out)

for _ in range(5):
    print("BIGRAM:", generate(2, bi_counts, bi_ctx, vocab, max_len=18))


BIGRAM: after after not <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>
BIGRAM: <UNK> <UNK>
BIGRAM: <UNK> not <UNK> <UNK> <UNK> <UNK> error not error <UNK> not
BIGRAM: <UNK> <UNK> <UNK>
BIGRAM: error


## 9) Model comparison: effect of n and smoothing


Try different `k` values. Notes:
- `k=1.0` is Laplace smoothing (often too strong)
- smaller `k` (like 0.1 to 0.5) is often better

In real corpora, trigrams often beat bigrams, but require more data.


In [17]:
for k in [1.0, 0.5, 0.1, 0.01]:
    pp_bi_k  = evaluate_perplexity(test_texts, n=2, ngram_counts=bi_counts,  context_counts=bi_ctx,  vocab=vocab, k=k)
    pp_tri_k = evaluate_perplexity(test_texts, n=3, ngram_counts=tri_counts, context_counts=tri_ctx, vocab=vocab, k=k)
    print(f"k={k:>4}:  bigram PP={pp_bi_k:,.2f}   trigram PP={pp_tri_k:,.2f}")

k= 1.0:  bigram PP=1.95   trigram PP=2.09
k= 0.5:  bigram PP=1.87   trigram PP=1.96
k= 0.1:  bigram PP=1.79   trigram PP=1.80
k=0.01:  bigram PP=1.76   trigram PP=1.75


## Exercises (do these during lab)
1) Add 20 more realistic domain sentences to the corpus and re-run training/evaluation.  
2) Change `min_count` (OOV threshold) and explain how perplexity changes.  
3) Implement **backoff**: if a trigram is unseen, fall back to bigram; if unseen, fall back to unigram.  
4) Create a function that returns **top-5 next words** given a phrase like: `"user cannot"`.


In [22]:
# Exercise 1: Add 20 more realistic domain sentences to the corpus
new_sentences = [
    "database connection timeout during peak hours",
    "ssl certificate expired causing https failures",
    "memory leak in background service causes crash",
    "authentication fails after password policy update",
    "network bandwidth exceeded rate limit exceeded",
    "file permissions issue read only access denied",
    "server out of disk space cleanup required",
    "dns resolution failing for internal services",
    "backup failed corrupted checkpoint detected",
    "load balancer unhealthy instances removed",
    "cache invalidation causing stale data issues",
    "thread pool exhausted too many concurrent requests",
    "firewall rule blocking legitimate traffic",
    "disk io latency causing query timeouts",
    "session expired user needs to re login",
    "certificate verification failed untrusted issuer",
    "queue overflow messages dropped",
    "resource exhaustion cpu at 100 percent",
    "replication lag causing inconsistency",
    "plugin incompatible version mismatch detected"
]

# Combine with original corpus
expanded_corpus = corpus + new_sentences

# Re-split train/test
random.seed(42)
random.shuffle(expanded_corpus)
split_exp = int(0.75 * len(expanded_corpus))
train_texts_exp = expanded_corpus[:split_exp]
test_texts_exp = expanded_corpus[split_exp:]

print(f"Original corpus size: {len(corpus)}")
print(f"Expanded corpus size: {len(expanded_corpus)}")
print(f"New train/test split: {len(train_texts_exp)}/{len(test_texts_exp)}")
print("\nSample new sentences:")
for s in new_sentences[:5]:
    print(f"  - {s}")

Original corpus size: 14
Expanded corpus size: 34
New train/test split: 25/9

Sample new sentences:
  - database connection timeout during peak hours
  - ssl certificate expired causing https failures
  - memory leak in background service causes crash
  - authentication fails after password policy update
  - network bandwidth exceeded rate limit exceeded


In [23]:
# Retrain models on expanded corpus with original min_count
train_tokens_flat_exp = []
for t in train_texts_exp:
    train_tokens_flat_exp.extend(tokenize(t))

freq_exp = Counter(train_tokens_flat_exp)
vocab_exp = {w for w, c in freq_exp.items() if c >= min_count}
vocab_exp |= {"<UNK>", "<s>", "</s>"}

print(f"\nOriginal vocab size: {len(vocab)}")
print(f"Expanded vocab size: {len(vocab_exp)}")

# Retrain n-gram models
uni_counts_exp, uni_ctx_exp = train_ngram_counts(train_texts_exp, 1, vocab_exp)
bi_counts_exp, bi_ctx_exp = train_ngram_counts(train_texts_exp, 2, vocab_exp)
tri_counts_exp, tri_ctx_exp = train_ngram_counts(train_texts_exp, 3, vocab_exp)

# Evaluate on original test set (for comparison)
pp_bi_orig = evaluate_perplexity(test_texts, 2, bi_counts, bi_ctx, vocab, k=0.5)
pp_tri_orig = evaluate_perplexity(test_texts, 3, tri_counts, tri_ctx, vocab, k=0.5)

# Evaluate on new test set
pp_bi_exp = evaluate_perplexity(test_texts_exp, 2, bi_counts_exp, bi_ctx_exp, vocab_exp, k=0.5)
pp_tri_exp = evaluate_perplexity(test_texts_exp, 3, tri_counts_exp, tri_ctx_exp, vocab_exp, k=0.5)

print(f"Bigram perplexity (original corpus): {pp_bi_orig:.4f}")
print(f"Bigram perplexity (expanded corpus): {pp_bi_exp:.4f}")
print(f"Trigram perplexity (original corpus): {pp_tri_orig:.4f}")
print(f"Trigram perplexity (expanded corpus): {pp_tri_exp:.4f}")
print(f"\nNote: More training data generally reduces perplexity on held-out test data.")


Original vocab size: 6
Expanded vocab size: 18
Bigram perplexity (original corpus): 1.8712
Bigram perplexity (expanded corpus): 2.4029
Trigram perplexity (original corpus): 1.9553
Trigram perplexity (expanded corpus): 2.6369

Note: More training data generally reduces perplexity on held-out test data.


In [26]:
# Exercise 2: Change min_count and analyze OOV threshold effect
min_counts = [1, 2, 3, 5]
results_bigram = []
results_trigram = []

for mc in min_counts:
    # Build vocab with different min_count
    vocab_mc = {w for w, c in freq_exp.items() if c >= mc}
    vocab_mc |= {"<UNK>", "<s>", "</s>"}
    
    # Retrain models
    bi_counts_mc, bi_ctx_mc = train_ngram_counts(train_texts_exp, 2, vocab_mc)
    tri_counts_mc, tri_ctx_mc = train_ngram_counts(train_texts_exp, 3, vocab_mc)
    
    # Evaluate
    pp_bi = evaluate_perplexity(test_texts_exp, 2, bi_counts_mc, bi_ctx_mc, vocab_mc, k=0.5)
    pp_tri = evaluate_perplexity(test_texts_exp, 3, tri_counts_mc, tri_ctx_mc, vocab_mc, k=0.5)
    
    results_bigram.append(pp_bi)
    results_trigram.append(pp_tri)
    
    print(f"min_count={mc:2d}: vocab_size={len(vocab_mc):3d} | Bigram PP={pp_bi:.4f} | Trigram PP={pp_tri:.4f}")

print("\nAnalysis:")
print("- Higher min_count: fewer, higher-frequency tokens")
print("  → More OOV replacements → larger effect of <UNK> handling")
print("  → Can increase perplexity if aggressive (discards useful rare words)")
print("- Lower min_count: more tokens including rare ones")
print("  → More unique tokens = sparsity → potentially worse smoothing")
print("  → Can hurt generalization on test set")


min_count= 1: vocab_size=142 | Bigram PP=143.6633 | Trigram PP=147.9632
min_count= 2: vocab_size= 18 | Bigram PP=2.4029 | Trigram PP=2.6369
min_count= 3: vocab_size=  8 | Bigram PP=1.7361 | Trigram PP=1.7346
min_count= 5: vocab_size=  3 | Bigram PP=1.4690 | Trigram PP=1.4393

Analysis:
- Higher min_count: fewer, higher-frequency tokens
  → More OOV replacements → larger effect of <UNK> handling
  → Can increase perplexity if aggressive (discards useful rare words)
- Lower min_count: more tokens including rare ones
  → More unique tokens = sparsity → potentially worse smoothing
  → Can hurt generalization on test set


In [27]:
# Exercise 3: Implement backoff smoothing (Trigram -> Bigram -> Unigram)
def prob_backoff(ngram: Tuple[str, ...], 
                 uni_counts: Counter, uni_ctx: Counter,
                 bi_counts: Counter, bi_ctx: Counter,
                 tri_counts: Counter, tri_ctx: Counter,
                 vocab_size: int, k: float = 0.1,
                 backoff_weight: float = 0.75) -> float:
    
    if len(ngram) == 3:
        context_3 = ngram[:2]  # (w1, w2)
        w3 = ngram[2]
        
        # Check if trigram exists
        if tri_counts[ngram] > 0:
            # Use smoothed trigram probability
            return prob_addk(ngram, tri_counts, tri_ctx, vocab_size, k)
        else:
            # Back off to bigram
            bigram = (context_3[1], w3)  # (w2, w3)
            return backoff_weight * prob_backoff(bigram, uni_counts, uni_ctx, 
                                                 bi_counts, bi_ctx, tri_counts, tri_ctx,
                                                 vocab_size, k, backoff_weight)
    
    elif len(ngram) == 2:
        context_2 = ngram[:1]  # (w1,)
        w2 = ngram[1]
        
        # Check if bigram exists
        if bi_counts[ngram] > 0:
            return prob_addk(ngram, bi_counts, bi_ctx, vocab_size, k)
        else:
            # Back off to unigram
            unigram = (w2,)
            return backoff_weight * prob_backoff(unigram, uni_counts, uni_ctx,
                                                 bi_counts, bi_ctx, tri_counts, tri_ctx,
                                                 vocab_size, k, backoff_weight)
    
    elif len(ngram) == 1:
        # Unigram: always use smoothed probability
        return prob_addk(ngram, uni_counts, uni_ctx, vocab_size, k)
    
    return 1.0 / vocab_size  # Fallback uniform

def evaluate_perplexity_backoff(texts: List[str], n: int,
                               uni_counts: Counter, uni_ctx: Counter,
                               bi_counts: Counter, bi_ctx: Counter,
                               tri_counts: Counter, tri_ctx: Counter,
                               vocab: set, k: float = 0.1) -> float:
    """Evaluate perplexity using backoff smoothing"""
    V = len(vocab)
    total_log_prob = 0.0
    total_ngrams = 0
    
    for text in texts:
        toks = replace_oov(tokenize(text), vocab)
        toks = add_boundaries(toks, n)
        ngrams = get_ngrams(toks, n)
        for ng in ngrams:
            prob = prob_backoff(ng, uni_counts, uni_ctx, bi_counts, bi_ctx,
                               tri_counts, tri_ctx, V, k)
            total_log_prob += math.log(max(prob, 1e-10))  # Avoid log(0)
            total_ngrams += 1
    
    avg_log_prob = total_log_prob / total_ngrams
    perplexity = math.exp(-avg_log_prob)
    return perplexity

# Compare: Add-k vs Backoff
pp_addk_bi = evaluate_perplexity(test_texts_exp, 2, bi_counts_exp, bi_ctx_exp, vocab_exp, k=0.5)
pp_addk_tri = evaluate_perplexity(test_texts_exp, 3, tri_counts_exp, tri_ctx_exp, vocab_exp, k=0.5)

pp_backoff = evaluate_perplexity_backoff(test_texts_exp, 3, 
                                        uni_counts_exp, uni_ctx_exp,
                                        bi_counts_exp, bi_ctx_exp,
                                        tri_counts_exp, tri_ctx_exp,
                                        vocab_exp, k=0.1)

print(f"Add-k (k=0.5) Bigram Perplexity:     {pp_addk_bi:.4f}")
print(f"Add-k (k=0.5) Trigram Perplexity:    {pp_addk_tri:.4f}")
print(f"Backoff Trigram Perplexity:         {pp_backoff:.4f}")
print(f"\nBackoff: Falls back to bigram/unigram when trigram unseen")
print(f"Better for sparse domains where trigram data is limited.")

Add-k (k=0.5) Bigram Perplexity:     2.4029
Add-k (k=0.5) Trigram Perplexity:    2.6369
Backoff Trigram Perplexity:         2.1498

Backoff: Falls back to bigram/unigram when trigram unseen
Better for sparse domains where trigram data is limited.


In [28]:
# Exercise 4: Top-5 next words given a phrase (like "user cannot")
def get_top_k_next_words(phrase: str, n: int, k: int = 5,
                        bi_counts: Counter = None, bi_ctx: Counter = None,
                        tri_counts: Counter = None, tri_ctx: Counter = None,
                        vocab: set = None, smoothing: str = "addk", k_param: float = 0.5) -> List[Tuple[str, float]]:
    
    # Tokenize and apply OOV replacement
    tokens = replace_oov(tokenize(phrase), vocab)
    
    # Extract context for n-gram
    if n == 2:
        # Bigram: use last token as context
        context = (tokens[-1],) if tokens else ("<s>",)
        counts = bi_counts
        ctx_counts = bi_ctx
    elif n == 3:
        # Trigram: use last 2 tokens as context
        context = tuple(tokens[-2:]) if len(tokens) >= 2 else ("<s>", tokens[-1] if tokens else "<s>")
        counts = tri_counts
        ctx_counts = tri_ctx
    else:
        return []
    
    # Compute probabilities for all vocabulary items
    candidates = []
    V = len(vocab)
    
    for w in vocab:
        if w == "<s>":  # Don't predict sentence start
            continue
        
        ngram = context + (w,)
        
        if smoothing == "addk":
            prob = prob_addk(ngram, counts, ctx_counts, V, k_param)
        elif smoothing == "backoff" and n == 3:
            prob = prob_backoff(ngram, uni_counts_exp, uni_ctx_exp,
                              bi_counts_exp, bi_ctx_exp,
                              tri_counts_exp, tri_ctx_exp, V, k_param)
        else:
            prob = prob_addk(ngram, counts, ctx_counts, V, k_param)
        
        candidates.append((w, prob))
    
    # Sort by probability descending, return top-k
    candidates.sort(key=lambda x: x[1], reverse=True)
    return candidates[:k]

# Test on example phrases
test_phrases = [
    "user cannot",
    "server",
    "authentication fails",
    "request timeout"
]

for phrase in test_phrases:
    print(f"\nPhrase: '{phrase}'")
    print("Top-5 next words (Trigram model with Add-k smoothing):")
    
    top5 = get_top_k_next_words(phrase, n=3, k=5,
                              bi_counts=bi_counts_exp, bi_ctx=bi_ctx_exp,
                              tri_counts=tri_counts_exp, tri_ctx=tri_ctx_exp,
                              vocab=vocab_exp, smoothing="addk", k_param=0.5)
    
    for i, (word, prob) in enumerate(top5, 1):
        print(f"  {i}. {word:20s} (prob: {prob:.6f})")



Phrase: 'user cannot'
Top-5 next words (Trigram model with Add-k smoothing):
  1. login                (prob: 0.125000)
  2. at                   (prob: 0.125000)
  3. <UNK>                (prob: 0.125000)
  4. after                (prob: 0.041667)
  5. certificate          (prob: 0.041667)

Phrase: 'server'
Top-5 next words (Trigram model with Add-k smoothing):
  1. <UNK>                (prob: 0.609375)
  2. certificate          (prob: 0.046875)
  3. denied               (prob: 0.046875)
  4. expired              (prob: 0.046875)
  5. failed               (prob: 0.046875)

Phrase: 'authentication fails'
Top-5 next words (Trigram model with Add-k smoothing):
  1. <UNK>                (prob: 0.528409)
  2. </s>                 (prob: 0.176136)
  3. causing              (prob: 0.039773)
  4. after                (prob: 0.028409)
  5. error                (prob: 0.028409)

Phrase: 'request timeout'
Top-5 next words (Trigram model with Add-k smoothing):
  1. <UNK>                (prob: 0.