---
title: Transformers and Self-Attention
exports:
  - format: pdf
    template: plain_latex
    output: exports/Transformers.pdf
    logo: false
    link: true
downloads:
  - file: exports/Transformers.pdf
  - file: Transformers.ipynb
math:
    '\calA': '{\cal A}'
    '\calB': '{\cal B}'
    '\calC': '{\cal C}'
    '\calD': '{\cal D}'
    '\calE': '{\cal E}'
    '\calF': '{\cal F}'
    '\calG': '{\cal G}'
    '\calH': '{\cal H}'
    '\calI': '{\cal I}'
    '\calJ': '{\cal J}'
    '\calK': '{\cal K}'
    '\calL': '{\cal L}'
    '\calM': '{\cal M}'
    '\calN': '{\cal N}'
    '\calO': '{\cal O}'
    '\calP': '{\cal P}'
    '\calQ': '{\cal Q}'
    '\calR': '{\cal R}'
    '\calS': '{\cal S}'
    '\calT': '{\cal T}'
    '\calU': '{\cal U}'
    '\calV': '{\cal V}'
    '\calW': '{\cal W}'
    '\calX': '{\cal X}'
    '\calY': '{\cal Y}'
    '\calZ': '{\cal Z}'
    '\bfa': '\mathbf{a}'
    '\bfb': '\mathbf{b}'
    '\bfc': '\mathbf{c}'
    '\bfd': '\mathbf{d}'
    '\bfe': '\mathbf{e}'
    '\bff': '\mathbf{f}'
    '\bfg': '\mathbf{g}'
    '\bfh': '\mathbf{h}'
    '\bfi': '\mathbf{i}'
    '\bfj': '\mathbf{j}'
    '\bfk': '\mathbf{k}'
    '\bfl': '\mathbf{l}'
    '\bfm': '\mathbf{m}'
    '\bfn': '\mathbf{n}'
    '\bfo': '\mathbf{o}'
    '\bfp': '\mathbf{p}'
    '\bfq': '\mathbf{q}'
    '\bfr': '\mathbf{r}'
    '\bfs': '\mathbf{s}'
    '\bft': '\mathbf{t}'
    '\bfu': '\mathbf{u}'
    '\bfv': '\mathbf{v}'
    '\bfw': '\mathbf{w}'
    '\bfx': '\mathbf{x}'
    '\bfy': '\mathbf{y}'
    '\bfz': '\mathbf{z}'
    '\bfW': '\mathbf{W}'
    '\bfX': '\mathbf{X}'
    '\bfY': '\mathbf{Y}'
    '\bfZ': '\mathbf{Z}'
    '\bfK': '\mathbf{K}'
    '\bfQ': '\mathbf{Q}'
    '\bfV': '\mathbf{V}'
    '\bftheta': '\boldsymbol{\theta}'
    '\bbR': '\mathbb{R}'
    '\bbE': '\mathbb{E}'
    '\p': '\partial'
---

# Transformers and Self-Attention

In this notebook, we will explore:
1. Attention as soft dictionary lookup
2. Scaled dot-product attention
3. Self-attention mechanism
4. Multi-head attention
5. Positional encoding
6. Building a complete Transformer (MiniGPT)
7. Training on Shakespeare text

**Prerequisites**: MLP notebook (05-mlp), basic linear algebra

**Reference**: "Attention Is All You Need" (Vaswani et al., 2017)

In [None]:
# Install dependencies
!pip install otter-grader torch matplotlib

In [None]:
# Setup otter-grader
URL = "https://raw.githubusercontent.com/wecacuee/ECE490-F25-Neural-Networks/refs/heads/master/notebooks/12-transformers/TransformersTests.zip"
fname = "TransformersTests.zip"
import urllib
from zipfile import ZipFile
try:
    urllib.request.urlretrieve(URL, fname)
    ZipFile(fname).extractall()
except:
    print("Could not download tests. Grading may not work.")
import otter
grader = otter.Notebook(tests_dir="./tests")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math

# Set device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

# For reproducibility
torch.manual_seed(42)

