# Task 2 - N-gram Language Models with BPE

In this task we implement and evaluate n-gram language models based on a cleaned Shakespeare corpus.  
The corpus is split into training, validation, and test sets (predefined).  

Key requirements:
- Models over BPE subword tokens
- Unigram → Bigram → Trigram → 4-gram
- Intrinsic evaluation: Perplexity
- Bigram analysis across different smoothing constants *k*
- Laplace (add-one) smoothing
- Simple interpolation/backoff
- Extrinsic evaluation: sentence generation


## Data Loading and Tokenization

We start by loading the cleaned Shakespeare dataset and applying **Byte-Pair Encoding (BPE)**.  
The dataset is split into **train, validation, and test** to ensure consistent comparison across models.  
The number of BPE merges (*k*, e.g. 1600) determines the vocabulary size and granularity of subword units.


In [None]:
# =============================================================================
# TASK 2 — BLOCK 1: IMPORTS AND SETUP
# =============================================================================

import math
from typing import Optional, List, Tuple, Dict
import os, re, random
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
from typing import Iterable
import time, math
from typing import List, Optional, Dict, Any

# Essential constants and functions from Task 1
CORPUS_DIR = "Corpus"
GENERATED_DIR = "Generated_tokens"
TRAIN_CLEAN = os.path.join(CORPUS_DIR, "Shakespeare_clean_train.txt")
VALID_CLEAN = os.path.join(CORPUS_DIR, "Shakespeare_clean_valid.txt")
TEST_CLEAN = os.path.join(CORPUS_DIR, "Shakespeare_clean_test.txt")
K_LIST = [1000, 1200, 1400, 1600, 1800, 2000]
WORD_END = "</w>"

_wsre = re.compile(r"\s+")

# Task 2 specific tokens
EOS = "<eos>"
BOS = "<bos>"
random.seed(42)


## Utility Functions

This block defines helper functions that are reused across the notebook.  
They cover tasks such as:

- Handling tokenization and decoding  
- Managing probability calculations  
- Supporting text generation routines  

By centralizing these functions, the implementation of n-gram models remains cleaner and easier to extend.


In [2]:
# =============================================================================
# TASK 2 — BLOCK 2: UTILITY FUNCTIONS
# =============================================================================

