# RNNs, LSTMs, GRUs in PyTorch

This notebook demonstrates recurrent neural networks and their variants.

## Table of Contents
1. [Why Gating Mechanisms?](#why-gating-mechanisms)
2. [Minimal Sequence Classifier with LSTM](#minimal-sequence-classifier-with-lstm)
3. [Variable Length Sequence Handling](#variable-length-sequence-handling)
4. [RNN vs LSTM vs GRU Comparison](#rnn-vs-lstm-vs-gru-comparison)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import numpy as np
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Why Gating Mechanisms?

Gating mechanisms help RNNs remember long-term dependencies and avoid vanishing gradients.

In [None]:
# Simple vanilla RNN implementation
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(VanillaRNN, self).__init__()
        self.hidden_size = hidden_size
        
        # RNN cell parameters
        self.W_ih = nn.Linear(input_size, hidden_size)   # input to hidden
        self.W_hh = nn.Linear(hidden_size, hidden_size)  # hidden to hidden
        self.W_ho = nn.Linear(hidden_size, output_size)  # hidden to output
        
    def forward(self, x, hidden=None):
        batch_size, seq_len, _ = x.size()
        
        if hidden is None:
            hidden = torch.zeros(batch_size, self.hidden_size).to(x.device)
        
        outputs = []
        for t in range(seq_len):
            # h_t = tanh(W_ih * x_t + W_hh * h_{t-1})
            hidden = torch.tanh(self.W_ih(x[:, t, :]) + self.W_hh(hidden))
            output = self.W_ho(hidden)
            outputs.append(output)
        
        return torch.stack(outputs, dim=1), hidden

# LSTM implementation (simplified)
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        
        # LSTM gates
        self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.candidate_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
        
        self.output_layer = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, states=None):
        batch_size, seq_len, _ = x.size()
        
        if states is None:
            h = torch.zeros(batch_size, self.hidden_size).to(x.device)
            c = torch.zeros(batch_size, self.hidden_size).to(x.device)
        else:
            h, c = states
        
        outputs = []
        for t in range(seq_len):
            # Concatenate input and hidden state
            combined = torch.cat([x[:, t, :], h], dim=1)
            
            # LSTM gates
            forget = torch.sigmoid(self.forget_gate(combined))
            input_g = torch.sigmoid(self.input_gate(combined))
            candidate = torch.tanh(self.candidate_gate(combined))
            output_g = torch.sigmoid(self.output_gate(combined))
            
            # Update cell state
            c = forget * c + input_g * candidate
            
            # Update hidden state
            h = output_g * torch.tanh(c)
            
            # Output
            output = self.output_layer(h)
            outputs.append(output)
        
        return torch.stack(outputs, dim=1), (h, c)

# Test both models
input_size, hidden_size, output_size = 10, 32, 5
seq_len, batch_size = 15, 3

rnn_model = VanillaRNN(input_size, hidden_size, output_size)
lstm_model = SimpleLSTM(input_size, hidden_size, output_size)

x = torch.randn(batch_size, seq_len, input_size)

rnn_out, rnn_hidden = rnn_model(x)
lstm_out, lstm_states = lstm_model(x)

print(f"Input shape: {x.shape}")
print(f"RNN output shape: {rnn_out.shape}")
print(f"LSTM output shape: {lstm_out.shape}")
print(f"\nRNN has {sum(p.numel() for p in rnn_model.parameters())} parameters")
print(f"LSTM has {sum(p.numel() for p in lstm_model.parameters())} parameters")
print("LSTM has more parameters due to gating mechanisms")

## Minimal Sequence Classifier with LSTM

In [None]:
class SequenceClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, n_layers=1):
        super(SequenceClassifier, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x shape: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        
        # LSTM returns output for all time steps and final hidden state
        lstm_out, (hidden, cell) = self.lstm(embedded)
        
        # Use the last hidden state for classification
        # hidden shape: (n_layers, batch, hidden_dim)
        last_hidden = hidden[-1]  # Take last layer: (batch, hidden_dim)
        
        # Apply dropout and classify
        output = self.fc(self.dropout(last_hidden))
        
        return output

# Create synthetic sequence data for sentiment analysis
def create_synthetic_sequences(vocab_size=100, n_samples=1000, min_len=5, max_len=20):
    """Create synthetic sequences with binary sentiment labels"""
    torch.manual_seed(42)
    
    sequences = []
    labels = []
    
    for _ in range(n_samples):
        # Random sequence length
        length = torch.randint(min_len, max_len + 1, (1,)).item()
        
        # Generate random sequence
        seq = torch.randint(1, vocab_size, (length,))  # Start from 1 (0 is padding)
        
        # Simple rule: if sum of tokens is even -> positive (1), else negative (0)
        label = int(seq.sum().item() % 2)
        
        sequences.append(seq)
        labels.append(label)
    
    return sequences, torch.tensor(labels)

# Generate data
vocab_size = 50
sequences, labels = create_synthetic_sequences(vocab_size, n_samples=1000)

print(f"Generated {len(sequences)} sequences")
print(f"Example sequence: {sequences[0]}")
print(f"Example label: {labels[0]}")
print(f"Label distribution: {torch.bincount(labels)}")

# Pad sequences to same length
def collate_batch(sequences, labels, pad_token=0):
    """Pad sequences and create batch"""
    # Pad sequences
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_token)
    return padded_sequences, labels

# Split data
train_size = int(0.8 * len(sequences))
train_sequences = sequences[:train_size]
train_labels = labels[:train_size]
val_sequences = sequences[train_size:]
val_labels = labels[train_size:]

# Create batches
train_X, train_y = collate_batch(train_sequences, train_labels)
val_X, val_y = collate_batch(val_sequences, val_labels)

print(f"\nTrain data shape: {train_X.shape}, {train_y.shape}")
print(f"Val data shape: {val_X.shape}, {val_y.shape}")

# Create model
model = SequenceClassifier(
    vocab_size=vocab_size,
    embed_dim=32,
    hidden_dim=64,
    output_dim=2,  # Binary classification
    n_layers=2
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
def train_model(model, train_X, train_y, val_X, val_y, epochs=30):
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(epochs):
        # Training
        model.train()
        optimizer.zero_grad()
        
        outputs = model(train_X)
        loss = criterion(outputs, train_y)
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Calculate training accuracy
        with torch.no_grad():
            _, predicted = torch.max(outputs.data, 1)
            train_acc = (predicted == train_y).sum().item() / len(train_y)
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_X)
            val_loss = criterion(val_outputs, val_y)
            
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_acc = (val_predicted == val_y).sum().item() / len(val_y)
        
        train_losses.append(loss.item())
        val_losses.append(val_loss.item())
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch:2d}: Train Loss: {loss.item():.4f}, Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {val_loss.item():.4f}, Val Acc: {val_acc:.4f}")
    
    return train_losses, val_losses, train_accs, val_accs

# Train the model
train_losses, val_losses, train_accs, val_accs = train_model(
    model, train_X, train_y, val_X, val_y, epochs=30
)

print(f"\nFinal validation accuracy: {val_accs[-1]:.4f}")

## Variable Length Sequence Handling

In [None]:
# Efficient variable length sequence processing
def create_variable_length_batch():
    """Create a batch with different sequence lengths"""
    sequences = [
        torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),      # length 8
        torch.tensor([10, 11, 12, 13]),               # length 4  
        torch.tensor([20, 21, 22, 23, 24, 25]),      # length 6
        torch.tensor([30, 31]),                       # length 2
        torch.tensor([40, 41, 42, 43, 44])           # length 5
    ]
    
    lengths = [len(seq) for seq in sequences]
    return sequences, lengths

