# Understanding Attention Mechanism and GPT Transformer

## A Complete Educational Guide

This notebook provides a barebones implementation of:
1. **Attention Mechanism** - The core innovation
2. **GPT-style Decoder-Only Transformer** - For language modeling

### Table of Contents
1. [Introduction](#introduction)
2. [Scaled Dot-Product Attention](#attention)
3. [Multi-Head Attention](#multihead)
4. [Positional Encoding](#positional)
5. [Layer Normalization](#layernorm)
6. [Feed-Forward Networks](#ffn)
7. [Residual Connections](#residual)
8. [Complete Decoder Block](#decoder)
9. [Full GPT Model](#gpt)
10. [Training Example](#training)
11. [Text Generation](#generation)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("PyTorch version:", torch.__version__)

---

## 1. Introduction <a id='introduction'></a>

### What is a Transformer?

The Transformer architecture, introduced in "Attention is All You Need" (2017), revolutionized NLP by replacing recurrence with **attention mechanisms**.

### Key Innovation: Attention

Instead of processing sequences sequentially (like RNNs), transformers process all positions **simultaneously** and let each position **attend** to all other positions.

### GPT vs BERT

- **BERT**: Encoder-only (bidirectional, for understanding)
- **GPT**: Decoder-only (unidirectional, for generation)
- **Original Transformer**: Encoder-decoder (for translation)

We'll implement GPT-style decoder-only transformer for language modeling.

---

## 2. Scaled Dot-Product Attention <a id='attention'></a>

### Intuition

Attention answers: "Given a query, which values should I focus on?"

**Analogy**: You're in a library (keys/values) looking for information (query):
1. You scan book titles (keys) to see which match your query
2. You calculate relevance scores (attention weights)
3. You read the most relevant books (weighted sum of values)

### Mathematical Formulation

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q$ = Query (what we're looking for)
- $K$ = Key (what we're comparing against)
- $V$ = Value (what we want to retrieve)
- $d_k$ = Dimension of keys (for scaling)

### Why Scale by $\sqrt{d_k}$?

Without scaling, for large $d_k$, the dot products grow large in magnitude, pushing softmax into regions with very small gradients (saturation). Scaling keeps the variance stable.

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    The fundamental attention mechanism.
    
    Args:
        query: (batch, seq_len, d_k) - What we're looking for
        key:   (batch, seq_len, d_k) - What we're comparing against
        value: (batch, seq_len, d_k) - What we want to retrieve
        mask:  (batch, seq_len, seq_len) - Optional masking (for causal attention)
    
    Returns:
        output: (batch, seq_len, d_k)
        attention_weights: (batch, seq_len, seq_len)
    """
    d_k = query.size(-1)  # Dimension of key/query
    
    # Step 1: Calculate attention scores
    # For each position i, calculate similarity with all positions j
    scores = torch.matmul(query, key.transpose(-2, -1))  # (batch, seq_len, seq_len)
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Step 3: Apply mask (if provided)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)  # (batch, seq_len, seq_len)
    
    # Step 5: Apply attention weights to values
    output = torch.matmul(attention_weights, value)  # (batch, seq_len, d_k)
    
    return output, attention_weights

### Visualizing Attention

Let's see attention in action with a simple example:

In [None]:
# Example: Simple attention
batch_size = 1
seq_len = 5
d_k = 8

# Create random Q, K, V
Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

# Apply attention
output, attn_weights = scaled_dot_product_attention(Q, K, V)

print("Input shapes:")
print(f"  Query: {Q.shape}")
print(f"  Key:   {K.shape}")
print(f"  Value: {V.shape}")
print(f"\nOutput shapes:")
print(f"  Output: {output.shape}")
print(f"  Attention weights: {attn_weights.shape}")

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(attn_weights[0].detach().numpy(), annot=True, fmt='.2f', cmap='Blues')
plt.title('Attention Weights (How much each position attends to others)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("Each row shows how much that position attends to all other positions.")
print("Rows sum to 1.0 (probability distribution).")

### Causal Masking

For **autoregressive** generation (like GPT), we need **causal masking**: position $i$ can only attend to positions $\leq i$.

This prevents the model from "cheating" by looking at future tokens during training.

In [None]:
# Create causal mask
def create_causal_mask(seq_len):
    """Lower triangular matrix - position i can only see positions <= i"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# Example
seq_len = 5
mask = create_causal_mask(seq_len)

print("Causal Mask (1 = can attend, 0 = cannot attend):")
print(mask.numpy().astype(int))

# Apply attention with causal mask
output_masked, attn_weights_masked = scaled_dot_product_attention(Q, K, V, mask=mask.unsqueeze(0))

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(attn_weights[0].detach().numpy(), annot=True, fmt='.2f', cmap='Blues', ax=axes[0])
axes[0].set_title('Without Causal Mask (can see future)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')

sns.heatmap(attn_weights_masked[0].detach().numpy(), annot=True, fmt='.2f', cmap='Blues', ax=axes[1])
axes[1].set_title('With Causal Mask (cannot see future)')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')

plt.tight_layout()
plt.show()

print("\nNotice: With causal mask, upper triangle is zero (cannot attend to future tokens).")

---

## 3. Multi-Head Attention <a id='multihead'></a>

### Why Multiple Heads?

Single attention allows the model to focus on one aspect. Multiple heads allow the model to attend to different aspects **simultaneously**:

- **Head 1**: Syntactic relationships (subject-verb agreement)
- **Head 2**: Semantic relationships (word meaning)
- **Head 3**: Positional relationships (word order)
- etc.

### Architecture

Instead of one attention with dimension $d_{model}$, we use $h$ heads with dimension $d_k = d_{model}/h$ each:

```
Input (d_model=512)
    ↓
Linear projections (Q, K, V)
    ↓
Split into h=8 heads (each d_k=64)
    ↓
Parallel attention for each head
    ↓
Concatenate heads
    ↓
Final linear projection
    ↓
Output (d_model=512)
```

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention allows the model to attend to different aspects
    of the sequence simultaneously.
    """
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (batch, num_heads, seq_len, d_k)
    
    def combine_heads(self, x):
        """Inverse of split_heads"""
        batch_size, num_heads, seq_len, d_k = x.size()
        x = x.transpose(1, 2)  # (batch, seq_len, num_heads, d_k)
        return x.contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        # Step 1: Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Step 2: Split into multiple heads
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Step 3: Apply attention for each head
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Step 4: Concatenate heads
        output = self.combine_heads(attn_output)
        
        # Step 5: Final linear projection
        output = self.W_o(output)
        
        return output, attn_weights

In [None]:
# Test multi-head attention
d_model = 64
num_heads = 4
batch_size = 1
seq_len = 6

mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)
mask = create_causal_mask(seq_len).unsqueeze(0).unsqueeze(0)

output, attn_weights = mha(x, x, x, mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  (batch, num_heads, seq_len, seq_len)")

# Visualize attention for different heads
fig, axes = plt.subplots(1, num_heads, figsize=(16, 4))
for i in range(num_heads):
    sns.heatmap(attn_weights[0, i].detach().numpy(), annot=True, fmt='.2f', 
                cmap='Blues', ax=axes[i], cbar=False)
    axes[i].set_title(f'Head {i+1}')
    axes[i].set_xlabel('Key')
    if i == 0:
        axes[i].set_ylabel('Query')

plt.suptitle('Attention Patterns Across Different Heads', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("\nObservation: Different heads learn different attention patterns!")

---

## 4. Positional Encoding <a id='positional'></a>

### The Problem

Attention has **no notion of position**. It treats the sequence as a set, not a sequence:
- "The cat sat on the mat" ≈ "mat the on sat cat The"

This is bad for language! Word order matters.

### The Solution: Positional Encoding

Add position information to the embeddings using sine/cosine functions:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

Where:
- $pos$ = position in sequence (0, 1, 2, ...)
- $i$ = dimension index (0, 1, 2, ..., d_model/2)

### Why Sine/Cosine?

1. **Different frequencies** for different dimensions (from high to low)
2. **Deterministic** (no learning required)
3. **Extrapolation** to longer sequences
4. **Relative positions** can be expressed as linear combinations

In [None]:
class PositionalEncoding(nn.Module):
    """
    Adds positional information to embeddings using sine/cosine functions.
    """
    
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                            -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """Add positional encoding to input embeddings"""
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return x

In [None]:
# Visualize positional encodings
d_model = 128
max_len = 100

pos_enc = PositionalEncoding(d_model, max_len)
pe_values = pos_enc.pe[0, :max_len, :].numpy()

fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Plot 1: Heatmap of positional encodings
im = axes[0].imshow(pe_values.T, cmap='RdBu', aspect='auto', vmin=-1, vmax=1)
axes[0].set_xlabel('Position in Sequence')
axes[0].set_ylabel('Embedding Dimension')
axes[0].set_title('Positional Encoding Matrix')
plt.colorbar(im, ax=axes[0])

# Plot 2: Encoding values for specific positions
positions_to_plot = [0, 10, 30, 60]
for pos in positions_to_plot:
    axes[1].plot(pe_values[pos, :64], label=f'Position {pos}', alpha=0.7)
axes[1].set_xlabel('Embedding Dimension (first 64)')
axes[1].set_ylabel('Encoding Value')
axes[1].set_title('Positional Encodings for Different Positions')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("1. Different dimensions have different frequencies (wavelengths)")
print("2. Lower dimensions = higher frequency (change rapidly)")
print("3. Higher dimensions = lower frequency (change slowly)")
print("4. Each position has a unique encoding pattern")

---

## 5. Layer Normalization <a id='layernorm'></a>

### What is Layer Normalization?

Layer normalization normalizes inputs across the **feature dimension** (not the batch dimension like Batch Norm).

For each sample:
$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

Where:
- $\mu$ = mean across features
- $\sigma^2$ = variance across features
- $\gamma, \beta$ = learnable parameters (scale and shift)
- $\epsilon$ = small constant for numerical stability

### Why Layer Norm?

1. **Stabilizes training** - prevents internal covariate shift
2. **Speeds up convergence** - allows higher learning rates
3. **Works with variable batch sizes** (unlike Batch Norm)
4. **Reduces gradient vanishing** in deep networks

### Pre-LN vs Post-LN

**Post-LN** (Original Transformer):
```
x = LayerNorm(x + Sublayer(x))
```

**Pre-LN** (Modern, more stable):
```
x = x + Sublayer(LayerNorm(x))
```

We use Pre-LN as it's more stable and easier to train.

In [None]:
# Demonstrate Layer Normalization
batch_size = 2
seq_len = 5
d_model = 8

# Random input with varying scales
x = torch.randn(batch_size, seq_len, d_model) * torch.tensor([0.1, 1.0, 10.0, 0.5, 2.0, 5.0, 0.2, 8.0])

# Apply layer norm
layer_norm = nn.LayerNorm(d_model)
x_normalized = layer_norm(x)

print("Before Layer Norm:")
print(f"Shape: {x.shape}")
print(f"Mean per sample: {x[0].mean():.4f}, {x[1].mean():.4f}")
print(f"Std per sample:  {x[0].std():.4f}, {x[1].std():.4f}")
print(f"\nFirst sample (one position):")
print(x[0, 0].detach().numpy())

print("\n" + "="*60)
print("\nAfter Layer Norm:")
print(f"Shape: {x_normalized.shape}")
print(f"Mean per sample: {x_normalized[0].mean():.4f}, {x_normalized[1].mean():.4f}")
print(f"Std per sample:  {x_normalized[0].std():.4f}, {x_normalized[1].std():.4f}")
print(f"\nFirst sample (one position):")
print(x_normalized[0, 0].detach().numpy())

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].boxplot([x[0, :, i].detach().numpy() for i in range(d_model)])
axes[0].set_title('Before Layer Norm (varying scales)')
axes[0].set_xlabel('Feature Dimension')
axes[0].set_ylabel('Value')
axes[0].grid(True, alpha=0.3)

axes[1].boxplot([x_normalized[0, :, i].detach().numpy() for i in range(d_model)])
axes[1].set_title('After Layer Norm (normalized)')
axes[1].set_xlabel('Feature Dimension')
axes[1].set_ylabel('Value')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Point: Layer Norm makes all features have similar scales!")

---

## 6. Feed-Forward Networks <a id='ffn'></a>

### What is the Feed-Forward Network?

After attention, each position passes through a simple 2-layer MLP:

$$\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$$

- **Layer 1**: Expand to higher dimension (typically $4 \times d_{model}$)
- **ReLU**: Non-linear activation
- **Layer 2**: Project back to $d_{model}$

### Why FFN?

1. **Non-linearity**: Attention is linear operations; FFN adds non-linearity
2. **Feature transformation**: Processes information gathered by attention
3. **Position-wise**: Applied independently to each position
4. **Capacity**: Most parameters are in FFN layers

In [None]:
class FeedForward(nn.Module):
    """
    Simple 2-layer MLP applied to each position independently.
    """
    
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = F.relu(self.linear1(x))  # (batch, seq_len, d_ff)
        x = self.linear2(x)           # (batch, seq_len, d_model)
        return x

In [None]:
# Test Feed-Forward Network
d_model = 64
d_ff = 256  # Typically 4x d_model

ffn = FeedForward(d_model, d_ff)
x = torch.randn(2, 10, d_model)
output = ffn(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nFFN expands to d_ff={d_ff}, then projects back to d_model={d_model}")

# Count parameters
total_params = sum(p.numel() for p in ffn.parameters())
print(f"\nTotal parameters in FFN: {total_params:,}")
print(f"  Linear1 (d_model → d_ff): {d_model * d_ff + d_ff:,}")
print(f"  Linear2 (d_ff → d_model): {d_ff * d_model + d_model:,}")

---

## 7. Residual Connections <a id='residual'></a>

### What are Residual Connections?

Instead of:
```
x = Layer(x)
```

We do:
```
x = x + Layer(x)  # Residual connection
```

### Why Residual Connections?

1. **Gradient flow**: Direct path for gradients to earlier layers
2. **Deep networks**: Enables training very deep models (100+ layers)
3. **Identity mapping**: Layer can learn to do nothing (identity)
4. **Stable training**: Prevents gradient vanishing

### Intuition

The layer only needs to learn the **residual** (difference) from the input, not the entire transformation. This is often easier.

In [None]:
# Demonstrate residual connections
def simulate_deep_network(depth, use_residual=True):
    """Simulate gradient flow through a deep network"""
    gradient = 1.0
    gradients = [gradient]
    
    for _ in range(depth):
        if use_residual:
            # With residual: gradient flows through both paths
            gradient = gradient * 0.5 + gradient  # Simplified model
        else:
            # Without residual: gradient only flows through layers
            gradient = gradient * 0.5  # Gradient vanishing!
        gradients.append(gradient)
    
    return gradients

# Compare with and without residual connections
depth = 50
grads_with_residual = simulate_deep_network(depth, use_residual=True)
grads_without_residual = simulate_deep_network(depth, use_residual=False)

plt.figure(figsize=(12, 5))
plt.plot(grads_without_residual, label='Without Residual (Vanishing!)', linewidth=2)
plt.plot(grads_with_residual, label='With Residual (Stable)', linewidth=2)
plt.xlabel('Layer Depth')
plt.ylabel('Gradient Magnitude')
plt.title('Gradient Flow: With vs Without Residual Connections')
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.tight_layout()
plt.show()

print("Observation:")
print(f"Without residual: gradient at layer {depth} = {grads_without_residual[-1]:.2e}")
print(f"With residual:    gradient at layer {depth} = {grads_with_residual[-1]:.2e}")
print("\nResidual connections prevent gradient vanishing!")

---

## 8. Complete Decoder Block <a id='decoder'></a>

### Architecture

A decoder block combines everything we've learned:

```
Input
  ↓
Layer Norm
  ↓
Multi-Head Self-Attention (with causal mask)
  ↓
Residual Connection (+)
  ↓
Layer Norm
  ↓
Feed-Forward Network
  ↓
Residual Connection (+)
  ↓
Output
```

This is **Pre-LN** architecture (Layer Norm before sub-layers).

In [None]:
class DecoderBlock(nn.Module):
    """
    A single decoder block: Attention + FFN with residual connections.
    """
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Multi-head self-attention
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        
        # Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff)
        
        # Layer normalization (Pre-LN)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        # Sub-layer 1: Self-Attention
        # Pre-LN: Normalize before attention
        attn_output, _ = self.self_attention(
            self.norm1(x), self.norm1(x), self.norm1(x), mask
        )
        x = x + self.dropout(attn_output)  # Residual connection
        
        # Sub-layer 2: Feed-Forward
        # Pre-LN: Normalize before FFN
        ff_output = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_output)  # Residual connection
        
        return x

In [None]:
# Test Decoder Block
d_model = 128
num_heads = 8
d_ff = 512
batch_size = 2
seq_len = 10

decoder_block = DecoderBlock(d_model, num_heads, d_ff)
x = torch.randn(batch_size, seq_len, d_model)
mask = create_causal_mask(seq_len).unsqueeze(0).unsqueeze(0)

output = decoder_block(x, mask)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nShape is preserved through the decoder block!")

# Count parameters
total_params = sum(p.numel() for p in decoder_block.parameters())
print(f"\nTotal parameters in one decoder block: {total_params:,}")

---

## 9. Full GPT Model <a id='gpt'></a>

### Complete Architecture

```
Input Tokens
    ↓
Token Embedding
    ↓
Positional Encoding
    ↓
Decoder Block 1
    ↓
Decoder Block 2
    ↓
    ...
    ↓
Decoder Block N
    ↓
Layer Norm
    ↓
Linear Projection (to vocabulary)
    ↓
Output Logits
```

In [None]:
class GPTTransformer(nn.Module):
    """
    GPT-style Decoder-Only Transformer for language modeling.
    """
    
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, 
                 d_ff=2048, max_len=5000, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        
        # Stack of decoder blocks
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)
        
        # Output projection to vocabulary
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with small random values"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def create_causal_mask(self, seq_len):
        """Create causal mask for autoregressive generation"""
        mask = torch.tril(torch.ones(seq_len, seq_len))
        return mask.unsqueeze(0).unsqueeze(0)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len) - Token indices
        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        batch_size, seq_len = x.size()
        
        # Token embedding
        x = self.token_embedding(x)
        x = x * math.sqrt(self.d_model)  # Scale embeddings
        
        # Add positional encoding
        x = self.positional_encoding(x)
        x = self.dropout(x)
        
        # Create causal mask
        mask = self.create_causal_mask(seq_len).to(x.device)
        
        # Pass through decoder blocks
        for decoder_block in self.decoder_blocks:
            x = decoder_block(x, mask)
        
        # Final layer norm
        x = self.final_norm(x)
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits

In [None]:
# Create GPT model
vocab_size = 10000
model = GPTTransformer(
    vocab_size=vocab_size,
    d_model=256,
    num_heads=8,
    num_layers=4,
    d_ff=1024,
    dropout=0.1
)

print("="*70)
print("GPT TRANSFORMER MODEL")
print("="*70)
print(f"\nModel Configuration:")
print(f"  Vocabulary Size: {vocab_size:,}")
print(f"  Model Dimension: 256")
print(f"  Number of Heads: 8")
print(f"  Number of Layers: 4")
print(f"  Feed-Forward Dimension: 1024")

# Test forward pass
batch_size = 2
seq_len = 20
input_tokens = torch.randint(0, vocab_size, (batch_size, seq_len))

print(f"\nInput:")
print(f"  Shape: {input_tokens.shape}")
print(f"  Sample tokens: {input_tokens[0, :10].tolist()}")

with torch.no_grad():
    logits = model(input_tokens)

print(f"\nOutput:")
print(f"  Shape: {logits.shape}")
print(f"  (batch_size, seq_len, vocab_size)")

# Get predictions
predictions = torch.argmax(logits, dim=-1)
print(f"\nPredicted tokens: {predictions[0, :10].tolist()}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Size:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / (1024**2):.2f} MB (float32)")

print("\n" + "="*70)

---

## 10. Training Example <a id='training'></a>

Let's create a simple training loop for a toy task: learning to copy sequences.

### Task: Sequence Copying

Given input: `[1, 2, 3, 4, 5]`
Predict: `[1, 2, 3, 4, 5]` (shifted by one position)

In [None]:
# Create a simple toy dataset
def create_toy_dataset(vocab_size, num_samples, seq_len):
    """Create random sequences for training"""
    data = torch.randint(1, vocab_size, (num_samples, seq_len))
    return data

# Training configuration
vocab_size = 100
num_samples = 1000
seq_len = 20
batch_size = 32
num_epochs = 10

# Create dataset
train_data = create_toy_dataset(vocab_size, num_samples, seq_len)

# Create smaller model for faster training
model = GPTTransformer(
    vocab_size=vocab_size,
    d_model=128,
    num_heads=4,
    num_layers=2,
    d_ff=512,
    dropout=0.1
)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("Training Configuration:")
print(f"  Vocabulary Size: {vocab_size}")
print(f"  Training Samples: {num_samples}")
print(f"  Sequence Length: {seq_len}")
print(f"  Batch Size: {batch_size}")
print(f"  Epochs: {num_epochs}")
print(f"  Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\nStarting training...\n")

In [None]:
# Training loop
model.train()
losses = []

for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = len(train_data) // batch_size
    
    for i in range(num_batches):
        # Get batch
        batch = train_data[i * batch_size:(i + 1) * batch_size]
        
        # Input and target (shifted by one position)
        input_seq = batch[:, :-1]  # All except last
        target_seq = batch[:, 1:]  # All except first
        
        # Forward pass
        logits = model(input_seq)
        
        # Calculate loss
        # Reshape for cross-entropy: (batch * seq_len, vocab_size)
        logits_flat = logits.reshape(-1, vocab_size)
        target_flat = target_seq.reshape(-1)
        loss = criterion(logits_flat, target_flat)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / num_batches
    losses.append(avg_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

print("\nTraining complete!")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Initial Loss: {losses[0]:.4f}")
print(f"Final Loss: {losses[-1]:.4f}")
print(f"Improvement: {(1 - losses[-1]/losses[0]) * 100:.1f}%")

---

## 11. Text Generation <a id='generation'></a>

### Autoregressive Generation

GPT generates text one token at a time:

1. Start with a prompt (seed text)
2. Feed prompt through model → get next token probabilities
3. Sample next token
4. Append to prompt
5. Repeat

### Sampling Strategies

- **Greedy**: Always pick highest probability token (deterministic)
- **Temperature Sampling**: Scale logits before softmax (higher = more random)
- **Top-k Sampling**: Sample from top k tokens
- **Top-p (Nucleus) Sampling**: Sample from smallest set with cumulative prob ≥ p

In [None]:
def generate_text(model, start_tokens, max_length=50, temperature=1.0, top_k=None):
    """
    Generate text autoregressively.
    
    Args:
        model: GPT model
        start_tokens: Initial tokens (1D tensor)
        max_length: Maximum length to generate
        temperature: Sampling temperature (higher = more random)
        top_k: If set, sample from top k tokens only
    """
    model.eval()
    generated = start_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_length - len(start_tokens)):
            # Get predictions for current sequence
            logits = model(generated.unsqueeze(0))[0, -1, :]  # Last position
            
            # Apply temperature
            logits = logits / temperature
            
            # Apply top-k filtering if specified
            if top_k is not None:
                top_k_logits, top_k_indices = torch.topk(logits, top_k)
                logits = torch.full_like(logits, float('-inf'))
                logits[top_k_indices] = top_k_logits
            
            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to generated sequence
            generated = torch.cat([generated, next_token])
    
    return generated

In [None]:
# Test generation
model.eval()

# Create a simple prompt
prompt = torch.tensor([1, 2, 3, 4, 5])

print("Text Generation Demo")
print("="*70)
print(f"\nPrompt tokens: {prompt.tolist()}")
print("\nGenerating with different settings:\n")

# Generate with different temperatures
for temp in [0.5, 1.0, 1.5]:
    generated = generate_text(model, prompt, max_length=30, temperature=temp)
    print(f"Temperature {temp}: {generated.tolist()[:30]}")

print("\nNote: Lower temperature = more conservative, Higher temperature = more random")

---

## Summary and Key Takeaways

### Core Components

1. **Attention Mechanism**
   - Allows each position to look at all others
   - Scaled dot-product: $\text{Attention}(Q,K,V) = \text{softmax}(QK^T/\sqrt{d_k})V$
   - Causal masking for autoregressive generation

2. **Multi-Head Attention**
   - Multiple attention mechanisms in parallel
   - Each head learns different patterns
   - Concatenate and project outputs

3. **Positional Encoding**
   - Injects position information using sine/cosine
   - Different frequencies for different dimensions
   - Allows model to understand word order

4. **Layer Normalization**
   - Normalizes activations across features
   - Stabilizes training of deep networks
   - Pre-LN is more stable than Post-LN

5. **Feed-Forward Networks**
   - Simple 2-layer MLP
   - Adds non-linearity
   - Applied position-wise

6. **Residual Connections**
   - Direct gradient flow
   - Enables very deep networks
   - $x = x + \text{Layer}(x)$

### Why GPT Works

- **Scalability**: Can train on massive datasets
- **Parallelization**: All positions processed simultaneously
- **Context**: Can attend to entire sequence
- **Flexibility**: Same architecture for many tasks

### Modern Improvements

- **Larger models**: GPT-3 (175B params), GPT-4 (rumored 1.7T)
- **Better tokenization**: BPE, SentencePiece
- **Efficient attention**: Flash Attention, Sparse Attention
- **Architecture tweaks**: RoPE, GQA, MoE

---

## Further Reading

1. **Original Papers**:
   - "Attention Is All You Need" (Vaswani et al., 2017)
   - "Improving Language Understanding by Generative Pre-Training" (Radford et al., 2018)

2. **Resources**:
   - The Illustrated Transformer (Jay Alammar)
   - Andrej Karpathy's "Let's build GPT" video
   - HuggingFace Transformers library

3. **Advanced Topics**:
   - Flash Attention for efficiency
   - Rotary Position Embeddings (RoPE)
   - Group Query Attention (GQA)
   - Mixture of Experts (MoE)

---

## Exercises

Try these to deepen your understanding:

1. **Modify the attention mechanism**:
   - Implement local attention (only attend to nearby positions)
   - Try different attention patterns

2. **Experiment with architecture**:
   - Change number of heads
   - Try different d_model/d_ff ratios
   - Implement Post-LN instead of Pre-LN

3. **Training improvements**:
   - Add learning rate scheduling
   - Implement gradient clipping
   - Try different optimizers

4. **Generation strategies**:
   - Implement beam search
   - Try nucleus (top-p) sampling
   - Add repetition penalty

5. **Visualization**:
   - Visualize attention patterns on real text
   - Plot embedding space with t-SNE
   - Analyze learned positional encodings