def read_text(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

def words_from_text(text: str, lowercase: bool = True) -> List[str]:
    if lowercase:
        text = text.lower()
    return [w for w in _wsre.split(text.strip()) if w]

def load_merges(merges_path: str) -> List[Tuple[str, str]]:
    merges = []
    with open(merges_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                merges.append((parts[0], parts[1]))
    return merges

def apply_merges_to_word(word: str, merges: List[Tuple[str, str]]) -> List[str]:
    symbols = tuple(list(word) + [WORD_END])
    for a, b in merges:
        out = []
        i, L = 0, len(symbols)
        while i < L:
            if i < L-1 and symbols[i] == a and symbols[i+1] == b:
                out.append(a + b); i += 2
            else:
                out.append(symbols[i]); i += 1
        symbols = tuple(out)
    return list(symbols)

def tokenize_lines_with_merges(text: str, merges: List[Tuple[str, str]]) -> List[List[str]]:
    """Convert text to line-based token sequences for n-gram training."""
    token_lines: List[List[str]] = []
    for line in text.strip().splitlines():
        words = words_from_text(line)
        if not words:
            continue
        toks: List[str] = []
        for w in words:
            toks.extend(apply_merges_to_word(w, merges))
        toks.append(EOS)
        token_lines.append(toks)
    return token_lines


## N-gram Building Functions

- `add_bos_context(n)`: pads each line with `<bos>` tokens for n-gram context.
- `build_ngrams`: builds vocabulary, n-gram counts, and context counts from tokenized lines.  
These counts are the basis for probability estimation and perplexity.


In [3]:
# N-GRAM BUILDING FUNCTIONS

def add_bos_context(line_tokens: List[str], n: int) -> List[str]:
    """Add beginning-of-sentence tokens for n-gram context."""
    if n <= 1:
        return line_tokens
    return [BOS] * (n-1) + line_tokens

def build_ngrams(token_lines: List[List[str]], n: int):
    """Build n-gram counts and vocabulary from tokenized lines."""
    vocab = set()
    ngram_counts = Counter()
    context_counts = Counter()

    for line in token_lines:
        line = add_bos_context(line, n)
        vocab.update(line)
        for i in range(n-1, len(line)):
            context = tuple(line[i-n+1:i])
            token   = line[i]
            ngram_counts[context + (token,)] += 1
            context_counts[context]          += 1

    return ngram_counts, context_counts, sorted(vocab)

## N-gram Language Model

`NGramLM(n)` implements several estimators:
- **ML** (`p_ml`) and **Laplace** (`p_laplace`) smoothing.  
- **Linear interpolation** (`p_interpolated`) over orders 1…n (defaults to Laplace components).  
- **Katz-like backoff** (`p_backoff_katz`, simplified) and **stupid backoff** (`p_backoff`, not normalized).  
Models are chained recursively so lower-order distributions are available.


In [4]:
 # N-GRAM LANGUAGE MODEL CLASS

class NGramLM:
    """N-gram Language Model with multiple smoothing techniques."""

    def __init__(self, n: int, token_lines: List[List[str]]):
        assert n >= 1
        self.n = n
        self.ng_counts, self.ctx_counts, self.vocab = build_ngrams(token_lines, n)
        self.V = len(self.vocab)
        self.lower = NGramLM(n-1, token_lines) if n > 1 else None

    def p_ml(self, context: Tuple[str, ...], token: str) -> float:
        """Maximum Likelihood probability estimation."""
        if self.n == 1:
            total = sum(self.ng_counts.values())
            return self.ng_counts.get((token,), 0) / max(1, total)
        c = self.ctx_counts.get(context, 0)
        if c == 0:
            return 0.0
        return self.ng_counts.get(context + (token,), 0) / c

    def p_laplace(self, context: Tuple[str, ...], token: str) -> float:
        """Laplace (add-one) smoothing probability estimation."""
        if self.n == 1:
            num = self.ng_counts.get((token,), 0) + 1
            den = sum(self.ng_counts.values()) + self.V
            return num / den
        c   = self.ctx_counts.get(context, 0)
        num = self.ng_counts.get(context + (token,), 0) + 1
        den = c + self.V
        return num / max(1, den)

    def p_interpolated(self, context: Tuple[str, ...], token: str,
                       lambdas: Optional[List[float]] = None, use_laplace: bool = True) -> float:
        """Linear interpolation of different n-gram orders."""
        if lambdas is None:
            lambdas = [1.0/self.n] * self.n
        assert len(lambdas) == self.n

        prob = 0.0
        current_model = self

        for order in range(self.n, 0, -1):
            if order == 1:
                p = current_model.p_laplace((), token) if use_laplace else current_model.p_ml((), token)
            else:
                need = order - 1
                if len(context) >= need:
                    ctx = context[-need:]
                else:
                    padding_needed = need - len(context)
                    ctx = tuple([BOS] * padding_needed) + context

                p = current_model.p_laplace(ctx, token) if use_laplace else current_model.p_ml(ctx, token)

            prob += lambdas[order-1] * p

            if current_model.lower is not None:
                current_model = current_model.lower

        return prob

    def p_backoff_katz(self, context: Tuple[str, ...], token: str) -> float:
        """Simplified Katz backoff (without Good-Turing discounting)."""
        if self.n == 1:
            return self.p_laplace((), token)

        need = self.n - 1
        if len(context) >= need:
            ctx = context[-need:]
        else:
            padding_needed = need - len(context)
            ctx = tuple([BOS] * padding_needed) + context

        c_ctx = self.ctx_counts.get(ctx, 0)
        c_ng = self.ng_counts.get(ctx + (token,), 0)

        if c_ng > 0:
            # Discounted ML estimate (simplified)
            discount = 0.75  # Simple absolute discounting
            prob_discounted = max(c_ng - discount, 0) / c_ctx
            return prob_discounted
        else:
            # Backoff with alpha weight
            alpha = 0.4  # Simplified backoff weight
            return alpha * self.lower.p_backoff_katz(context, token)

    def p_backoff(self, context: Tuple[str, ...], token: str) -> float:
        """Stupid Backoff (not a true probability distribution)."""
        if self.n == 1:
            return self.p_laplace((), token)

        need = self.n - 1
        if len(context) >= need:
            ctx = context[-need:]
        else:
            padding_needed = need - len(context)
            ctx = tuple([BOS] * padding_needed) + context

        c_ctx = self.ctx_counts.get(ctx, 0)
        c_ng = self.ng_counts.get(ctx + (token,), 0)

        if c_ctx > 0 and c_ng > 0:
            return c_ng / c_ctx  # ML estimate if seen
        else:
            # Backoff with penalty
            return 0.4 * self.lower.p_backoff(context, token)

## Perplexity

`perplexity(model, token_lines, mode)` computes PPL over BPE tokens, resetting context at `<eos>`.  
Supported modes: `ml`, `laplace`, `interp`, `backoff`, `katz`.  
**Note:** Stupid backoff is not a proper probability distribution—treat its PPL as a *relative* score.


In [5]:
# PERPLEXITY CALCULATION

def flatten_for_eval(token_lines: List[List[str]]) -> List[str]:
    """Flatten token lines for evaluation (BOS handling done in perplexity)."""
    flat: List[str] = []
    for line in token_lines:
        flat.extend(line)
    return flat

def perplexity(model: NGramLM, token_lines: List[List[str]], mode: str = "laplace",
               lambdas: Optional[List[float]] = None) -> float:
    """Calculate perplexity following slide methodology exactly."""
    log_prob_sum = 0.0
    count = 0

    for line in token_lines:
        # Initialize context with BOS tokens at start of each sentence
        context = [BOS] * (model.n - 1) if model.n > 1 else []

        # Process each token in the line (including EOS)
        for token in line:
            # Create context tuple for probability calculation
            ctx = tuple(context[-(model.n-1):]) if model.n > 1 else tuple()

            # Calculate probability based on smoothing method
            if mode == "ml":
                p = model.p_ml(ctx, token)
            elif mode == "laplace":
                p = model.p_laplace(ctx, token)
            elif mode == "interp":
                p = model.p_interpolated(ctx, token, lambdas=lambdas, use_laplace=True)
            elif mode == "backoff":
                p = model.p_backoff(ctx, token)
            elif mode == "katz":
                p = model.p_backoff_katz(ctx, token)
            else:
                raise ValueError("mode must be one of: ml, laplace, interp, backoff, katz")

            # Add log probability
            if p > 0:
                log_prob_sum += math.log(p)
            else:
                log_prob_sum += float('-inf')
            count += 1

            # Update context window
            context = (context + [token])[-(model.n - 1):] if model.n > 1 else context

            # Reset context at sentence boundary
            if token == EOS:
                context = [BOS] * (model.n - 1) if model.n > 1 else []

    # Calculate final perplexity
    avg_log_prob = log_prob_sum / count if count > 0 else float('-inf')
    return math.exp(-avg_log_prob)


## Data Loading Functions

- `find_merges_file(k)`: locates a merges file in `Generated_tokens/` (flexible naming).  
- `load_token_lines_for_k(k)`: loads splits, applies merges, and returns tokenized lines for train/valid/test.  
Per the task, the **validation set is used for tuning interpolation weights, not for choosing `k`**.


In [6]:
# DATA LOADING FUNCTIONS

def find_merges_file(k: int, verbose: bool = True) -> str:
    """Find BPE merges file with flexible naming conventions."""
    candidates = [
        os.path.join(GENERATED_DIR, f"bpe_merges with k = {k}.txt"),
        os.path.join(GENERATED_DIR, f"standard_bpe_merges_k{k}.txt"),
        os.path.join(GENERATED_DIR, f"aggressive_clean_bpe_merges_k{k}.txt"),
        os.path.join(GENERATED_DIR, f"bpe_merges_k{k}.txt"),
        os.path.join(GENERATED_DIR, f"bpe_merges_k{k}_webtext_clean.txt"),
    ]
    for path in candidates:
        if os.path.exists(path):
            if verbose:
                print(f"[Found] Using merges file: {path}")
            return path
    raise FileNotFoundError(f"No merges file found for k={k}. Tried: {candidates}")

def load_token_lines_for_k(k: int):
    """Load and tokenize train/validation/test data for given k."""
    merges_path = find_merges_file(k, verbose=True)
    merges = load_merges(merges_path)

    tr_text = read_text(TRAIN_CLEAN)
    va_text = read_text(VALID_CLEAN)
    te_text = read_text(TEST_CLEAN)

    tr_tok = tokenize_lines_with_merges(tr_text, merges)
    va_tok = tokenize_lines_with_merges(va_text, merges)
    te_tok = tokenize_lines_with_merges(te_text, merges)
    return merges, tr_tok, va_tok, te_tok

## Training & Evaluation Utilities

- `grid_simplex_lambdas(n, step)`: enumerates interpolation weights that sum to 1.  
- `train_and_eval_for_k(k, n_max, tune_interp)`: trains n=1…n_max; reports PPL for ML/Laplace/Backoff;  
  tunes **interpolation weights on the validation set** and then evaluates on test.  
- `bigram_vs_k(k_list)`: evaluates bigram performance across different BPE merge counts.


In [8]:
# TRAINING AND EVALUATION FUNCTIONS

def grid_simplex_lambdas(n: int, step: float = 0.2) -> List[List[float]]:
    """Generate lambda weight combinations that sum to 1.0."""
    if n == 1:
        return [[1.0]]

    grids = []
    def rec(prefix, remaining, slots):
        if slots == 1:
            grids.append(prefix + [round(remaining, 10)])
            return
        t = 0.0
        while t <= remaining + 1e-9:
            rec(prefix + [round(t,10)], round(remaining - t,10), slots-1)
            t = round(t + step, 10)

    rec([], 1.0, n)
    return [g for g in grids if abs(sum(g) - 1.0) < 1e-6]

def train_and_eval_for_k(k: int, n_max: int = 4, tune_interp: bool = True) -> pd.DataFrame:
    """Train and evaluate n-gram models for given BPE vocabulary size k."""
    print(f"\n=== Processing k={k} ===")
    try:
        _, tr_tok, va_tok, te_tok = load_token_lines_for_k(k)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return pd.DataFrame()

    results = []
    for n in range(1, n_max+1):
        print(f"Training {n}-gram model...")
        lm = NGramLM(n, tr_tok)

        # Basic evaluations (ML, Laplace)
        pp_valid_ml = perplexity(lm, va_tok, mode="ml")
        pp_valid_laplace = perplexity(lm, va_tok, mode="laplace")
        pp_test_ml = perplexity(lm, te_tok, mode="ml")
        pp_test_laplace = perplexity(lm, te_tok, mode="laplace")

        results.extend([
            {"k":k, "n":n, "mode":"ml", "lambdas":"N/A", "split":"valid", "perplexity":pp_valid_ml},
            {"k":k, "n":n, "mode":"laplace", "lambdas":"N/A", "split":"valid", "perplexity":pp_valid_laplace},
            {"k":k, "n":n, "mode":"ml", "lambdas":"N/A", "split":"test", "perplexity":pp_test_ml},
            {"k":k, "n":n, "mode":"laplace", "lambdas":"N/A", "split":"test", "perplexity":pp_test_laplace},
        ])

        # Interpolation with lambda tuning on validation set
        if tune_interp and n > 1:
            print(f"Tuning interpolation for {n}-gram...")
            best_pp, best_lmb = float("inf"), None
            for lambdas in grid_simplex_lambdas(n=n, step=0.2):
                pp = perplexity(lm, va_tok, mode="interp", lambdas=lambdas)
                if pp < best_pp:
                    best_pp, best_lmb = pp, lambdas

            if best_lmb is not None:
                pp_test_interp = perplexity(lm, te_tok, mode="interp", lambdas=best_lmb)
                results.extend([
                    {"k":k, "n":n, "mode":"interp", "lambdas":str(best_lmb), "split":"valid", "perplexity":best_pp},
                    {"k":k, "n":n, "mode":"interp", "lambdas":str(best_lmb), "split":"test", "perplexity":pp_test_interp},
                ])

        # Backoff evaluation (Stupid Backoff implementation)
        pp_valid_backoff = perplexity(lm, va_tok, mode="backoff")
        pp_test_backoff = perplexity(lm, te_tok, mode="backoff")
        results.extend([
            {"k":k, "n":n, "mode":"backoff", "lambdas":"N/A", "split":"valid", "perplexity":pp_valid_backoff},
            {"k":k, "n":n, "mode":"backoff", "lambdas":"N/A", "split":"test", "perplexity":pp_test_backoff},
        ])

    return pd.DataFrame(results)

def bigram_vs_k(k_list: List[int]) -> pd.DataFrame:
    """Analyze bigram performance across different k values."""
    rows = []
    for k in k_list:
        merges_path = os.path.join(GENERATED_DIR, f"bpe_merges with k = {k}.txt")
        if not os.path.exists(merges_path):
            print(f"[Skip] No merges for k={k} at {merges_path}")
            continue
        dfk = train_and_eval_for_k(k=k, n_max=2, tune_interp=True)  # n=2 only
        rows.append(dfk[dfk["n"] == 2])
    return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()


## Text Generation (Extrinsic Evaluation)

- **Encoding/decoding:** `bpe_encode_words`, `bpe_decode_to_words` convert between words and BPE tokens.  
- **Fallbacks:** `_unigram_fallback('most'|'avg')` handles unseen contexts.  
- **Decoding:** `_next_token_argmax_or_sample` supports argmax or temperature sampling.  
- **Driver:** `generate_sentence(...)` continues from a prompt until `<eos>` or a word budget is reached.


In [None]:

# TEXT GENERATION (Extrinsic Evaluation)

def bpe_encode_words(words: Iterable[str], merges: List[Tuple[str, str]]) -> List[str]:
    """Lowercase and convert words to BPE subwords (including </w>)."""
    toks: List[str] = []
    for w in (w.lower() for w in words):
        toks.extend(apply_merges_to_word(w, merges))
    return toks

def bpe_decode_to_words(token_stream: List[str]) -> List[str]:
    """Convert BPE subword list back to words (split at </w>)."""
    words: List[str] = []
    buf: List[str] = []
    for t in token_stream:
        if t == EOS:
            break
        buf.append(t)
        if t.endswith(WORD_END):  # word boundary
            chars: List[str] = []
            for sub in buf:
                if sub.endswith(WORD_END):
                    chars.extend(list(sub[:-len(WORD_END)]))
                else:
                    chars.extend(list(sub))
            words.append("".join(chars))
            buf = []
    # Handle any leftover fragments
    if buf:
        chars = []
        for sub in buf:
            if sub.endswith(WORD_END):
                chars.extend(list(sub[:-len(WORD_END)]))
            else:
                chars.extend(list(sub))
        if chars:
            words.append("".join(chars))
    return words

def _get_unigram_model(model: NGramLM) -> NGramLM:
    """Return the base unigram model (n=1) from an n-gram chain."""
    m = model
    while m.lower is not None:
        m = m.lower
    return m

def _unigram_fallback(model: NGramLM, strategy: str = "most") -> str:
    """
    Fallback token choice when distribution is empty:
      - 'most': most frequent unigram (excluding BOS/EOS)
      - 'avg' : token whose Laplace probability is closest to average unigram probability
    """
    um = _get_unigram_model(model)
    total = sum(um.ng_counts.values())
    V = um.V if hasattr(um, "V") else len(um.vocab)

    if strategy == "most":
        best_tok, best_count = None, -1
        for (tok_tuple, cnt) in um.ng_counts.items():
            t = tok_tuple[0]
            if t in (BOS, EOS):
                continue
            if cnt > best_count:
                best_count = cnt
                best_tok = t
        return best_tok if best_tok is not None else EOS

    else:  # 'avg' strategy
        avg_p = 1.0 / V  # Laplace average probability
        best_tok, best_gap = None, float("inf")
        for (tok_tuple, cnt) in um.ng_counts.items():
            t = tok_tuple[0]
            if t in (BOS, EOS):
                continue
            p = (cnt + 1) / (total + V)
            gap = abs(p - avg_p)
            if gap < best_gap:
                best_gap, best_tok = gap, t
        return best_tok if best_tok is not None else EOS

def _next_token_argmax_or_sample(dist: Dict[str, float],
                                 temperature: float = 1.0,
                                 sample: bool = False) -> str:
    """Choose next token from a probability distribution (argmax or temperature sampling)."""
    if not dist:
        return None
    if not sample:
        return max(dist.items(), key=lambda kv: kv[1])[0]
    tokens, probs = zip(*dist.items())
    probs = list(probs)
    if temperature <= 0:
        return tokens[int(max(range(len(probs)), key=lambda i: probs[i]))]
    if temperature != 1.0:
        probs = [p ** (1.0/temperature) for p in probs]
        Z = sum(probs) or 1.0
        probs = [p / Z for p in probs]
    return random.choices(tokens, weights=probs, k=1)[0]

def generate_sentence(
    k: int,
    n: int,
    prompt_words: List[str],
    mode: str = "interp",
    lambdas: Optional[List[float]] = None,
    max_new_words: int = 30,
    temperature: float = 1.0,
    sample: bool = False,
    fallback_strategy: str = "most",   # 'most' or 'avg'
) -> str:
    """
    Extrinsic evaluation: generate continuation from a prompt using an n-gram model.
    - mode: 'ml' | 'laplace' | 'interp' | 'backoff' | 'katz'
    - sample=False → argmax; sample=True → temperature sampling
    - fallback: unigram choice ('most' or 'avg') if no distribution is found
    - Stops when EOS appears or max_new_words is reached
    """
    merges, tr_tok, _, _ = load_token_lines_for_k(k)
    lm = NGramLM(n, tr_tok)

    # 1) Encode prompt to BPE tokens
    prompt_toks = bpe_encode_words(prompt_words, merges)

    # 2) Initialize context (BOS*(n-1) + prompt)
    context: List[str] = ([BOS] * (n - 1)) + prompt_toks if n > 1 else prompt_toks[:]

    out_tokens: List[str] = []
    words_generated = 0

    def _dist(ctx_tokens: List[str]) -> Dict[str, float]:
        ctx = tuple(ctx_tokens[-(lm.n - 1):]) if lm.n > 1 else tuple()
        probs: Dict[str, float] = {}
        for tok in lm.vocab:
            if tok == BOS:
                continue
            if mode == "ml":
                p = lm.p_ml(ctx, tok)
            elif mode == "laplace":
                p = lm.p_laplace(ctx, tok)
            elif mode == "interp":
                p = lm.p_interpolated(ctx, tok, lambdas=lambdas, use_laplace=True)
            elif mode == "backoff":
                p = lm.p_backoff(ctx, tok)
            elif mode == "katz":
                p = lm.p_backoff_katz(ctx, tok)
            else:
                raise ValueError("Invalid mode.")
            if p > 0:
                probs[tok] = p
        Z = sum(probs.values())
        if Z > 0:
            for t in probs:
                probs[t] /= Z
        return probs

    while words_generated < max_new_words:
        dist = _dist(context)

        if not dist:
            next_tok = _unigram_fallback(lm, strategy=fallback_strategy)
        else:
            next_tok = _next_token_argmax_or_sample(dist, temperature=temperature, sample=sample)

        out_tokens.append(next_tok)
        context.append(next_tok)

        if next_tok == EOS:
            break
        if next_tok.endswith(WORD_END):
            words_generated += 1

    return " ".join(bpe_decode_to_words(out_tokens))



## Model Preparation

`prepare_models(k, n_max)` trains and caches n-gram models (1…n_max) for a fixed `k`,  
and prints Laplace perplexities on train/valid/test—used by the generation suite to avoid retraining.


In [16]:
def prepare_models(k: int, n_max: int = 4) -> Tuple[List[Tuple[str, str]], Dict[int, NGramLM]]:
    """
    Load merges and train n-gram models up to order n_max.
    Also compute perplexities on validation and test.
    Returns:
      merges: BPE merges list
      models: dict {n: NGramLM}
    """
    # Load merges + tokenized splits
    merges, tr_tok, va_tok, te_tok = load_token_lines_for_k(k)
    models = {}

    for n in range(1, n_max + 1):
        print(f"\n[prepare_models] Training {n}-gram model for k={k}")
        model = NGramLM(n, tr_tok)

       # Evaluate perplexities
        ppl_train = perplexity(model, tr_tok, mode="laplace")
        ppl_valid = perplexity(model, va_tok, mode="laplace") if va_tok else None
        ppl_test  = perplexity(model, te_tok, mode="laplace") if te_tok else None

        print(f"[n={n}] train ppl={ppl_train:.2f} | valid ppl={ppl_valid:.2f} | test ppl={ppl_test:.2f}")

        models[n] = model

    return merges, models


## Helpers for Reporting

Lightweight analysis for generated text:
- Diversity: `distinct-1/2`.  
- Repetition: adjacent duplicates and longest repeat run.  
- Optional self-scoring hook (`score_sequence_logprob`) to compute avg NLL / PPL if available.  
Also maps external mode names (e.g., `"simple"` → stupid backoff) and optionally uses a fast generator.


In [17]:
# =============================== Helpers ======================================
def _clip(s: str, max_chars: int = 160) -> str:
    """We trim and single-line the sample text so the console stays readable."""
    s = (s or "").strip().replace("\n", " ")
    return s if len(s) <= max_chars else s[:max_chars - 3] + "..."

def _ensure_lambdas(n: int, lambdas: Optional[List[float]], mode: str):
    """
    We provide interpolation weights:
      - if explicit weights are given, we use them,
      - else if default_lambdas_for(n) exists, we use that,
      - else we fall back to uniform weights.
    """
    if mode != "interp":
        return None
    if lambdas is not None:
        return lambdas
    f = globals().get("default_lambdas_for", None)
    if callable(f):
        return f(n)
    return [1.0 / n] * n

def _distinct_ratio(seq, n=1):
    """We compute distinct-n / total-n ratio as a simple diversity metric."""
    if n == 1:
        total = len(seq)
        return (len(set(seq)) / total) if total else 0.0
    ngrams = list(zip(*[seq[i:] for i in range(n)]))
    total = len(ngrams)
    return (len(set(ngrams)) / total) if total else 0.0

def _max_repeat_run(seq):
    """We measure the longest run of identical consecutive tokens."""
    if not seq:
        return 0
    mx = cur = 1
    for i in range(1, len(seq)):
        if seq[i] == seq[i - 1]:
            cur += 1
            mx = max(mx, cur)
        else:
            cur = 1
    return mx

def _summarize_text(txt: str) -> Dict[str, Any]:
    """We summarize the generated text with lightweight quality indicators."""
    tokens = txt.strip().split()
    n_tok = len(tokens)
    d1 = _distinct_ratio(tokens, 1)
    d2 = _distinct_ratio(tokens, 2)
    rep_pairs = sum(1 for i in range(1, n_tok) if tokens[i] == tokens[i - 1])
    rep_ratio = rep_pairs / max(1, (n_tok - 1))
    return {
        "len_words": n_tok,
        "distinct1": round(d1, 4),
        "distinct2": round(d2, 4),
        "repeat_ratio": round(rep_ratio, 4),
        "max_repeat_run": _max_repeat_run(tokens),
        "ends_with_eos": (tokens[-1] == "<eos>") if tokens else False,
    }

def _maybe_score_ppl(txt: str, merges, models, n: int, mode: str, lambdas: Optional[List[float]]):
    """
    We optionally self-score the generated text if a scorer is available:
    expects score_sequence_logprob(tokens, merges, models, n, mode, lambdas) → log p.
    Returns (avg_nll, ppl) or (None, None) if not available.
    """
    scorer = globals().get("score_sequence_logprob", None)
    if not callable(scorer):
        return None, None
    tokens = txt.strip().split()
    if not tokens:
        return None, None
    try:
        logp = scorer(tokens, merges=merges, models=models, n=n, mode=mode, lambdas=lambdas)
        avg_nll = -logp / len(tokens)
        ppl = math.exp(avg_nll)
        return round(avg_nll, 4), round(ppl, 4)
    except Exception:
        return None, None

def _map_mode(mode: str) -> str:
    """
    We accept external mode names and map them to the implementation:
      - "simple" → "backoff" (stupid backoff)
      - "interp", "ml", "laplace" pass through
    We do not use Katz here.
    """
    if mode == "simple":
        return "backoff"
    return mode

def _call_generate(merges, models, k: int, n: int,
                   prompt_words: List[str], mode: str,
                   lambdas: Optional[List[float]], sample: bool,
                   temperature: float, max_new_words: int = 20,
                   fallback_strategy: str = "most") -> str:
    """
    We prefer generate_sentence_fast(...) if present (uses prebuilt models),
    and fall back to generate_sentence(...) otherwise.
    """
    impl_mode = _map_mode(mode)
    gen_fast = globals().get("generate_sentence_fast", None)
    if callable(gen_fast):
        return gen_fast(
            merges, models, k, n,
            prompt_words=prompt_words, mode=impl_mode,
            lambdas=lambdas, sample=sample, temperature=temperature,
            fallback_strategy=fallback_strategy, max_new_words=max_new_words
        )
    return generate_sentence(
        k=k, n=n, prompt_words=prompt_words, mode=impl_mode,
        lambdas=lambdas, max_new_words=max_new_words,
        sample=sample, temperature=temperature,
        fallback_strategy=fallback_strategy
    )

# =============================== Main suite ===================================
def run_generation_suite(k: int = 1600,
                         temperatures = (0.5, 0.7, 1.0),
                         max_new_words: int = 20,
                         show_text: bool = True,
                         text_max_chars: int = 160) -> List[Dict[str, Any]]:
    """
    We run a compact generation & reporting suite:
      1) Two fixed examples for quick inspection,
      2) A small grid over (prompt, mode, n) × {argmax, sample} × temperatures,
      3) Speed + lightweight quality metrics (len, distinct-1/2, repetition).
    If show_text=True, we also print each generated sentence (truncated).
    Returns a list of dicts (ready for DataFrame/CSV).
    """
    # We assume prepare_models(k) exists and returns (merges, {order: NGramLM})
    merges, models = prepare_models(k)
    results = []

    # (A) Two fixed examples we show up front
    special_tests = [
        dict(prompt=["to","be","or","not"], mode="interp",  n=2, lambdas=[0.2, 0.8], sample=False, temperature=1.0, label="spec_interp_argmax"),
        dict(prompt=["my","lord"],          mode="simple",  n=3, lambdas=None,       sample=True,  temperature=0.7, label="spec_simple_sample"),
    ]

    # (B) Grid similar to the earlier quick test
    grid_tests = [
        dict(prompt=["to","be","or"],     mode="simple",  n=2, lambdas=None),
        dict(prompt=["the","king","of"],  mode="laplace", n=2, lambdas=None),
        dict(prompt=["fair","is","foul"], mode="interp",  n=3, lambdas=None),
    ]

    def run_one(test_cfg: Dict[str, Any], temperature: float, sample: bool, label: str):
        prompt = test_cfg["prompt"]
        mode   = test_cfg["mode"]        # we keep the external label ("simple" etc.)
        n      = test_cfg["n"]
        lamb   = _ensure_lambdas(n, test_cfg.get("lambdas"), mode)

        t0 = time.time()
        txt = _call_generate(
            merges, models, k, n,
            prompt_words=prompt, mode=mode, lambdas=lamb,
            sample=sample, temperature=temperature,
            max_new_words=max_new_words,
            fallback_strategy=test_cfg.get("fallback_strategy", "most"),
        )
        dt = time.time() - t0

        summary = _summarize_text(txt)
        avg_nll, ppl = _maybe_score_ppl(txt, merges, models, n, mode, lamb)

        rec = {
            "label": label,
            "prompt": " ".join(prompt),
            "mode": mode,  # external name ("simple" not Katz)
            "n": n,
            "sample": sample,
            "temperature": temperature,
            "lambdas": lamb,
            "text": txt,
            "gen_time_sec": round(dt, 4),
            "tok_per_sec_est": round(summary["len_words"]/dt, 2) if dt > 0 else None,
            "len_words": summary["len_words"],
            "distinct1": summary["distinct1"],
            "distinct2": summary["distinct2"],
            "repeat_ratio": summary["repeat_ratio"],
            "max_repeat_run": summary["max_repeat_run"],
            "ends_with_eos": summary["ends_with_eos"],
            "avg_nll": avg_nll,
            "ppl_self": ppl
        }
        results.append(rec)

        # We print one example line per run (truncated) so the instructor sees outputs.
        if show_text:
            dec = "sample" if sample else "argmax"
            print(f"\n[{label}] mode={mode} | n={n} | T={temperature} | {dec}")
            print("  " + _clip(txt, text_max_chars))

    # (A) run the two fixed examples
    for cfg in special_tests:
        run_one(cfg, cfg["temperature"], cfg["sample"], cfg["label"])

    # (B) grid: argmax and sample across temperatures
    for cfg in grid_tests:
        for T in temperatures:
            run_one(cfg, T, False, f"grid_argmax_T{T}")
        for T in temperatures:
            run_one(cfg, T, True,  f"grid_sample_T{T}")

    # Short  summary
    print("\n=== (label | mode | n | temp | sample | len | d1 | d2 | rep) ===")
    for r in results:
        print(f"{r['label']:18s} | {r['mode']:7s} | {r['n']} | {r['temperature']:>3} | "
              f"{'S' if r['sample'] else 'A'} | {r['len_words']:3d} | "
              f"{r['distinct1']:.2f} | {r['distinct2']:.2f} | {r['repeat_ratio']:.2f}")

    return results

if __name__ == "__main__":
    # We run with k=1600 by default; tweak temperatures/max_new_words as needed.
    _ = run_generation_suite(k=1000)


[Found] Using merges file: Generated_tokens\bpe_merges with k = 1000.txt

[prepare_models] Training 1-gram model for k=1000
[n=1] train ppl=282.43 | valid ppl=278.58 | test ppl=277.44

[prepare_models] Training 2-gram model for k=1000
[n=2] train ppl=72.68 | valid ppl=80.55 | test ppl=80.21

[prepare_models] Training 3-gram model for k=1000
[n=3] train ppl=164.65 | valid ppl=217.90 | test ppl=221.55

[prepare_models] Training 4-gram model for k=1000
[n=4] train ppl=306.63 | valid ppl=464.27 | test ppl=479.42
[Found] Using merges file: Generated_tokens\bpe_merges with k = 1000.txt

[spec_interp_argmax] mode=interp | n=2 | T=1.0 | argmax
  to the  and the  and the  and the  and the  and the  and the
[Found] Using merges file: Generated_tokens\bpe_merges with k = 1000.txt

[spec_simple_sample] mode=simple | n=3 | T=0.7 | sample
  othello i do estate upon me. desdemona i am cruel cassio. othello o ialack the couples that lives in the
[Found] Using merges file: Generated_tokens\bpe_merges w

## Results

We evaluated continuations generated with **k = 1600 BPE merges** across decoding regimes—stupid backoff, Laplace, and interpolation (n = 3).  
Each run produced 20 tokens.

### Diversity
- **Sampling** yielded high diversity:  
  - distinct-1 ≈ 0.90–1.00  
  - distinct-2 ≈ 0.95–1.00  
  - Example: simple backoff, T = 0.7 → distinct-1 = 0.95, distinct-2 = 1.00
- Same pattern for Laplace (n = 2) and Interpolation (n = 3) under sampling.

### Degeneracy
- **Argmax** reduced diversity and caused loops:  
  - Laplace bigrams: distinct-1 ≈ 0.20, distinct-2 ≈ 0.21  
  - Simple bigrams: distinct-1 ≈ 0.30, distinct-2 ≈ 0.32  
  - Interpolation (n = 3) improved slightly (≈ 0.35/0.37) but still repeated phrases.

### Repetition Metric
- Adjacent token repetition = 0.00 in all cases.  
- However, phrase-level loops (e.g., “the matters of the matters of …”) were common → metric underestimates degeneracy.

### Qualitative Examples
- Argmax: repeated clause fragments (e.g., “and the moor: I am not …”).  
- Sampling (T = 0.7): varied, syntactically richer lines (e.g., “polonius give him directions …”).

**Conclusion:**  
Deterministic decoding collapses onto frequent n-grams, reducing diversity.  
Stochastic sampling restores lexical richness without immediate repetition.  
Interpolation alleviates loops under argmax but does not eliminate them.