## 1. Motivation: Why Attention?

**Problems with RNNs:**
- Sequential processing: $O(n)$ steps to process sequence of length $n$
- Long-range dependencies are hard to learn
- Vanishing/exploding gradients over long sequences

**Attention's solution:**
- Process all tokens in parallel: $O(1)$ sequential steps
- Direct connections between any two positions
- Gradient flows directly between any two positions

## 2. Attention as Soft Dictionary Lookup

### Hard Lookup (Traditional Dictionary)
Given a query $\bfq$, find the matching key and return its value:
$$\text{lookup}(\bfq) = \bfv_i \text{ where } \bfk_i = \bfq$$

### Soft Lookup (Attention)
Return a **weighted average** of all values, where weights depend on query-key similarity:
$$\text{Attention}(\bfq, \bfK, \bfV) = \sum_{i=1}^{n} \alpha_i \bfv_i$$

The attention weights $\alpha_i$ are computed as:
$$\alpha_i = \frac{\exp(\bfq^\top \bfk_i)}{\sum_j \exp(\bfq^\top \bfk_j)} = \text{softmax}(\bfq^\top \bfK)_i$$

In [None]:
def simple_attention(query, keys, values):
    """
    Simple attention mechanism for a single query.
    
    Args:
        query: (d,) single query vector
        keys: (n, d) n key vectors
        values: (n, d_v) n value vectors
        
    Returns:
        output: (d_v,) weighted sum of values
        weights: (n,) attention weights
    """
    # Compute attention scores
    scores = query @ keys.T  # (n,)
    
    # Convert to probabilities
    weights = F.softmax(scores, dim=-1)  # (n,)
    
    # Weighted sum of values
    output = weights @ values  # (d_v,)
    
    return output, weights

# Example: simple key-value store
keys = torch.tensor([[1., 0., 0.],
                     [0., 1., 0.],
                     [0., 0., 1.]])
values = torch.tensor([[1., 0.],  # "apple"
                       [0., 1.],  # "banana" 
                       [1., 1.]]) # "cherry"

# Query close to first key
query = torch.tensor([0.9, 0.1, 0.0])
output, weights = simple_attention(query, keys, values)
print(f"Query: {query}")
print(f"Attention weights: {weights}")
print(f"Output: {output}")

In [None]:
# Visualize attention weights
def visualize_attention_weights(weights, query_labels, key_labels, title="Attention Weights"):
    fig, ax = plt.subplots(figsize=(8, 6))
    im = ax.imshow(weights, cmap='Blues', aspect='auto')
    
    ax.set_xticks(range(len(key_labels)))
    ax.set_xticklabels(key_labels)
    ax.set_yticks(range(len(query_labels)))
    ax.set_yticklabels(query_labels)
    
    ax.set_xlabel('Keys')
    ax.set_ylabel('Queries')
    ax.set_title(title)
    
    plt.colorbar(im)
    plt.tight_layout()
    return fig

# Multiple queries
queries = torch.tensor([[0.9, 0.1, 0.0],
                        [0.1, 0.9, 0.0],
                        [0.3, 0.3, 0.4]])

all_weights = []
for q in queries:
    _, w = simple_attention(q, keys, values)
    all_weights.append(w.numpy())
all_weights = np.array(all_weights)

visualize_attention_weights(all_weights, 
                           ['query_1', 'query_2', 'query_3'],
                           ['key_1', 'key_2', 'key_3'])
plt.show()

## 3. Scaled Dot-Product Attention

The standard attention formula used in Transformers:

$$\text{Attention}(\bfQ, \bfK, \bfV) = \text{softmax}\left(\frac{\bfQ \bfK^\top}{\sqrt{d_k}}\right) \bfV$$

### Why scale by $\sqrt{d_k}$?

If $q_i, k_i \sim \calN(0, 1)$ independently:
$$\bfq^\top \bfk = \sum_{i=1}^{d_k} q_i k_i$$

$$\bbE[\bfq^\top \bfk] = 0, \quad \text{Var}(\bfq^\top \bfk) = d_k$$

