# Programming Assignment 3: N-gram Language Models
## Building Language Models from Scratch on Tiny Shakespeare

**Objectives:**
- Understand text preprocessing and tokenization for language modeling
- Implement a **Unigram** model (frequency-based word generation)
- Implement a **Bigram** model (next-word prediction with one word of context)
- Generalize to an **N-gram** model with Laplace smoothing
- Evaluate language models using **perplexity**
- Generate Shakespeare-like text from your trained models

---
## Background: What is a Language Model?

### The Big Picture

A **language model** assigns probabilities to sequences of words. Given a sentence, a language model tells you how likely that sentence is. Given the start of a sentence, it can predict what comes next.

This task is foundational for most modern Natural Langauge Processing (NLP) tasks: from phone autocomplete to ChatGPT. Before neural networks took over, **N-gram models** were the dominant approach, and they remain an excellent way to understand the core concepts.

### From Counting to Predicting

The simplest idea: **estimate word probabilities by counting**. If you read all of Shakespeare and count how often each word appears, you get a basic language model. The word "the" appears very frequently, while "zephyr" is rare. This is a **unigram model** which as the probability distribution:

$$P(w_i) = \frac{\text{count}(w_i)}{\sum_{w}\text{count}(w)}$$

In english, the probability of predicting the next word to be $w_i$ is the number of occurances of word $w_i$ divided by the total number of words in our corpus/vocabulary. 

But language has *structure*. Justing counting the occurances of each word completely removes the context that is embedded in language. For example, after "to", the word "be" is much more likely than "the". A **bigram model** captures this by looking at pairs of consecutive words:

$$P(\text{be} | \text{to}) = \frac{\text{count}(\text{"to be"})}{\text{count}(\text{"to"})}$$

In this bigram model example, the probability that the next word is "be" given that the previous word is "to" is the number of occurances in which we found the "be" to be after "to" in the corpus divided by the number of occurnaces of the word "to".

We can generalize to **N-grams** — conditioning on the previous $(N-1)$ words:

$$P(w_i | w_{i-n+1}, \ldots, w_{i-1}) = \frac{\text{count}(w_{i-n+1}, \ldots, w_i)}{\text{count}(w_{i-n+1}, \ldots, w_{i-1})}$$

Where the model predicts the next word to be $w_i$ given the previous N-1 words to be the sum of the observing the N length sequence normalized by observing the N-1 sequence (excluding $w_i$).

### The Sparsity Problem

While we can create Here's the catch: as $N$ grows, most possible N-grams **never appear** in the training data. Shakespeare never wrote "thou doth yeet" but should our model assign it probability *zero*? That seems too harsh. One solution is smoothing, which effectively redistributes a tiny probability to unseen events.

### Evaluating Language Models: Perplexity

How do we measure if one language model is better than another? We use **perplexity**:

$$\text{Perplexity} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | \text{context})\right)$$

**Note:** Context may vary depending on which n-gram model used. For unigram, there is no context and the term decomposes to $\log P(w_i)$. For bigrams, the context is simply $w_{i-1}$ leading to the term becoming  $\log P(w_i | w_{i-1})$. For n-grams where $n > 2$, the context is the n-1 words which came before it.

**Intuition:** Perplexity measures how "surprised" the model is by the test data. Lower perplexity = better model. A perplexity of $k$ roughly means the model is as uncertain as if it were choosing uniformly among $k$ words at each step.

Let's get started!

## Table of Contents

