# Exercise: Compare RNN, LSTM, and GRU for Sequence Prediction

**Scenario**: Your keyboard team has limited compute budget for real-time inference on mobile devices. The product manager wants the smallest model that still handles long-context prediction. 

**Your Mission**: Benchmark RNN vs. LSTM vs. GRU on prediction accuracy, training speed, and inference latency to make a data-driven recommendation.

**Estimated Time**: 18 minutes

---

## Learning Objectives
- Implement three recurrent architectures with identical configurations
- Compare learning curves, prediction quality, and computational efficiency  
- Make architecture recommendations based on systematic benchmarking
- Understand trade-offs between model complexity and performance

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
import pandas as pd
import time

# Set device and seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
np.random.seed(42)
print(f"Using device: {device}")

## Part A: Implement Three Architectures

**Task**: Create three character-level models with **identical specifications**:
- `embedding_dim=128`
- `hidden_dim=256` 
- `num_layers=2`
- Same prediction head: `Linear(hidden_dim, vocab_size)`

**Your Goal**: Ensure fair comparison by matching all hyperparameters except the core recurrent layer.

In [None]:
# Data preparation (same as demo)
dataset = load_dataset('wikitext', 'wikitext-2-v1')
text = ' '.join(dataset['train']['text'][:120])
text = ''.join(c for c in text if c.isprintable() and c != '\t')
text = text[:60000]

# Character mapping
chars = sorted(set(text))
vocab_size = len(chars)
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for i, ch in enumerate(chars)}

print(f"Vocabulary size: {vocab_size}")
print(f"Sample characters: {chars[:10]}")

In [None]:
# TODO: Implement CharRNN class
class CharRNN(nn.Module):
    """Character-level RNN using nn.RNN"""
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        # TODO: Store dimensions and initialize layers
        # TODO: self.hidden_dim = ?
        # TODO: self.num_layers = ?
        # TODO: self.embedding = nn.Embedding(?, ?)
        # TODO: self.rnn = nn.RNN(?, ?, ?, batch_first=True)
        # TODO: self.output = nn.Linear(?, ?)
        pass
        
    def forward(self, x, hidden=None):
        # TODO: Pass through embedding
        # TODO: Pass through RNN
        # TODO: Use final timestep for prediction: rnn_out[:, -1, :]
        # TODO: Return logits and new hidden state
        pass
    
    def init_hidden(self, batch_size):
        # TODO: Return tensor of zeros with shape (num_layers, batch_size, hidden_dim)
        pass

# TODO: Implement CharLSTM class  
class CharLSTM(nn.Module):
    """Character-level LSTM using nn.LSTM"""
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        # TODO: Same structure as RNN but with LSTM layer and dropout
        # TODO: Add dropout parameter for regularization
        pass
        
    def forward(self, x, hidden=None):
        # TODO: Similar to RNN but handle LSTM's (h, c) tuple
        # TODO: Apply dropout after embedding and before output
        pass
    
    def init_hidden(self, batch_size):
        # TODO: Initialize both h and c for LSTM (return tuple)
        pass

# TODO: Implement CharGRU class
class CharGRU(nn.Module):
    """Character-level GRU using nn.GRU"""  
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        # TODO: Same structure as RNN but with GRU layer and dropout
        pass
        
    def forward(self, x, hidden=None):
        # TODO: Same as RNN (GRU returns single hidden state)
        # TODO: Apply dropout after embedding and before output
        pass
    
    def init_hidden(self, batch_size):
        # TODO: Initialize hidden state for GRU
        pass

print("TODO: Complete the three model implementations above")

In [None]:
# Utility function provided for you
def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters())

# TODO: Initialize models and print parameter counts
# Uncomment and complete after implementing the model classes above
# rnn_model = CharRNN(vocab_size).to(device)
# lstm_model = CharLSTM(vocab_size).to(device)  
# gru_model = CharGRU(vocab_size).to(device)

# print("Parameter Comparison:")
# print(f"RNN:  {count_parameters(rnn_model):,}")
# print(f"LSTM: {count_parameters(lstm_model):,}")
# print(f"GRU:  {count_parameters(gru_model):,}")

print("TODO: Uncomment and run after completing model implementations")

## Part B: Train and Compare Learning

**Task**: Train each model for 20 epochs on the same data split and measure:
- Training loss per epoch
- Validation loss per epoch  
- Training time per epoch
- Final perplexity on test set

**Your Goal**: Generate learning curves that reveal which architecture learns most efficiently.