Large variance pushes softmax into saturation (very peaked distribution). Scaling by $\sqrt{d_k}$ normalizes variance to 1.

<!-- BEGIN QUESTION -->

### Exercise 1: Implement Scaled Dot-Product Attention (15 points)

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: (batch, seq_len_q, d_k) query vectors
        K: (batch, seq_len_k, d_k) key vectors
        V: (batch, seq_len_k, d_v) value vectors
        mask: optional (seq_len_q, seq_len_k) mask, -inf for masked positions
        
    Returns:
        output: (batch, seq_len_q, d_v) attention output
        attention_weights: (batch, seq_len_q, seq_len_k) attention weights
    """
    d_k = K.size(-1)
    
    # TODO: Implement scaled dot-product attention
    # 1. Compute attention scores: Q @ K^T
    # 2. Scale by sqrt(d_k)
    # 3. If mask is provided, add it to scores (masked positions should become -inf)
    # 4. Apply softmax to get attention weights
    # 5. Multiply weights by V to get output
    
    # YOUR CODE HERE
    scores = ...  # Step 1-2
    
    if mask is not None:
        scores = ...  # Step 3
    
    attention_weights = ...  # Step 4
    output = ...  # Step 5
    
    return output, attention_weights

In [None]:
# Test scaled dot-product attention
def test_scaled_attention():
    B, T, d_k, d_v = 2, 5, 8, 16
    Q = torch.randn(B, T, d_k)
    K = torch.randn(B, T, d_k)
    V = torch.randn(B, T, d_v)
    
    output, weights = scaled_dot_product_attention(Q, K, V)
    
    assert output.shape == (B, T, d_v), f"Expected output shape {(B, T, d_v)}, got {output.shape}"
    assert weights.shape == (B, T, T), f"Expected weights shape {(B, T, T)}, got {weights.shape}"
    
    # Check weights sum to 1
    weight_sums = weights.sum(dim=-1)
    assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), "Weights should sum to 1"
    
    print("Scaled dot-product attention tests passed!")

test_scaled_attention()

In [None]:
grader.check("scaled_attention")

## 4. Self-Attention

In **self-attention**, queries, keys, and values all come from the same sequence $\bfX \in \bbR^{n \times d}$:

$$\bfQ = \bfX \bfW^Q, \quad \bfK = \bfX \bfW^K, \quad \bfV = \bfX \bfW^V$$

Where $\bfW^Q, \bfW^K \in \bbR^{d \times d_k}$ and $\bfW^V \in \bbR^{d \times d_v}$ are learned projections.

**Self-Attention Output**:
$$\text{SelfAttn}(\bfX) = \text{softmax}\left(\frac{\bfX \bfW^Q (\bfX \bfW^K)^\top}{\sqrt{d_k}}\right) \bfX \bfW^V$$

In [None]:
class SelfAttention(nn.Module):
    """Single-head self-attention layer."""
    
    def __init__(self, d_model, d_k=None, d_v=None):
        super().__init__()
        d_k = d_k or d_model
        d_v = d_v or d_model
        
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_v, bias=False)
        self.d_k = d_k
        
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model) input sequence
            mask: optional attention mask
        Returns:
            output: (batch, seq_len, d_v)
            attention_weights: (batch, seq_len, seq_len)
        """
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        return scaled_dot_product_attention(Q, K, V, mask)