sequences, lengths = create_variable_length_batch()
print("Original sequences:")
for i, (seq, length) in enumerate(zip(sequences, lengths)):
    print(f"Seq {i}: {seq.tolist()} (length: {length})")

# Method 1: Simple padding (inefficient)
print("\n=== Method 1: Simple Padding ===")
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
print(f"Padded sequences shape: {padded_sequences.shape}")
print("Padded sequences:")
print(padded_sequences.numpy())

# Method 2: Packed sequences (efficient)
print("\n=== Method 2: Packed Sequences ===")

class EfficientLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(EfficientLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, sequences, lengths):
        # Pack padded sequences
        packed = pack_padded_sequence(
            sequences, lengths, batch_first=True, enforce_sorted=False
        )
        
        # Run LSTM on packed sequences
        packed_output, (hidden, cell) = self.lstm(packed)
        
        # Unpack sequences
        output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)
        
        # Use the last valid output for each sequence
        batch_size = output.size(0)
        last_outputs = []
        
        for i in range(batch_size):
            # Get the last valid time step for this sequence
            last_idx = output_lengths[i] - 1
            last_outputs.append(output[i, last_idx, :])
        
        last_outputs = torch.stack(last_outputs)
        
        # Final classification
        return self.fc(last_outputs)

# Create embeddings for our sequences (treat as token indices)
vocab_size = 50
embed_dim = 16
embedding = nn.Embedding(vocab_size, embed_dim)