1. [Part 1: Text Preprocessing and Tokenization](#part1)
2. [Part 2: Unigram Language Model](#part2)
3. [Part 3: Bigram Language Model](#part3)
4. [Part 4 (Extra Credit): General N-gram Model](#part4)

**Note:** Some parts may be already implemented which will have **(Done for you)** and would be worth **0 points**.

## Setup

Run the cells below to install dependencies and import libraries. You only need to run the `pip install` cell once.

In [None]:
# Install dependencies (run once and then comment out)
!pip install numpy matplotlib

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import re
import random
import math
from collections import Counter, defaultdict
from typing import List, Dict, Tuple, Optional
random.seed(42)
np.random.seed(42)

---
<a id="part1"></a>
## Part 1: Text Preprocessing and Tokenization (10 points)

### The Big Picture

Before building any model, we need to transform raw text into a clean sequence of tokens. Shakespeare's text contains stage directions (`[Enter ROMEO]`), character name prefixes (`ROMEO:`), and mixed punctuation — none of which help our language model learn word patterns.

Our preprocessing pipeline will:
1. Normalize the text (lowercase)
2. Remove non-speech content (stage directions, character names)
3. Split into sentences
4. Tokenize each sentence with special boundary markers

### 1.0: Load the Dataset

The Tiny Shakespeare dataset contains ~1.1 million characters from various Shakespeare plays. We download it from Andrej Karpathy's GitHub repository. (Done for you)

**Note:** This creates a text file called tiny_shakespeare.txt and so you don't need to call this cell again after the first time.

In [None]:
import urllib.request
import os

data_path = "tiny_shakespeare.txt"
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

# 1. Check if the local file is available
if os.path.exists(data_path):
    print(f"Local copy found at '{data_path}'. Loading from disk...")
    with open(data_path, 'r', encoding='utf-8') as f:
        raw_text = f.read()

# 2. If not, load via URL
else:
    print(f"Local copy not found. Loading via URL...")
    try:
        # Fetch the data straight from the web
        with urllib.request.urlopen(url) as response:
            raw_text = response.read().decode('utf-8')
        
        # Save a local copy for the next time the script runs
        with open(data_path, 'w', encoding='utf-8') as f:
            f.write(raw_text)
        print(f"Successfully loaded and saved a fallback copy to '{data_path}'.")
        
    except urllib.error.URLError as e:
        raise Exception(
            f"Failed to load dataset from URL. Please check your connection or the URL. Error: {e}"
        )

# 3. Print out the stats to verify
print(f"\nDataset size: {len(raw_text):,} characters")
print(f"\nFirst 500 characters:\n{'='*50}")
print(raw_text[:500])

### Task 1.1: Implement `preprocess_text` (Done for you)

This function transforms raw Shakespeare text into a list of clean sentences. Each sentence should contain only lowercase alphabetic words.

**The preprocessing steps:**

```
Raw text:
    "ROMEO: O, she doth teach the torches to burn bright!
     [Aside] It seems she hangs upon the cheek of night."

After preprocessing:
    ["she doth teach the torches to burn bright",
     "it seems she hangs upon the cheek of night"]
```

**Step-by-step:**

| Step | Operation | Tool |
|------|-----------|------|
| 1 | Convert to lowercase | `text.lower()` |
| 2 | Remove stage directions `[...]` | `re.sub(r'\[.*?\]', ' ', text)` |
| 3 | Remove character prefixes (e.g., `romeo:`) | `re.sub()` with `re.MULTILINE` |
| 4 | Split into sentences on `.`, `!`, `?` | `re.split(r'[.!?]+', text)` |
| 5 | Extract only alphabetic words | `re.findall(r'\b[a-z]+\b', sent)` |
| 6 | Filter out sentences with < 2 words | `if len(words) >= 2` |

**Why remove character names?** Lines like `ROMEO:` are not part of the spoken dialogue and including them would teach our model that "romeo" frequently starts sentences, which isn't useful.

**Why filter short sentences?** Single-word "sentences" (or empty ones) don't provide useful bigram/trigram data.

In [None]:
def preprocess_text(text: str) -> List[str]:
    """Preprocess raw Shakespeare text into a list of cleaned sentences."""
    text = text.lower()
    text = re.sub(r'\[.*?\]', ' ', text)
    text = re.sub(r'^[a-z ]+:', '', text, flags=re.MULTILINE)
    raw_sentences = re.split(r'[.!?]+', text)
    sentences = []
    for sent in raw_sentences:
        words = re.findall(r'\b[a-z]+\b', sent)
        if len(words) >= 2:
            sentences.append(' '.join(words))
    return sentences

sentences = preprocess_text(raw_text)

print(f"Preprocessed {len(sentences):,} sentences") # 12,198 sentences
for i, s in enumerate(sentences[:5]):
    print(f"  {i}: {s}")

### Task 1.2: Implement `tokenize_sentences` (10 points)

Language models need to know where sentences **begin** and **end**. We add special boundary tokens:
- `<s>` — Start-of-sentence token
- `</s>` — End-of-sentence token

**Example:**
```
Input sentence:  "to be or not to be"
Tokenized:       ['<s>', 'to', 'be', 'or', 'not', 'to', 'be', '</s>']
```

**Why boundary tokens?** They let the model learn:
- Which words typically **start** sentences (e.g., after `<s>`, "the" and "i" are common)
- Which words typically **end** sentences (e.g., before `</s>`, "so" and "me" are common)

You also need to build a **vocabulary** — mappings between words and integer indices. This is standard practice in NLP and will be useful for analysis.

In [None]:
def tokenize_sentences(sentences: List[str]) -> Tuple[List[List[str]], dict, dict]:
    """
    Tokenize sentences and build vocabulary mappings.
    
    For each sentence string like "the torches burn bright":
        1. Split into words
        2. Add <s> at the beginning and </s> at the end
    
    Also build:
        word_to_idx: dict mapping each unique word to an integer index
        idx_to_word: dict mapping each index back to the word
    
    Args:
        sentences: List of preprocessed sentence strings
    
    Returns:
        Tuple of (tokenized_sentences, word_to_idx, idx_to_word)
    """
    # ============================================================
    # TODO: Implement tokenization
    #
    # 1. For each sentence, split on spaces and add <s> / </s>
    # 2. Collect all unique tokens into a vocabulary set
    # 3. Build word_to_idx: sorted vocab -> enumerate
    # 4. Build idx_to_word: reverse of word_to_idx
    # ============================================================
    pass


# Test
tokenized, word_to_idx, idx_to_word = tokenize_sentences(sentences)

assert tokenized[0][0] == '<s>', "First token should be <s>"
assert tokenized[0][-1] == '</s>', "Last token should be </s>"
assert '<s>' in word_to_idx, "Vocabulary should contain <s>"
assert '</s>' in word_to_idx, "Vocabulary should contain </s>"

print(f"Vocabulary size: {len(word_to_idx):,}")
print(f"Total tokens: {sum(len(s) for s in tokenized):,}")
print(f"\nSample tokenized sentence: {tokenized[0]}")

### Task 1.3: Train/Test Split and Exploratory Analysis

We split the data into training (80%), validation (10%), and test (10%) sets. This is done for you.

In [None]:
# Split data: 80% train, 10% validation, 10% test
random.shuffle(tokenized)
n = len(tokenized)
train_data = tokenized[:int(0.8 * n)]
val_data = tokenized[int(0.8 * n):int(0.9 * n)]
test_data = tokenized[int(0.9 * n):]

print(f"Train: {len(train_data):,} sentences")
print(f"Val:   {len(val_data):,} sentences")
print(f"Test:  {len(test_data):,} sentences")

---
<a id="part2"></a>
## Part 2: Unigram Language Model (40 points)

### How the Unigram Model Works

A **unigram model** is the simplest possible language model. It assumes every word is generated independently, according to its frequency in the training data:

$$P(w_i) = \frac{\text{count}(w_i)}{\sum_{w} \text{count}(w)}$$

**Strengths:** Simple, fast, captures which words are common in Shakespeare.

**Weaknesses:** Completely ignores word order. "To be or not to be" and "be not or to to be" have the same probability. Generated text is just random common words.

### What You Need to Implement

You will implement the unigram model as a set of **standalone functions** that operate on a shared dictionary:

| Function | What It Does |
|----------|-------------|
| `unigram_train(data)` | Count words, compute probabilities, return model dict |
| `unigram_probability(model, word)` | Return P(word) for any word |
| `unigram_sentence_log_prob(model, sentence)` | Compute log P(sentence) as sum of log P(word) |
| `unigram_perplexity(model, data)` | Evaluate the model on a dataset (Done for you) |
| `unigram_generate(model, max_length)` | Sample random sentences from the model |

**Key detail:** We count all words *except* `<s>` (since we never generate the start token — it's always given). We *do* count `</s>` since the model needs to learn when to stop.

### Task 2.1: Implement `unigram_train` (10 points)

In [None]:
def unigram_train(data: List[List[str]]) -> dict:
    """
    Train a unigram model on tokenized sentences.
    
    Compute the probability of each word as:
        P(w) = count(w) / total_word_count
    
    Include </s> but exclude <s> from counts.
    
    Args:
        data: List of tokenized sentences (each is a list of strings)
    
    Returns:
        A dictionary with keys:
            'word_probs': dict mapping word -> probability
            'vocab': set of all words in the vocabulary
    """
    # ============================================================
    # TODO: Implement training
    # ============================================================
    pass


# Test
unigram_model = unigram_train(train_data)

assert isinstance(unigram_model, dict), "Should return a dict"
assert 'word_probs' in unigram_model, "Should contain 'word_probs'"
assert 'vocab' in unigram_model, "Should contain 'vocab'"
assert '</s>' in unigram_model['word_probs'], "Should contain </s>"
assert '<s>' not in unigram_model['word_probs'], "Should NOT contain <s>"

print(f"Vocabulary size: {len(unigram_model['vocab']):,}")
print(f"\nTop 10 unigram probabilities:")
top_words = sorted(unigram_model['word_probs'].items(), key=lambda x: -x[1])[:10]
for word, prob in top_words:
    print(f"  P({word:10s}) = {prob:.5f}")

### Task 2.2: Implement `unigram_probability` (10 points)

In [None]:
def unigram_probability(model: dict, word: str) -> float:
    """
    Return P(word) from the unigram model.
    Return 1e-10 for unknown words.
    
    Args:
        model: The unigram model dictionary from unigram_train()
        word: A word string
    
    Returns:
        The probability P(word)
    """
    # ============================================================
    # TODO: Implement (~1 line)
    # ============================================================
    pass


# Test
assert unigram_probability(unigram_model, 'the') > 0.01, "Common word should have high prob"
assert unigram_probability(unigram_model, 'xyzzy') == 1e-10, "Unknown word should return 1e-10"
print(f"P('the') = {unigram_probability(unigram_model, 'the'):.5f}")
print(f"P('xyzzy') = {unigram_probability(unigram_model, 'xyzzy')}")

### Task 2.3: Implement `unigram_sentence_log_prob` (10 points)

In [None]:
def unigram_sentence_log_prob(model: dict, sentence: List[str]) -> float:
    """
    Compute log P(sentence) = sum of log P(w_i) for each word.
    Exclude <s> from the computation.
    
    Args:
        model: The unigram model dictionary
        sentence: A tokenized sentence (list of strings)
    
    Returns:
        The log probability of the sentence
    """
    # ============================================================
    # TODO: Implement (~2 lines)
    # ============================================================
    pass


# Test
test_sent = train_data[0]
log_prob = unigram_sentence_log_prob(unigram_model, test_sent)
assert log_prob < 0, "Log probability should be negative"
print(f"Log P('{' '.join(test_sent)}') = {log_prob:.4f}")

### `unigram_perplexity` (Done for you)

In [None]:
def unigram_perplexity(model: dict, data: List[List[str]]) -> float:
    total_log_prob = 0
    total_tokens = 0
    for sentence in data:
        total_log_prob += unigram_sentence_log_prob(model, sentence)
        total_tokens += len(sentence) - 1
    return math.exp(-total_log_prob / total_tokens)

train_ppl = unigram_perplexity(unigram_model, train_data)
val_ppl = unigram_perplexity(unigram_model, val_data)
assert train_ppl > 0, "Perplexity should be positive"
print(f"Train perplexity: {train_ppl:.2f}")
print(f"Val perplexity:   {val_ppl:.2f}")

### Task 2.5: Implement `unigram_generate` (10 points)

In [None]:
def unigram_generate(model: dict, max_length: int = 20) -> str:
    """
    Generate a sentence by sampling words independently from the unigram distribution.
    Stop when </s> is generated or max_length is reached.
    
    Args:
        model: The unigram model dictionary
        max_length: Maximum number of words to generate
    
    Returns:
        Generated sentence as a string (without <s> and </s>)
    """
    # ============================================================
    # TODO: Implement
    # 1. Build lists of words and their probabilities from model['word_probs']
    # 2. Use np.random.choice(words, p=probs) to sample each word
    # 3. Stop at </s> or max_length
    # 4. Return ' '.join(generated_words)
    # ============================================================
    pass


# Test
print("Generated sentences (unigram):")
for i in range(5):
    print(f"  {i+1}. {unigram_generate(unigram_model)}")

---
<a id="part3"></a>
## Part 3: Bigram Language Model (50 points)

### How the Bigram Model Works

A **bigram model** conditions each word on the immediately preceding word:

$$P(w_i | w_{i-1}) = \frac{\text{count}(w_{i-1}, w_i)}{\text{count}(w_{i-1})}$$

This captures basic word-to-word transitions. In Shakespeare, after "thou", words like "art", "hast", and "shalt" are far more likely than "computer" or "pizza".

**How it builds the graph of transitions:**

```
Training sentence: ['<s>', 'to', 'be', 'or', 'not', 'to', 'be', '</s>']

Bigrams extracted:
    (<s>, to)     → count(<s>, to) += 1
    (to, be)      → count(to, be) += 1     (counted TWICE)
    (be, or)      → count(be, or) += 1
    (or, not)     → count(or, not) += 1
    (not, to)     → count(not, to) += 1
    (to, be)      → count(to, be) += 1     (second occurrence)
    (be, </s>)    → count(be, </s>) += 1
```

### The Zero-Probability Problem

Without smoothing, if a bigram like ("thou", "computer") never appeared in training, we get:

$$P(\text{computer} | \text{thou}) = \frac{0}{\text{count}(\text{thou})} = 0$$

A single zero makes the *entire sentence* probability zero! This is catastrophic for evaluation on new data where unseen bigrams are guaranteed. **Laplace (add-α) smoothing** fixes this:

$$P_{\text{smooth}}(w_i | w_{i-1}) = \frac{\text{count}(w_{i-1}, w_i) + \alpha}{\text{count}(w_{i-1}) + \alpha \cdot |V|}$$

This redistributes a small amount of probability mass to every possible bigram, ensuring nothing gets probability zero.

### Task 3.1: Implement `bigram_train` (10 points)

In [None]:
def bigram_train(data: List[List[str]], smoothing: float = 0.0) -> dict:
    """
    Train a bigram model on tokenized sentences.
    
    For each consecutive pair (w_{i-1}, w_i) in each sentence:
        - Increment bigram_counts[w_{i-1}][w_i]
        - Increment unigram_counts[w_{i-1}]
    Also collect the full vocabulary.
    
    Args:
        data: List of tokenized sentences
        smoothing: Laplace smoothing parameter (alpha). If 0, no smoothing.
    
    Returns:
        A dictionary with keys:
            'bigram_counts': defaultdict(Counter) — bigram_counts[context][word] = count
            'unigram_counts': Counter — count of each word used as context
            'vocab': set of all unique tokens
            'smoothing': the smoothing parameter
    """
    # ============================================================
    # TODO: Implement training
    # ============================================================
    pass


# Test
bigram_model = bigram_train(train_data, smoothing=0.0)

assert isinstance(bigram_model, dict), "Should return a dict"
assert 'bigram_counts' in bigram_model
assert 'thou' in bigram_model['bigram_counts'], "'thou' should appear as a context"

print(f"Vocabulary size: {len(bigram_model['vocab']):,}")
print(f"Unique contexts: {len(bigram_model['unigram_counts']):,}")

### Task 3.2: Implement `bigram_probability` (20 points)

In [None]:
def bigram_probability(model: dict, word: str, context: str) -> float:
    """
    Compute P(word | context) with optional Laplace smoothing.
    
    Without smoothing (alpha=0):
        P(w | c) = count(c, w) / count(c)
        Return 1e-10 if count is 0.
    
    With Laplace smoothing:
        P(w | c) = (count(c, w) + alpha) / (count(c) + alpha * |V|)
    
    Args:
        model: The bigram model dictionary
        word: The word to compute probability for
        context: The preceding word
    
    Returns:
        P(word | context)
    """
    # ============================================================
    # TODO: Implement probability computation
    # Handle both smoothed and unsmoothed cases
    # ============================================================
    pass


# Test
print("Words most likely to follow 'thou':")
thou_probs = {w: bigram_probability(bigram_model, w, 'thou') 
              for w in bigram_model['bigram_counts']['thou']}
for word, prob in sorted(thou_probs.items(), key=lambda x: -x[1])[:10]:
    print(f"  P({word:12s} | thou) = {prob:.4f}")

### Task 3.3: Implement `bigram_sentence_log_prob` and `bigram_perplexity` (10 points)

In [None]:
def bigram_sentence_log_prob(model: dict, sentence: List[str]) -> float:
    """
    Compute log P(sentence) = sum of log P(w_i | w_{i-1}).
    
    Args:
        model: The bigram model dictionary
        sentence: A tokenized sentence
    
    Returns:
        Log probability of the sentence
    """
    # ============================================================
    # TODO: Implement (~3 lines)
    # ============================================================
    pass


def bigram_perplexity(model: dict, data: List[List[str]]) -> float:
    total_log_prob = 0
    total_tokens = 0
    for sentence in data:
        total_log_prob += bigram_sentence_log_prob(model, sentence)
        total_tokens += len(sentence) - 1
    return math.exp(-total_log_prob / total_tokens)

# Test
train_ppl = bigram_perplexity(bigram_model, train_data)
print(f"Train perplexity (no smoothing): {train_ppl:.2f}")

### Task 3.4: Implement `bigram_generate` (10 points)

In [None]:
def bigram_generate(model: dict, max_length: int = 30) -> str:
    """
    Generate a sentence word by word, starting from <s>.
    
    At each step, sample the next word from P(w | previous_word).
    Stop at </s> or max_length.
    
    Args:
        model: The bigram model dictionary
        max_length: Maximum number of words to generate
    
    Returns:
        Generated sentence as a string (without <s> and </s>)
    """
    # ============================================================
    # TODO: Implement
    # 1. Start with context = '<s>'
    # 2. At each step, get candidates from bigram_counts[context]
    # 3. Sample proportional to counts using np.random.choice
    # 4. Update context to the sampled word
    # 5. Stop at '</s>' or max_length
    # ============================================================
    pass


# Test
print("Generated sentences (bigram):")
for i in range(5):
    print(f"  {i+1}. {bigram_generate(bigram_model)}")

### Task 3.5: Smoothing Experiment (Done for you)

The following code will hyperparameter tune for laplace smoothing. We will use choose the $\alpha$ with the lowest perplexity for the heatmap visualization in the next part.

In [None]:
# Smoothing experiment
alphas = [0.001, 0.01, 0.1, 0.5, 1.0, 2.0]
val_perplexities = []

for alpha in alphas:
    m = bigram_train(train_data, smoothing=alpha)
    ppl = bigram_perplexity(m, val_data)
    val_perplexities.append(ppl)
    print(f"  alpha={alpha:.3f}  val_ppl={ppl:.2f}")

plt.figure(figsize=(8, 5))
plt.semilogx(alphas, val_perplexities, 'bo-', linewidth=2, markersize=8)
plt.xlabel('Smoothing Parameter (α)', fontsize=12)
plt.ylabel('Validation Perplexity', fontsize=12)
plt.title('Bigram Model: Smoothing vs. Perplexity', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

best_idx = np.argmin(val_perplexities)
print(f"\n Best alpha: {alphas[best_idx]}, Val PPL: {val_perplexities[best_idx]:.2f}")

### Task 3.6: Visualize Bigram Probabilities (Done for you)

This heatmap shows transition probabilities between common words. Study the patterns — you should see that `<s>` transitions mostly to common sentence-starting words, while function words like "the" and "of" have broad distributions.

What are some of the most common bigrams? Do they make sense to you?

In [None]:
# Visualize bigram probabilities as a heatmap
best_bigram = bigram_train(train_data, smoothing=alphas[np.argmin(val_perplexities)])

focus_words = ['<s>', 'the', 'thou', 'my', 'and', 'to', 'of', 'is', 'in', 'a',
               'not', 'i', 'love', 'shall', 'what', 'his', 'her', '</s>']

prob_matrix = np.zeros((len(focus_words), len(focus_words)))
for i, ctx in enumerate(focus_words):
    for j, word in enumerate(focus_words):
        prob_matrix[i, j] = bigram_probability(best_bigram, word, ctx)

plt.figure(figsize=(12, 10))
plt.imshow(prob_matrix, cmap='YlOrRd', aspect='auto')
plt.colorbar(label='P(column | row)')
plt.xticks(range(len(focus_words)), focus_words, rotation=45, ha='right')
plt.yticks(range(len(focus_words)), focus_words)
plt.xlabel('Next Word', fontsize=12)
plt.ylabel('Context Word', fontsize=12)
plt.title('Bigram Transition Probabilities', fontsize=14)

for i in range(len(focus_words)):
    for j in range(len(focus_words)):
        val = prob_matrix[i, j]
        if val > 0.01:
            plt.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=7)

plt.tight_layout()
plt.show()

---
<a id="part4"></a>
## Part 4 (Extra Credit): General N-gram Language Model (20 points)

### Generalizing Beyond Bigrams

Now we extend to **N-grams**, where each word is conditioned on the previous $(N-1)$ words:

$$P(w_i | w_{i-n+1}, \ldots, w_{i-1}) = \frac{\text{count}(w_{i-n+1}, \ldots, w_i)}{\text{count}(w_{i-n+1}, \ldots, w_{i-1})}$$

**The context is now a tuple** rather than a single word. For a trigram model ($N=3$):

```
Sentence: ['<s>', '<s>', 'to', 'be', 'or', 'not', '</s>']
                  ^^^^ padded with extra <s>

Trigrams:
    context=('<s>', '<s>')  → word='to'
    context=('<s>', 'to')   → word='be'
    context=('to', 'be')    → word='or'
    context=('be', 'or')    → word='not'
    context=('or', 'not')   → word='</s>'
```

**Important:** We pad the beginning of each sentence with $(N-1)$ `<s>` tokens so that even the first word has a full context window.

### The Context-Length Trade-off

| N | Context | Captures | Problem |
|---|---------|----------|---------|
| 1 (unigram) | None | Word frequencies | No word order |
| 2 (bigram) | 1 word | Local pairs | Limited context |
| 3 (trigram) | 2 words | Short phrases | Data sparsity grows |
| 4+ | 3+ words | Longer patterns | Most contexts never seen |

As $N$ increases, the model captures more context but faces exponentially worse data sparsity: there are $|V|^N$ possible N-grams, and most of them never appear in training.

### Task 4.1: Implement `get_ngrams` (4 points)

In [None]:
def get_ngrams(sentence: List[str], n: int) -> List[Tuple[tuple, str]]:
    """
    Extract (context, word) pairs from a sentence for an N-gram model.
    
    For a trigram model (n=3) and sentence ['<s>', 'to', 'be', 'or', '</s>']:
        - Pad beginning: ['<s>', '<s>', 'to', 'be', 'or', '</s>']
        - context = tuple of previous (n-1) = 2 words
        - Pairs: (('<s>', '<s>'), 'to'), (('<s>', 'to'), 'be'), ...
    
    For unigram (n=1): context is always empty tuple ()
    For bigram (n=2): context is 1-tuple of previous word
    
    Args:
        sentence: A tokenized sentence (list of strings)
        n: The order of the N-gram model
    
    Returns:
        List of (context_tuple, next_word) pairs
    """
    # ============================================================
    # TODO: Implement N-gram extraction
    # 1. Pad beginning: pad = ['<s>'] * (n - 1)
    #    padded = pad + sentence[1:]  (avoid double-counting existing <s>)
    # 2. For each position i from (n-1) to len(padded):
    #    - context = tuple of padded[i-n+1 : i] if n > 1 else ()
    #    - word = padded[i]
    #    - Append (context, word) to results
    # ============================================================
    pass


# Test
test_sent = ['<s>', 'to', 'be', 'or', '</s>']
bigrams = get_ngrams(test_sent, n=2)
trigrams = get_ngrams(test_sent, n=3)
unigrams = get_ngrams(test_sent, n=1)

print("Bigrams:", bigrams)
print("Trigrams:", trigrams)
print("Unigrams:", unigrams)

assert bigrams[0] == (('<s>',), 'to'), f"First bigram wrong: {bigrams[0]}"
assert trigrams[0] == (('<s>', '<s>'), 'to'), f"First trigram wrong: {trigrams[0]}"
assert unigrams[0] == ((), 'to'), f"First unigram wrong: {unigrams[0]}"
print("\n✓ All ngram extraction tests passed!")

### Task 4.2: Implement `ngram_train` (4 points)

In [None]:
def ngram_train(data: List[List[str]], n: int, smoothing: float = 0.1) -> dict:
    """
    Train a general N-gram model.
    
    Args:
        data: List of tokenized sentences
        n: The N-gram order (1=unigram, 2=bigram, 3=trigram, etc.)
        smoothing: Laplace smoothing parameter
    
    Returns:
        A dictionary with keys:
            'n': the N-gram order
            'smoothing': the smoothing parameter
            'ngram_counts': defaultdict(Counter) — ngram_counts[context_tuple][word]
            'context_counts': Counter — total count for each context tuple
            'vocab': set of all unique tokens
    """
    # ============================================================
    # TODO: Implement training
    # For each sentence, extract ngrams using get_ngrams() and count them:
    #   for context, word in get_ngrams(sentence, n):
    #       ngram_counts[context][word] += 1
    #       context_counts[context] += 1
    #       vocab.add(word) + add all words in context
    # ============================================================
    pass


# Test
trigram_model = ngram_train(train_data, n=3, smoothing=0.1)
assert isinstance(trigram_model, dict)
assert trigram_model['n'] == 3
print(f"Trigram model: {len(trigram_model['context_counts']):,} unique contexts, "
      f"{len(trigram_model['vocab']):,} vocab size")

### Task 4.3: Implement `ngram_probability` (4 points)

In [None]:
def ngram_probability(model: dict, word: str, context: tuple) -> float:
    """
    Compute P(word | context) with Laplace smoothing.
    
    P(w | c) = (count(c, w) + alpha) / (count(c) + alpha * |V|)
    
    Args:
        model: The N-gram model dictionary
        word: The word to compute probability for
        context: A tuple of preceding words
    
    Returns:
        P(word | context)
    """
    # ============================================================
    # TODO: Implement
    # ============================================================
    pass


# Test
p = ngram_probability(trigram_model, 'be', ('<s>', 'to'))
print(f"P(be | '<s>', 'to') = {p:.4f}")
assert p > 0, "Probability should be positive"
# Unknown context should still give nonzero probability (due to smoothing)
p_unk = ngram_probability(trigram_model, 'be', ('xyzzy', 'abcde'))
assert p_unk > 0, "Smoothed probability of unseen context should be > 0"
print(f"P(be | 'xyzzy', 'abcde') = {p_unk:.6f}  (smoothed)")

### Task 4.4: Implement `ngram_perplexity` (4 points)

In [None]:
def ngram_perplexity(model: dict, data: List[List[str]]) -> float:
    """
    Compute perplexity of the N-gram model on a dataset.
    
    Args:
        model: The N-gram model dictionary
        data: List of tokenized sentences
    
    Returns:
        Perplexity (positive float; lower is better)
    """
    # ============================================================
    # TODO: Implement
    # For each sentence, use get_ngrams(sentence, model['n']) to
    # extract (context, word) pairs, then compute log probabilities
    # ============================================================
    pass


# Test
train_ppl = ngram_perplexity(trigram_model, train_data)
val_ppl = ngram_perplexity(trigram_model, val_data)
print(f"Trigram — Train perplexity: {train_ppl:.2f}")
print(f"Trigram — Val perplexity:   {val_ppl:.2f}")

### Task 4.5: Implement `ngram_generate` (4 points)

In [None]:
def ngram_generate(model: dict, max_length: int = 30) -> str:
    """
    Generate a sentence from the N-gram model.
    
    Start with context of (n-1) <s> tokens.
    At each step, sample next word given current context.
    Slide context window forward after each word.
    
    Args:
        model: The N-gram model dictionary
        max_length: Maximum number of words to generate
    
    Returns:
        Generated sentence as a string (without <s> and </s>)
    """
    # ============================================================
    # TODO: Implement
    # 1. n = model['n']
    # 2. context = tuple(['<s>'] * (n-1)) if n > 1 else ()
    # 3. Loop: sample word from ngram_counts[context]
    # 4. Update context: context = tuple(list(context[1:]) + [word])
    # 5. Stop at '</s>' or max_length
    # ============================================================
    pass


# Test
print("Generated sentences (trigram):")
for i in range(5):
    print(f"  {i+1}. {ngram_generate(trigram_model)}")

### 4.6: N-gram Comparison (Done for you)

Run the following code which compares n-gram models of different sizes. 

**Conceptual Questions (not graded)**

- Does the text generated seem like something Shakespeare would wrtite?
- How does the text quality change as we increase n? (unigram vs bigram vs trigram vs etc.)


In [None]:
# Compare N-gram orders
n_values = [1, 2, 3, 4, 5]
train_ppls = []
val_ppls = []
models = {}

for nv in n_values:
    m = ngram_train(train_data, n=nv, smoothing=0.1)
    models[nv] = m
    t_ppl = ngram_perplexity(m, train_data)
    v_ppl = ngram_perplexity(m, val_data)
    train_ppls.append(t_ppl)
    val_ppls.append(v_ppl)
    print(f"N={nv}: Train PPL={t_ppl:8.2f}  Val PPL={v_ppl:8.2f}")

plt.figure(figsize=(8, 5))
plt.plot(n_values, train_ppls, 'bo-', label='Train', linewidth=2, markersize=8)
plt.plot(n_values, val_ppls, 'rs-', label='Validation', linewidth=2, markersize=8)
plt.xlabel('N-gram Order (N)', fontsize=12)
plt.ylabel('Perplexity', fontsize=12)
plt.title('Perplexity vs. N-gram Order', fontsize=14)
plt.legend(fontsize=11)
plt.xticks(n_values)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\n" + "=" * 60)
print("GENERATED TEXT COMPARISON")
print("=" * 60)
for nv in n_values:
    print(f"\n--- {nv}-gram Model ---")
    for i in range(3):
        print(f"  {ngram_generate(models[nv])}")

---
## Submission Checklist

Before submitting, make sure:

- [ ] **Part 1:** `tokenize_sentences` work correctly
- [ ] **Part 2:** All 5 unigram functions implemented
- [ ] **Part 3:** All bigram functions implemented
- [ ] **Part 4 (Extra Credit):** All N-gram functions implemented

**Grading:**

| Part | Points |
|------|--------|
| Part 1: Preprocessing & Tokenization | 10 |
| Part 2: Unigram Model | 40 |
| Part 3: Bigram Model | 50 |
| Part 4 (Extra Credit): N-gram Model | 20 |
| **Total** | **100** |