# Test self-attention
sa = SelfAttention(d_model=64, d_k=32, d_v=32)
x = torch.randn(2, 10, 64)
out, weights = sa(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weights shape: {weights.shape}")

## 5. Multi-Head Attention

A single attention head can only focus on one type of relationship. **Multi-head attention** runs multiple attention heads in parallel, allowing the model to attend to information from different representation subspaces.

$$\text{MultiHead}(\bfQ, \bfK, \bfV) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \bfW^O$$

Where each head is:
$$\text{head}_i = \text{Attention}(\bfQ \bfW_i^Q, \bfK \bfW_i^K, \bfV \bfW_i^V)$$

**Dimensions**:
- Input: $d_{model}$
- Per-head: $d_k = d_v = d_{model} / h$
- Output projection: $\bfW^O \in \bbR^{h \cdot d_v \times d_{model}}$

### Exercise 2: Implement Multi-Head Attention (20 points)

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention as described in "Attention Is All You Need".
    
    Args:
        d_model: Model dimension
        n_heads: Number of attention heads
        dropout: Dropout probability (default: 0.1)
    """
    
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # TODO: Define the following layers:
        # self.W_qkv: Linear layer that projects to Q, K, V concatenated (d_model -> 3*d_model)
        # self.W_o: Output projection (d_model -> d_model)
        # self.dropout: Dropout layer
        
        self.W_qkv = ...  # YOUR CODE HERE
        self.W_o = ...    # YOUR CODE HERE
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model) input
            mask: optional attention mask
            
        Returns:
            output: (batch, seq_len, d_model)
        """
        B, T, C = x.size()
        
        # TODO: Implement multi-head attention
        # 1. Project x to Q, K, V using W_qkv
        # 2. Split into n_heads: reshape to (B, T, n_heads, d_k) then transpose to (B, n_heads, T, d_k)
        # 3. Apply scaled dot-product attention
        # 4. Concatenate heads: transpose back and reshape to (B, T, d_model)
        # 5. Apply output projection
        
        # YOUR CODE HERE
        
        return output

In [None]:
# Test multi-head attention
def test_multihead():
    B, T, d_model, n_heads = 2, 10, 64, 8
    mha = MultiHeadAttention(d_model, n_heads)
    x = torch.randn(B, T, d_model)
    out = mha(x)
    
    assert out.shape == (B, T, d_model), f"Expected {(B, T, d_model)}, got {out.shape}"
    print("Multi-head attention tests passed!")

test_multihead()

In [None]:
grader.check("multihead")

<!-- END QUESTION -->

## 6. Positional Encoding

Self-attention is **permutation equivariant**: shuffling input tokens shuffles output tokens the same way. But word order matters!

"The cat sat on the mat" vs "The mat sat on the cat"

We inject position information using **positional encodings** added to the input embeddings.

### Sinusoidal Positional Encoding

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

**Why sinusoidal?**
- Unique encoding for each position
- Can extrapolate to longer sequences than seen during training
- Relative positions can be computed via linear transformation

### Exercise 3: Implement Positional Encoding (10 points)

In [None]:
class PositionalEncoding(nn.Module):
    """
    Sinusoidal positional encoding.
    
    Args:
        d_model: Model dimension
        max_len: Maximum sequence length (default: 5000)
        dropout: Dropout probability (default: 0.1)
    """
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # TODO: Create positional encoding matrix
        # pe shape: (max_len, d_model)
        # Use the sinusoidal formulas above
        # Register as buffer (not a parameter) using self.register_buffer('pe', pe)
        
        # YOUR CODE HERE
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Fill in sin for even indices, cos for odd indices
        pe[:, 0::2] = ...  # YOUR CODE HERE
        pe[:, 1::2] = ...  # YOUR CODE HERE
        
        # Add batch dimension and register
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model) input embeddings
        Returns:
            (batch, seq_len, d_model) embeddings + positional encoding
        """
        # TODO: Add positional encoding to x
        # Remember to only use positions up to seq_len
        
        x = ...  # YOUR CODE HERE
        return self.dropout(x)

In [None]:
# Visualize positional encodings
pe = PositionalEncoding(d_model=64, dropout=0.0)
x = torch.zeros(1, 100, 64)
pe_values = pe(x)[0].numpy()

plt.figure(figsize=(12, 4))
plt.imshow(pe_values.T, aspect='auto', cmap='RdBu')
plt.colorbar()
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Sinusoidal Positional Encoding')
plt.show()

In [None]:
grader.check("positional")

## 7. Transformer Block

A complete Transformer block consists of:
1. Multi-head self-attention with residual connection and layer norm
2. Feed-forward network with residual connection and layer norm

**Pre-LN variant** (used in GPT-2):
$$\bfx' = \bfx + \text{MultiHeadAttn}(\text{LayerNorm}(\bfx))$$
$$\bfx'' = \bfx' + \text{FFN}(\text{LayerNorm}(\bfx'))$$

**Feed-Forward Network**:
$$\text{FFN}(\bfx) = \text{GELU}(\bfx \bfW_1 + \bfb_1) \bfW_2 + \bfb_2$$

Typically $d_{ff} = 4 \cdot d_{model}$.

In [None]:
class TransformerBlock(nn.Module):
    """
    A single Transformer block with pre-layer normalization.
    
    Args:
        d_model: Model dimension
        n_heads: Number of attention heads
        d_ff: Feed-forward hidden dimension (default: 4*d_model)
        dropout: Dropout probability
    """
    
    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Pre-LN: LayerNorm before attention/FFN
        x = x + self.dropout(self.attn(self.ln1(x), mask))
        x = x + self.ffn(self.ln2(x))
        return x

# Test transformer block
block = TransformerBlock(d_model=64, n_heads=4)
x = torch.randn(2, 10, 64)
out = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

## 8. Causal Masking for Language Modeling

For **autoregressive** language modeling (like GPT), position $i$ can only attend to positions $\leq i$.

We achieve this with a **causal mask**:
$$\text{mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}$$

After softmax, future positions have zero attention weight.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal (lower-triangular) attention mask.
    
    Args:
        seq_len: Sequence length
        
    Returns:
        mask: (seq_len, seq_len) tensor with 0 for allowed positions, -inf for masked
    """
    # Create upper triangular matrix of ones (excluding diagonal)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    # Replace 1s with -inf
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

