# Recurrent Neural Networks: From Vanilla RNNs to LSTMs

## 🎯 Introduction

Welcome to the world of sequential processing! This notebook will take you through the evolution of recurrent neural networks, from simple vanilla RNNs to sophisticated LSTMs and GRUs. Understanding RNNs is crucial for appreciating why transformers were such a breakthrough.

### 🧠 What You'll Understand

This comprehensive guide covers:
- **Vanilla RNNs**: The basic recurrent mechanism and why it struggles
- **The vanishing gradient problem**: Why deep-in-time training fails
- **LSTM architecture**: How gates solve the memory problem
- **GRU simplification**: A streamlined alternative to LSTMs
- **Bidirectional processing**: Using context from both directions

### 🎓 Prerequisites

- Solid understanding of MLPs and backpropagation
- Familiarity with sequence modeling concepts
- Basic knowledge of PyTorch modules and training loops
- Understanding of gradient flow in neural networks

### 🚀 Why RNNs Were Revolutionary

RNNs introduced several key innovations:
- **Sequential processing**: Handle variable-length sequences naturally
- **Parameter sharing**: Same weights across all time steps
- **Memory mechanism**: Hidden state carries information through time
- **Temporal modeling**: Understand patterns that unfold over time
- **Contextual understanding**: Each output depends on entire history

---

## 📚 Table of Contents