In [None]:
# Data preparation function provided for you
def prepare_sequences(text, seq_len=50, train_ratio=0.8, val_ratio=0.1):
    """Prepare train/val/test splits"""
    # Convert text to sequences
    X, y = [], []
    for i in range(len(text) - seq_len):
        sequence = [char2idx[ch] for ch in text[i:i+seq_len]]
        target = char2idx[text[i+seq_len]]
        X.append(sequence)
        y.append(target)
    
    X = torch.tensor(X)
    y = torch.tensor(y)
    
    # Create train/validation/test splits (80/10/10)
    total_size = len(X)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)
    
    train_X, train_y = X[:train_size], y[:train_size]
    val_X, val_y = X[train_size:train_size+val_size], y[train_size:train_size+val_size]
    test_X, test_y = X[train_size+val_size:], y[train_size+val_size:]
    
    return train_X, train_y, val_X, val_y, test_X, test_y

# Prepare data splits
train_X, train_y, val_X, val_y, test_X, test_y = prepare_sequences(text)
print(f"Data splits - Train: {train_X.shape}, Val: {val_X.shape}, Test: {test_X.shape}")

In [None]:
# Training helper functions provided for you
def train_epoch(model, X, y, optimizer, criterion, batch_size=64):
    """Train for one epoch and return average loss and time"""
    model.train()
    total_loss = 0
    batch_count = 0
    start_time = time.time()
    
    for i in range(0, len(X) - batch_size, batch_size):
        # Get batch
        batch_x = X[i:i+batch_size].to(device)
        batch_y = y[i:i+batch_size].to(device)
        
        # Initialize hidden state
        hidden = model.init_hidden(batch_x.size(0))
        
        # Forward pass
        optimizer.zero_grad()
        outputs, hidden = model(batch_x, hidden)
        loss = criterion(outputs, batch_y)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # Gradient clipping
        optimizer.step()
        
        total_loss += loss.item()
        batch_count += 1
    
    epoch_time = time.time() - start_time
    avg_loss = total_loss / batch_count if batch_count > 0 else float('inf')
    return avg_loss, epoch_time

def evaluate_model(model, X, y, criterion, batch_size=64):
    """Evaluate model and return loss"""
    model.eval()
    total_loss = 0
    batch_count = 0
    
    with torch.no_grad():
        for i in range(0, len(X) - batch_size, batch_size):
            batch_x = X[i:i+batch_size].to(device)
            batch_y = y[i:i+batch_size].to(device)
            
            hidden = model.init_hidden(batch_x.size(0))
            outputs, hidden = model(batch_x, hidden)
            loss = criterion(outputs, batch_y)
            
            total_loss += loss.item()
            batch_count += 1
    
    return total_loss / batch_count if batch_count > 0 else float('inf')

print("âœ… Training and evaluation helper functions provided")

In [None]:
# TODO: Train all three models and collect metrics
def train_and_compare_models():
    """Train RNN, LSTM, GRU and collect metrics"""
    
    # TODO: Initialize your completed models
    models = {
        'RNN': None,    # TODO: Initialize CharRNN(vocab_size).to(device)
        'LSTM': None,   # TODO: Initialize CharLSTM(vocab_size).to(device)
        'GRU': None     # TODO: Initialize CharGRU(vocab_size).to(device)
    }
    
    # TODO: Create optimizers with different learning rates
    # Hint: Use lr=0.002 for RNN, lr=0.0005 for LSTM, lr=0.001 for GRU
    # Hint: Add weight_decay=1e-5 for LSTM and GRU to prevent overfitting
    optimizers = {
        'RNN': None,    # TODO: optim.Adam(models['RNN'].parameters(), lr=?)
        'LSTM': None,   # TODO: optim.Adam(models['LSTM'].parameters(), lr=?, weight_decay=?)
        'GRU': None     # TODO: optim.Adam(models['GRU'].parameters(), lr=?, weight_decay=?)
    }
    
    criterion = nn.CrossEntropyLoss()
    
    # TODO: Initialize metrics tracking
    metrics = {
        'RNN': {'train_losses': [], 'val_losses': [], 'epoch_times': []},
        'LSTM': {'train_losses': [], 'val_losses': [], 'epoch_times': []},
        'GRU': {'train_losses': [], 'val_losses': [], 'epoch_times': []}
    }
    
    # TODO: Training loop for 20 epochs
    # Hint: For each epoch, train each model and collect metrics
    # Hint: Use train_epoch() and evaluate_model() functions provided above
    # Hint: Use smaller data subsets for faster training: train_X[:8000], val_X[:2000]
    
    for epoch in range(20):
        print(f"\nEpoch {epoch+1}/20")
        
        for model_name in ['RNN', 'LSTM', 'GRU']:
            # TODO: Train the model for this epoch
            # TODO: Evaluate on validation set  
            # TODO: Store metrics
            # TODO: Print progress
            pass
    
    return metrics

# TODO: Run training comparison
# Uncomment after completing the function above and model implementations
# print("ðŸš€ Starting training comparison...")
# metrics = train_and_compare_models()
# print("âœ… Training completed!")

## Part C: Evaluate Prediction Quality  

**Task**: Compare the models on:
- Text generation quality (seed: "The meaning of")
- Test set perplexity 
- Inference speed