# Visualize causal mask
mask = create_causal_mask(8)
plt.figure(figsize=(6, 6))
plt.imshow(mask.numpy(), cmap='RdBu', vmin=-10, vmax=0)
plt.colorbar()
plt.xlabel('Key position')
plt.ylabel('Query position')
plt.title('Causal Mask (blue=allowed, red=-inf)')
plt.show()

## 9. Building MiniGPT

<!-- BEGIN QUESTION -->

### Exercise 4: Implement MiniGPT (20 points)

Assemble all components into a complete GPT-style language model.

In [None]:
class MiniGPT(nn.Module):
    """
    A minimal GPT-style language model.
    
    Args:
        vocab_size: Size of vocabulary
        d_model: Model dimension (default: 128)
        n_heads: Number of attention heads (default: 4)
        n_layers: Number of transformer blocks (default: 4)
        block_size: Maximum sequence length (default: 64)
        dropout: Dropout probability (default: 0.1)
    """
    
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=4,
                 block_size=64, dropout=0.1):
        super().__init__()
        self.block_size = block_size
        
        # TODO: Define the following layers:
        # self.tok_emb: Token embedding (vocab_size -> d_model)
        # self.pos_emb: Learnable position embedding (block_size -> d_model)
        # self.dropout: Dropout layer
        # self.blocks: ModuleList of n_layers TransformerBlocks
        # self.ln_f: Final LayerNorm
        # self.head: Output projection (d_model -> vocab_size)
        
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout=dropout)
            for _ in range(n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying (optional but common)
        self.tok_emb.weight = self.head.weight
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
            
    def forward(self, idx, targets=None):
        """
        Args:
            idx: (batch, seq_len) token indices
            targets: (batch, seq_len) target indices for loss computation
            
        Returns:
            logits: (batch, seq_len, vocab_size) output logits
            loss: Cross-entropy loss if targets provided, else None
        """
        B, T = idx.size()
        assert T <= self.block_size, f"Sequence length {T} exceeds block_size {self.block_size}"
        
        # TODO: Implement forward pass
        # 1. Get token embeddings: (B, T, d_model)
        # 2. Get position embeddings for positions 0..T-1: (T, d_model)
        # 3. Add token + position embeddings and apply dropout
        # 4. Create causal mask
        # 5. Pass through all transformer blocks
        # 6. Apply final layer norm
        # 7. Project to vocabulary size
        # 8. If targets provided, compute cross-entropy loss
        
        # YOUR CODE HERE
        tok_emb = self.tok_emb(idx)  # (B, T, d_model)
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)  # (T,)
        pos_emb = self.pos_emb(pos)  # (T, d_model)
        
        x = self.dropout(tok_emb + pos_emb)
        
        # Causal mask
        mask = create_causal_mask(T).to(idx.device)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x, mask)
            
        x = self.ln_f(x)
        logits = self.head(x)  # (B, T, vocab_size)
        
        # Compute loss if targets provided
        loss = None
        if targets is not None:
            # Reshape for cross-entropy
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            
        return logits, loss
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0):
        """
        Generate new tokens autoregressively.
        
        Args:
            idx: (batch, seq_len) conditioning tokens
            max_new_tokens: Number of tokens to generate
            temperature: Sampling temperature (higher = more random)
            
        Returns:
            (batch, seq_len + max_new_tokens) generated sequence
        """
        for _ in range(max_new_tokens):
            # Crop to block_size if needed
            idx_cond = idx[:, -self.block_size:]
            
            # Get predictions
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature  # (B, vocab_size)
            
            # Sample from distribution
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            
            # Append to sequence
            idx = torch.cat([idx, idx_next], dim=1)
            
        return idx