# Convert sequences to embeddings
embedded_sequences = [embedding(seq) for seq in sequences]
embedded_padded = pad_sequence(embedded_sequences, batch_first=True)

print(f"Embedded padded shape: {embedded_padded.shape}")  # (batch, max_seq_len, embed_dim)

# Test efficient LSTM
efficient_model = EfficientLSTM(embed_dim, 32, 2)
output = efficient_model(embedded_padded, lengths)

print(f"Output shape: {output.shape}")  # (batch_size, output_size)
print("Efficient processing completed - no computation wasted on padding!")

# Show the difference in computation
print(f"\nTotal padded length: {padded_sequences.shape[0] * padded_sequences.shape[1]}")
print(f"Total actual length: {sum(lengths)}")
print(f"Efficiency gain: {(1 - sum(lengths)/(padded_sequences.shape[0] * padded_sequences.shape[1]))*100:.1f}% less computation")

## RNN vs LSTM vs GRU Comparison

In [None]:
# Compare different RNN architectures
class RNNComparison(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rnn_type='LSTM'):
        super(RNNComparison, self).__init__()
        self.rnn_type = rnn_type
        
        if rnn_type == 'RNN':
            self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        elif rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        if self.rnn_type == 'LSTM':
            output, (hidden, cell) = self.rnn(x)
            last_hidden = hidden[-1]
        else:  # RNN or GRU
            output, hidden = self.rnn(x)
            last_hidden = hidden[-1]
        
        return self.fc(last_hidden)

# Create test models
input_size, hidden_size, output_size = 32, 64, 2

models = {
    'RNN': RNNComparison(input_size, hidden_size, output_size, 'RNN'),
    'LSTM': RNNComparison(input_size, hidden_size, output_size, 'LSTM'),
    'GRU': RNNComparison(input_size, hidden_size, output_size, 'GRU')
}

# Compare parameter counts
print("=== Parameter Comparison ===")
for name, model in models.items():
    param_count = sum(p.numel() for p in model.parameters())
    print(f"{name}: {param_count:,} parameters")

# Test on long sequence (to show vanishing gradient problem)
seq_len = 100
batch_size = 16
x = torch.randn(batch_size, seq_len, input_size)

print(f"\n=== Forward Pass Test ===")
print(f"Input shape: {x.shape}")

for name, model in models.items():
    output = model(x)
    print(f"{name} output shape: {output.shape}")

# Gradient analysis
print(f"\n=== Gradient Flow Analysis ===")

# Create a simple task: remember the first input value
def create_memory_task(seq_len=50, batch_size=32):
    """Create a task that requires remembering the first time step"""
    x = torch.randn(batch_size, seq_len, 1)
    # Target is based on the sign of the first time step
    y = (x[:, 0, 0] > 0).long()
    return x, y

# Test gradient flow
x, y = create_memory_task(seq_len=50, batch_size=32)
criterion = nn.CrossEntropyLoss()

for name, model in models.items():
    # Reset model
    for param in model.parameters():
        param.grad = None
    
    # Forward and backward
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    
    # Check gradient norms
    total_norm = 0
    for param in model.parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    
    print(f"{name}: Loss = {loss.item():.4f}, Gradient norm = {total_norm:.4f}")

print("\n=== Architecture Summary ===")
print("RNN: Simple, fast, but suffers from vanishing gradients")
print("LSTM: Complex gating, best for long sequences, most parameters")
print("GRU: Simplified gating, good compromise between RNN and LSTM")
print("\nRule of thumb:")
print("- Short sequences (<20): RNN might be sufficient")
print("- Long sequences (>50): LSTM or GRU")
print("- When in doubt: try GRU first (good balance of performance/complexity)")