**Your Goal**: Determine which model produces the most coherent text and runs fastest.

In [None]:
# Evaluation helper functions provided for you
def calculate_perplexity(model, X, y):
    """Calculate perplexity on test set"""
    criterion = nn.CrossEntropyLoss()
    avg_loss = evaluate_model(model, X, y, criterion)
    return np.exp(avg_loss)

def measure_inference_speed(model, X_sample, num_runs=100):
    """Measure inference time"""
    model.eval()
    
    with torch.no_grad():
        start_time = time.time()
        for _ in range(num_runs):
            hidden = model.init_hidden(1)
            _ = model(X_sample, hidden)
        end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs * 1000  # Convert to milliseconds
    return avg_time

print("âœ… Evaluation helper functions provided")

In [None]:
# Performance comparison helper function provided for you  
def create_performance_table(metrics, models_dict, test_X, test_y):
    """Create comprehensive comparison table"""
    comparison_data = {
        'Model': [],
        'Parameters': [],
        'Final Train Loss': [],
        'Final Val Loss': [],
        'Test Perplexity': [],
        'Avg Epoch Time (s)': [],
        'Inference Speed (ms)': []
    }
    
    # Sample for inference speed testing
    X_sample = test_X[:1].to(device)
    
    for model_name, model in models_dict.items():
        comparison_data['Model'].append(model_name)
        comparison_data['Parameters'].append(f"{count_parameters(model):,}")
        comparison_data['Final Train Loss'].append(f"{metrics[model_name]['train_losses'][-1]:.4f}")
        comparison_data['Final Val Loss'].append(f"{metrics[model_name]['val_losses'][-1]:.4f}")
        comparison_data['Test Perplexity'].append(f"{calculate_perplexity(model, test_X[:1000], test_y[:1000]):.2f}")
        comparison_data['Avg Epoch Time (s)'].append(f"{np.mean(metrics[model_name]['epoch_times']):.1f}")
        comparison_data['Inference Speed (ms)'].append(f"{measure_inference_speed(model, X_sample):.2f}")
    
    return pd.DataFrame(comparison_data)

# TODO: Create and display performance table after training
# Uncomment after completing training
# models_dict = {'RNN': rnn_model, 'LSTM': lstm_model, 'GRU': gru_model}
# performance_table = create_performance_table(metrics, models_dict, test_X, test_y)
# print("\n" + "="*80)
# print("ðŸ“Š COMPREHENSIVE PERFORMANCE COMPARISON")
# print("="*80)
# print(performance_table.to_string(index=False))

print("TODO: Uncomment and run after completing training")

## Part D: Visualization & Analysis

Create learning curves and make your recommendation.

In [None]:
# Visualization function provided for you
def plot_learning_curves(metrics):
    """Plot training and validation curves for all models"""
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    colors = {'RNN': 'red', 'LSTM': 'blue', 'GRU': 'orange'}
    
    # Plot training losses
    ax1.set_title('Training Loss Over Time')
    for model_name in ['RNN', 'LSTM', 'GRU']:
        ax1.plot(metrics[model_name]['train_losses'], 
                label=model_name, color=colors[model_name], linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot validation losses  
    ax2.set_title('Validation Loss Over Time')
    for model_name in ['RNN', 'LSTM', 'GRU']:
        ax2.plot(metrics[model_name]['val_losses'], 
                label=model_name, color=colors[model_name], linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Validation Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot epoch times
    ax3.set_title('Training Time per Epoch')
    for model_name in ['RNN', 'LSTM', 'GRU']:
        ax3.plot(metrics[model_name]['epoch_times'], 
                label=model_name, color=colors[model_name], linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Time (seconds)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Create summary comparison
    ax4.set_title('Final Performance Summary')
    models = ['RNN', 'LSTM', 'GRU']
    final_val_losses = [metrics[name]['val_losses'][-1] for name in models]
    avg_times = [np.mean(metrics[name]['epoch_times']) for name in models]
    
    bars = ax4.bar(models, final_val_losses, color=[colors[name] for name in models])
    ax4.set_ylabel('Final Validation Loss')
    
    # Add time annotations on bars
    for i, (bar, time_val) in enumerate(zip(bars, avg_times)):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{time_val:.1f}s/epoch', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.show()

# TODO: Generate plots after training

## Key Takeaways

**What you should have discovered**:

1. **LSTM vs GRU**: Similar performance, but GRU is more efficient
2. **Vanilla RNN**: Fastest but struggles with long sequences
3. **Parameter Trade-off**: More parameters â‰  always better
4. **Use Case Matters**: Mobile needs different trade-offs than servers

**Next Steps**:
- Try different sequence lengths to find breaking points
- Experiment with different hidden dimensions
- Test on other datasets (code, music, time series)
- Implement attention mechanisms for even longer sequences