In [None]:
# Test MiniGPT
def test_minigpt():
    vocab_size = 100
    model = MiniGPT(vocab_size, d_model=64, n_heads=4, n_layers=2, block_size=32)
    
    # Test forward pass
    idx = torch.randint(0, vocab_size, (2, 20))
    targets = torch.randint(0, vocab_size, (2, 20))
    logits, loss = model(idx, targets)
    
    assert logits.shape == (2, 20, vocab_size), f"Expected logits shape (2, 20, {vocab_size}), got {logits.shape}"
    assert loss is not None, "Loss should not be None when targets provided"
    
    # Test generation
    prompt = torch.randint(0, vocab_size, (1, 5))
    generated = model.generate(prompt, max_new_tokens=10)
    assert generated.shape == (1, 15), f"Expected generated shape (1, 15), got {generated.shape}"
    
    print("MiniGPT tests passed!")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

test_minigpt()

In [None]:
grader.check("minigpt")

## 10. Training on Shakespeare

### Exercise 5: Train and Generate (15 points)

Train MiniGPT on Shakespeare text and generate new text.

In [None]:
# Download Shakespeare dataset
!mkdir -p data
!wget -q -O data/shakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

# Load and inspect
with open('data/shakespeare.txt', 'r') as f:
    text = f.read()
print(f"Dataset size: {len(text):,} characters")
print(f"\nFirst 500 characters:\n{text[:500]}")

In [None]:
class CharDataset(Dataset):
    """Character-level dataset for language modeling."""
    
    def __init__(self, text, block_size):
        chars = sorted(set(text))
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}
        self.vocab_size = len(chars)
        self.block_size = block_size
        self.data = torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
        
    def __len__(self):
        return len(self.data) - self.block_size
    
    def __getitem__(self, idx):
        x = self.data[idx:idx+self.block_size]
        y = self.data[idx+1:idx+self.block_size+1]
        return x, y
    
    def encode(self, s):
        return torch.tensor([self.stoi[c] for c in s], dtype=torch.long)
    
    def decode(self, t):
        return ''.join([self.itos[i.item()] for i in t])

# Create dataset
block_size = 64
dataset = CharDataset(text, block_size)
print(f"Vocabulary size: {dataset.vocab_size}")
print(f"Vocabulary: {''.join(dataset.itos.values())}")

In [None]:
def train_minigpt(model, dataset, epochs=5, batch_size=64, lr=3e-4):
    """
    Train MiniGPT on character dataset.
    
    Args:
        model: MiniGPT model
        dataset: CharDataset
        epochs: Number of training epochs
        batch_size: Batch size
        lr: Learning rate
        
    Returns:
        model: Trained model
        losses: List of training losses
    """
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    model.train()
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            
            logits, loss = model(x, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}")
                
        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1} complete. Average loss: {avg_loss:.4f}")
        
    return model, losses