1. **[Vanilla RNN Mechanics](#vanilla-rnn-mechanics)** - Understanding the basic recurrent mechanism
2. **[The Vanishing Gradient Problem](#vanishing-gradient-problem)** - Why simple RNNs struggle with long sequences
3. **[LSTM Architecture Deep Dive](#lstm-architecture-deep-dive)** - How gates enable long-term memory
4. **[GRU: Simplified Gating](#gru-simplified-gating)** - A streamlined alternative to LSTMs
5. **[Bidirectional RNNs](#bidirectional-rnns)** - Using context from both directions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import numpy as np
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Vanilla RNN Mechanics

### 🔄 The Basic Recurrent Mechanism

Vanilla RNNs introduced the revolutionary idea of memory in neural networks. Unlike feedforward networks that process inputs independently, RNNs maintain a hidden state that gets updated at each time step, allowing them to remember past information.

In [None]:
# =============================================================================
# VANILLA RNN: THE FOUNDATION OF SEQUENTIAL MODELING
# =============================================================================

print("🔄 Vanilla RNN Implementation")
print("=" * 50)

class VanillaRNN(nn.Module):
    """
    Basic RNN implementation to understand the core mechanism.
    
    The key insight: At each time step, combine current input with
    previous hidden state to produce new hidden state.
    """
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        # Weight matrices for RNN computation
        # W_ih: input-to-hidden transformation
        # W_hh: hidden-to-hidden transformation (the "recurrent" part)
        # W_ho: hidden-to-output transformation
        self.W_ih = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_ho = nn.Linear(hidden_size, output_size)
        
        # Activation function (tanh is traditional for vanilla RNNs)
        self.activation = nn.Tanh()
    
    def forward(self, input_sequence, h_0=None):
        """
        Process a sequence through the RNN.
        
        Args:
            input_sequence: [batch_size, seq_len, input_size]
            h_0: Initial hidden state [batch_size, hidden_size] (optional)
        
        Returns:
            outputs: [batch_size, seq_len, output_size]
            final_hidden: [batch_size, hidden_size]
        """
        batch_size, seq_len, input_size = input_sequence.shape
        
        # Initialize hidden state if not provided
        if h_0 is None:
            h_t = torch.zeros(batch_size, self.hidden_size)
        else:
            h_t = h_0
        
        outputs = []
        hidden_states = [h_t]  # Track hidden state evolution
        
        print(f"Processing sequence: batch_size={batch_size}, seq_len={seq_len}")
        
        # Process each time step sequentially (this is the key limitation!)
        for t in range(seq_len):
            # Get input at current time step
            x_t = input_sequence[:, t, :]  # [batch_size, input_size]
            
            # Core RNN computation: h_t = tanh(W_ih * x_t + W_hh * h_{t-1})
            # This is where the "memory" happens - h_t depends on h_{t-1}
            input_contribution = self.W_ih(x_t)      # Input → Hidden
            hidden_contribution = self.W_hh(h_t)     # Previous Hidden → Current Hidden
            
            # Combine input and hidden contributions
            h_t = self.activation(input_contribution + hidden_contribution)
            
            # Compute output at this time step
            y_t = self.W_ho(h_t)  # Hidden → Output
            
            outputs.append(y_t)
            hidden_states.append(h_t)
            
            if t < 3:  # Show first few steps in detail
                print(f"  Step {t}: input {x_t.shape} + hidden {h_t.shape} → output {y_t.shape}")
        
        # Stack outputs across time dimension
        outputs = torch.stack(outputs, dim=1)  # [batch_size, seq_len, output_size]
        
        print(f"Final output shape: {outputs.shape}")
        
        return outputs, h_t, hidden_states

# Demonstrate RNN with a simple sequence
print(f"\n🎯 RNN in Action")
print("=" * 50)

# Create a simple RNN
batch_size, seq_len, input_size = 2, 5, 3
hidden_size, output_size = 4, 2

rnn = VanillaRNN(input_size, hidden_size, output_size)

# Create a simple input sequence
# Let's use a sequence where each time step has increasing values
input_seq = torch.randn(batch_size, seq_len, input_size)

print(f"Input sequence shape: {input_seq.shape}")
print(f"Network: {input_size} → {hidden_size} → {output_size}")

# Process through RNN
outputs, final_hidden, hidden_states = rnn(input_seq)

print(f"\n📊 RNN Processing Analysis")
print("=" * 50)

print(f"Hidden state evolution (shape {hidden_states[0].shape}):")
print("Time | Hidden State Mean | Hidden State Std")
print("-----|-------------------|------------------")

for t, h in enumerate(hidden_states):
    mean_val = h.mean().item()
    std_val = h.std().item()
    print(f"{t:4d} | {mean_val:16.4f} | {std_val:16.4f}")

print(f"\n💡 Key RNN Properties")
print("=" * 50)
print("1. **Sequential Processing**: Must process t=0 before t=1, etc.")
print("2. **Parameter Sharing**: Same W_ih, W_hh, W_ho for all time steps")
print("3. **Variable Length**: Can handle any sequence length")
print("4. **Memory**: Hidden state h_t carries information from all previous steps")
print("5. **Bottleneck**: All history compressed into fixed-size hidden state")

print(f"\n⚠️ Vanilla RNN Limitations")
print("=" * 50)
print("1. **Vanishing Gradients**: Hard to learn long-term dependencies")
print("2. **No Parallelization**: Sequential nature prevents parallel computation")
print("3. **Memory Overwriting**: New information can erase old information")
print("4. **Gradient Explosion**: Gradients can grow exponentially")
print("5. **Limited Context**: Fixed hidden size limits memory capacity")

# Demonstrate the parameter sharing aspect
print(f"\n🔍 Parameter Sharing Demonstration")
print("=" * 50)

total_params = sum(p.numel() for p in rnn.parameters())
print(f"Total parameters: {total_params}")
print(f"Parameters used at EVERY time step - this is efficient!")

print(f"\nParameter breakdown:")
print(f"W_ih (input→hidden): {rnn.W_ih.weight.shape} = {rnn.W_ih.weight.numel()} params")
print(f"W_hh (hidden→hidden): {rnn.W_hh.weight.shape} = {rnn.W_hh.weight.numel()} params")
print(f"W_ho (hidden→output): {rnn.W_ho.weight.shape} = {rnn.W_ho.weight.numel()} params")

print(f"\n✅ RNN achieves memory with constant parameters!")
print(f"Sequence length doesn't affect parameter count.")

## Minimal Sequence Classifier with LSTM

In [None]:
class SequenceClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, n_layers=1):
        super(SequenceClassifier, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x shape: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        
        # LSTM returns output for all time steps and final hidden state
        lstm_out, (hidden, cell) = self.lstm(embedded)
        
        # Use the last hidden state for classification
        # hidden shape: (n_layers, batch, hidden_dim)
        last_hidden = hidden[-1]  # Take last layer: (batch, hidden_dim)
        
        # Apply dropout and classify
        output = self.fc(self.dropout(last_hidden))
        
        return output

# Create synthetic sequence data for sentiment analysis
def create_synthetic_sequences(vocab_size=100, n_samples=1000, min_len=5, max_len=20):
    """Create synthetic sequences with binary sentiment labels"""
    torch.manual_seed(42)
    
    sequences = []
    labels = []
    
    for _ in range(n_samples):
        # Random sequence length
        length = torch.randint(min_len, max_len + 1, (1,)).item()
        
        # Generate random sequence
        seq = torch.randint(1, vocab_size, (length,))  # Start from 1 (0 is padding)
        
        # Simple rule: if sum of tokens is even -> positive (1), else negative (0)
        label = int(seq.sum().item() % 2)
        
        sequences.append(seq)
        labels.append(label)
    
    return sequences, torch.tensor(labels)

# Generate data
vocab_size = 50
sequences, labels = create_synthetic_sequences(vocab_size, n_samples=1000)

print(f"Generated {len(sequences)} sequences")
print(f"Example sequence: {sequences[0]}")
print(f"Example label: {labels[0]}")
print(f"Label distribution: {torch.bincount(labels)}")

# Pad sequences to same length
def collate_batch(sequences, labels, pad_token=0):
    """Pad sequences and create batch"""
    # Pad sequences
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_token)
    return padded_sequences, labels

# Split data
train_size = int(0.8 * len(sequences))
train_sequences = sequences[:train_size]
train_labels = labels[:train_size]
val_sequences = sequences[train_size:]
val_labels = labels[train_size:]

# Create batches
train_X, train_y = collate_batch(train_sequences, train_labels)
val_X, val_y = collate_batch(val_sequences, val_labels)

print(f"\nTrain data shape: {train_X.shape}, {train_y.shape}")
print(f"Val data shape: {val_X.shape}, {val_y.shape}")

# Create model
model = SequenceClassifier(
    vocab_size=vocab_size,
    embed_dim=32,
    hidden_dim=64,
    output_dim=2,  # Binary classification
    n_layers=2
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
def train_model(model, train_X, train_y, val_X, val_y, epochs=30):
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(epochs):
        # Training
        model.train()
        optimizer.zero_grad()
        
        outputs = model(train_X)
        loss = criterion(outputs, train_y)
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Calculate training accuracy
        with torch.no_grad():
            _, predicted = torch.max(outputs.data, 1)
            train_acc = (predicted == train_y).sum().item() / len(train_y)
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_X)
            val_loss = criterion(val_outputs, val_y)
            
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_acc = (val_predicted == val_y).sum().item() / len(val_y)
        
        train_losses.append(loss.item())
        val_losses.append(val_loss.item())
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch:2d}: Train Loss: {loss.item():.4f}, Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {val_loss.item():.4f}, Val Acc: {val_acc:.4f}")
    
    return train_losses, val_losses, train_accs, val_accs

# Train the model
train_losses, val_losses, train_accs, val_accs = train_model(
    model, train_X, train_y, val_X, val_y, epochs=30
)

print(f"\nFinal validation accuracy: {val_accs[-1]:.4f}")

## Variable Length Sequence Handling

In [None]:
# Efficient variable length sequence processing
def create_variable_length_batch():
    """Create a batch with different sequence lengths"""
    sequences = [
        torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),      # length 8
        torch.tensor([10, 11, 12, 13]),               # length 4  
        torch.tensor([20, 21, 22, 23, 24, 25]),      # length 6
        torch.tensor([30, 31]),                       # length 2
        torch.tensor([40, 41, 42, 43, 44])           # length 5
    ]
    
    lengths = [len(seq) for seq in sequences]
    return sequences, lengths

sequences, lengths = create_variable_length_batch()
print("Original sequences:")
for i, (seq, length) in enumerate(zip(sequences, lengths)):
    print(f"Seq {i}: {seq.tolist()} (length: {length})")

# Method 1: Simple padding (inefficient)
print("\n=== Method 1: Simple Padding ===")
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
print(f"Padded sequences shape: {padded_sequences.shape}")
print("Padded sequences:")
print(padded_sequences.numpy())

# Method 2: Packed sequences (efficient)
print("\n=== Method 2: Packed Sequences ===")

class EfficientLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(EfficientLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, sequences, lengths):
        # Pack padded sequences
        packed = pack_padded_sequence(
            sequences, lengths, batch_first=True, enforce_sorted=False
        )
        
        # Run LSTM on packed sequences
        packed_output, (hidden, cell) = self.lstm(packed)
        
        # Unpack sequences
        output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)
        
        # Use the last valid output for each sequence
        batch_size = output.size(0)
        last_outputs = []
        
        for i in range(batch_size):
            # Get the last valid time step for this sequence
            last_idx = output_lengths[i] - 1
            last_outputs.append(output[i, last_idx, :])
        
        last_outputs = torch.stack(last_outputs)
        
        # Final classification
        return self.fc(last_outputs)

# Create embeddings for our sequences (treat as token indices)
vocab_size = 50
embed_dim = 16
embedding = nn.Embedding(vocab_size, embed_dim)

# Convert sequences to embeddings
embedded_sequences = [embedding(seq) for seq in sequences]
embedded_padded = pad_sequence(embedded_sequences, batch_first=True)

print(f"Embedded padded shape: {embedded_padded.shape}")  # (batch, max_seq_len, embed_dim)

# Test efficient LSTM
efficient_model = EfficientLSTM(embed_dim, 32, 2)
output = efficient_model(embedded_padded, lengths)

print(f"Output shape: {output.shape}")  # (batch_size, output_size)
print("Efficient processing completed - no computation wasted on padding!")

# Show the difference in computation
print(f"\nTotal padded length: {padded_sequences.shape[0] * padded_sequences.shape[1]}")
print(f"Total actual length: {sum(lengths)}")
print(f"Efficiency gain: {(1 - sum(lengths)/(padded_sequences.shape[0] * padded_sequences.shape[1]))*100:.1f}% less computation")

## RNN vs LSTM vs GRU Comparison

In [None]:
# Compare different RNN architectures
class RNNComparison(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rnn_type='LSTM'):
        super(RNNComparison, self).__init__()
        self.rnn_type = rnn_type
        
        if rnn_type == 'RNN':
            self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        elif rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        if self.rnn_type == 'LSTM':
            output, (hidden, cell) = self.rnn(x)
            last_hidden = hidden[-1]
        else:  # RNN or GRU
            output, hidden = self.rnn(x)
            last_hidden = hidden[-1]
        
        return self.fc(last_hidden)

# Create test models
input_size, hidden_size, output_size = 32, 64, 2

models = {
    'RNN': RNNComparison(input_size, hidden_size, output_size, 'RNN'),
    'LSTM': RNNComparison(input_size, hidden_size, output_size, 'LSTM'),
    'GRU': RNNComparison(input_size, hidden_size, output_size, 'GRU')
}

# Compare parameter counts
print("=== Parameter Comparison ===")
for name, model in models.items():
    param_count = sum(p.numel() for p in model.parameters())
    print(f"{name}: {param_count:,} parameters")

# Test on long sequence (to show vanishing gradient problem)
seq_len = 100
batch_size = 16
x = torch.randn(batch_size, seq_len, input_size)

print(f"\n=== Forward Pass Test ===")
print(f"Input shape: {x.shape}")

for name, model in models.items():
    output = model(x)
    print(f"{name} output shape: {output.shape}")

# Gradient analysis
print(f"\n=== Gradient Flow Analysis ===")

# Create a simple task: remember the first input value
def create_memory_task(seq_len=50, batch_size=32):
    """Create a task that requires remembering the first time step"""
    x = torch.randn(batch_size, seq_len, 1)
    # Target is based on the sign of the first time step
    y = (x[:, 0, 0] > 0).long()
    return x, y

# Test gradient flow
x, y = create_memory_task(seq_len=50, batch_size=32)
criterion = nn.CrossEntropyLoss()

for name, model in models.items():
    # Reset model
    for param in model.parameters():
        param.grad = None
    
    # Forward and backward
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    
    # Check gradient norms
    total_norm = 0
    for param in model.parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    
    print(f"{name}: Loss = {loss.item():.4f}, Gradient norm = {total_norm:.4f}")

print("\n=== Architecture Summary ===")
print("RNN: Simple, fast, but suffers from vanishing gradients")
print("LSTM: Complex gating, best for long sequences, most parameters")
print("GRU: Simplified gating, good compromise between RNN and LSTM")
print("\nRule of thumb:")
print("- Short sequences (<20): RNN might be sufficient")
print("- Long sequences (>50): LSTM or GRU")
print("- When in doubt: try GRU first (good balance of performance/complexity)")