In [None]:
# Create and train model
model = MiniGPT(
    vocab_size=dataset.vocab_size,
    d_model=128,
    n_heads=4,
    n_layers=4,
    block_size=block_size,
    dropout=0.1
).to(DEVICE)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train (use more epochs for better results)
trained_model, losses = train_minigpt(model, dataset, epochs=3, batch_size=64)

In [None]:
# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

In [None]:
# Generate text
def generate_text(model, dataset, prompt, max_new_tokens=200, temperature=0.8):
    """Generate text from a prompt."""
    model.eval()
    prompt_enc = dataset.encode(prompt).unsqueeze(0).to(DEVICE)
    generated = model.generate(prompt_enc, max_new_tokens=max_new_tokens, temperature=temperature)
    return dataset.decode(generated[0])

# Generate with different prompts
prompts = [
    "ROMEO:",
    "To be or not to be",
    "The king"
]

for prompt in prompts:
    print(f"\n{'='*60}")
    print(f"Prompt: {prompt}")
    print(f"{'='*60}")
    generated = generate_text(trained_model, dataset, prompt, max_new_tokens=200)
    print(generated)

In [None]:
grader.check("generate")

<!-- END QUESTION -->

## 11. Attention Visualization

Let's visualize what the attention heads are learning.

In [None]:
def visualize_attention(model, text, dataset, layer=0, head=0):
    """
    Visualize attention weights for a given text.
    """
    model.eval()
    tokens = dataset.encode(text).unsqueeze(0).to(DEVICE)
    
    # Hook to capture attention weights
    attention_weights = []
    
    def hook(module, input, output):
        # Output might be tuple (output, weights) or just output
        if isinstance(output, tuple):
            attention_weights.append(output)
    
    # We need to modify MultiHeadAttention to return attention weights
    # For now, let's create a simple visualization
    
    # Create attention matrix for visualization
    T = len(text)
    Q = torch.randn(1, T, 64)
    K = torch.randn(1, T, 64)
    _, weights = scaled_dot_product_attention(Q, K, Q)
    
    # Apply causal mask
    mask = create_causal_mask(T)
    _, weights_causal = scaled_dot_product_attention(Q, K, Q, mask)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Without causal mask
    ax = axes[0]
    im = ax.imshow(weights[0].detach().numpy(), cmap='Blues')
    ax.set_title('Bidirectional Attention')
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')
    
    # With causal mask
    ax = axes[1]
    im = ax.imshow(weights_causal[0].detach().numpy(), cmap='Blues')
    ax.set_title('Causal Attention (GPT-style)')
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')
    
    plt.tight_layout()
    plt.show()

visualize_attention(trained_model, "To be or not", dataset)

## 12. BERT vs GPT Architectures

| Aspect | BERT (Encoder) | GPT (Decoder) |
|--------|----------------|---------------|
| Attention | Bidirectional | Causal (left-to-right) |
| Training | Masked Language Modeling | Autoregressive LM |
| Use case | Classification, NER, QA | Text generation |
| Mask | No mask (sees all tokens) | Causal mask |

**BERT** (Bidirectional Encoder Representations from Transformers):
- Trained to predict masked tokens
- Can see context in both directions
- Good for understanding tasks

**GPT** (Generative Pre-trained Transformer):
- Trained to predict next token
- Can only see past context
- Good for generation tasks

## Summary

1. **Attention** is a soft dictionary lookup: weighted average based on query-key similarity
2. **Scaled dot-product attention**: $\text{softmax}(\bfQ\bfK^\top/\sqrt{d_k})\bfV$
3. **Self-attention**: Q, K, V all derived from the same input
4. **Multi-head attention**: Multiple parallel attention heads capture different relationships
5. **Positional encoding**: Inject position information (sinusoidal or learned)
6. **Transformer block**: Multi-head attention + FFN with residual connections and layer norm
7. **Causal masking**: Prevent attending to future tokens for autoregressive generation

## Submission

Make sure you have run all cells in order before exporting.

In [None]:
# Save your notebook first, then run this cell to export
grader.export(run_tests=True)