# Practice 1: RNN Fundamentals - The Vanishing Gradient Problem

## SECTION 1: Setup and Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
import time

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## SECTION 2: Data Loading and Preprocessing

In [None]:
print("Loading IMDB dataset...")

# TODO: Load the IMDB dataset
df = ...

print(f"Dataset loaded: {len(df)} reviews")
print(f"Sentiment distribution:\n{df['sentiment'].value_counts()}")

# TODO: Convert sentiment labels to binary
df['label'] = ...

# TODO: Use 30,000 samples
df = ...

# Simple text preprocessing
import re
from collections import Counter

def preprocess_text(text):
    """Basic text preprocessing"""
    text = text.lower()
    # TODO
    text = ... # Remove HTML tags
    text = ...  # Keep only letters
    text = re.sub(r'\s+', ' ', text).strip()
    return text

df['cleaned_review'] = df['review'].apply(preprocess_text)

# Build vocabulary
print("Building vocabulary...")

# TODO: Tokenize reviews and build vocabulary
all_words = []
...

# Create vocabulary with special tokens
word_to_idx = {'<PAD>': 0, '<UNK>': 1}
for idx, (word, count) in enumerate(most_common, start=2):
    word_to_idx[word] = idx

vocab_size = len(word_to_idx)
print(f"Vocabulary size: {vocab_size}")

# TODO: Tokenize and convert to indices
def text_to_indices(text, word_to_idx, max_len=None):
    """Convert text to sequence of indices"""
    words = text.split()
    indices = ...
    ...
    return indices

# Split data: 80% train, 10% validation, 10% test
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size + val_size]
test_df = df[train_size + val_size:]

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")


## SECTION 3: Dataset Class and DataLoaders

In [None]:
class IMDBDataset(Dataset):
    def __init__(self, texts, labels, word_to_idx, max_len):
        self.texts = texts
        self.labels = labels
        self.word_to_idx = word_to_idx
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        indices = text_to_indices(text, self.word_to_idx, self.max_len)
        return torch.LongTensor(indices), torch.tensor(label, dtype=torch.float32)

# TODO
def create_dataloaders(train_df, val_df, test_df, word_to_idx, max_len, batch_size=32):
    """Create dataloaders for a specific sequence length"""
    train_dataset = ...
    val_dataset = ...
    test_dataset = ...

    train_loader = ...
    val_loader = ...
    test_loader = ...

    return train_loader, val_loader, test_loader


## SECTION 4: Model Definitions

In [None]:
class SimpleRNNClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=128, num_layers=1):
        super(SimpleRNNClassifier, self).__init__()
        # TODO
        self.embedding = ...
        self.hidden_dim = hidden_dim

        # Define Simple RNN layer
        self.rnn = ...

        self.fc = ...
        self.sigmoid = ...

    def forward(self, x):
        # x: (batch_size, seq_len)
        embedded = ...  # (batch_size, seq_len, embedding_dim)

        # TODO: Pass through RNN and get output
        ...

        # TODO: Use the last timestep's output
        ...

        out = self.fc(rnn_out)
        out = self.sigmoid(out)
        return out.squeeze()


class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=128, num_layers=1):
        super(LSTMClassifier, self).__init__()
        # TODO
        self.embedding = ...
        self.hidden_dim = hidden_dim

        # TODO: Define LSTM layer
        self.lstm = ...

        self.fc = ...
        self.sigmoid = ...

    def forward(self, x):
        # x: (batch_size, seq_len)
        embedded = ...

        # TODO: Pass through LSTM and get output
        ...

        # Use the last timestep's output
        ...

        out = self.fc(lstm_out)
        out = self.sigmoid(out)
        return out.squeeze()

## SECTION 5: Training and Evaluation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc="Training")):
        # TODO
        data, target = ...

        # TODO: Complete the training steps
        ...

        total_loss += loss.item()
        predicted = (output > 0.5).float()
        total += target.size(0)
        correct += (predicted == target).sum().item()

    return total_loss / len(train_loader), 100. * correct / total


def evaluate(model, data_loader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in data_loader:
            # TODO:
            data, target = ...
            ...

            total_loss += loss.item()
            predicted = (output > 0.5).float()
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return total_loss / len(data_loader), 100. * correct / total


def train_model(model, train_loader, val_loader, criterion, optimizer,
                num_epochs, device):
    """Complete training loop"""
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    return train_losses, train_accs, val_losses, val_accs


## EXPERIMENT A: Sequence Length Impact Analysis

In [None]:
print("\n" + "="*80)
print("EXPERIMENT A: Sequence Length Impact Analysis")
print("="*80)

# Sequence lengths to test
sequence_lengths = [50, 100, 200, 400]
results = {
    'sequence_length': [],
    'simple_rnn_acc': [],
    'lstm_acc': []
}

# Hyperparameters
EMBEDDING_DIM = 128
HIDDEN_DIM = 128
NUM_EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.001

# Train models for each sequence length
for seq_len in sequence_lengths:
    print(f"\n{'='*80}")
    print(f"Training models with sequence length: {seq_len}")
    print(f"{'='*80}")

    # TODO: Create dataloaders
    ...

    # TODO: Train SimpleRNN
    print(f"\nTraining SimpleRNN (seq_len={seq_len})...")
    simple_rnn = ...
    criterion = ...
    optimizer = ...

    start_time = time.time()
    ...
    rnn_train_time = time.time() - start_time

    # TODO: Evaluate SimpleRNN on test set
    ...
    print(f"SimpleRNN Test Accuracy: {rnn_test_acc:.2f}%")
    print(f"Training time: {rnn_train_time:.2f}s")

    # TODO: Train LSTM
    print(f"\nTraining LSTM (seq_len={seq_len})...")
    lstm_model = ...
    optimizer = ...

    start_time = time.time()
    ...
    lstm_train_time = time.time() - start_time

    # TODO: Evaluate LSTM on test set
    ...
    print(f"LSTM Test Accuracy: {lstm_test_acc:.2f}%")
    print(f"Training time: {lstm_train_time:.2f}s")

    # Store results
    results['sequence_length'].append(seq_len)
    results['simple_rnn_acc'].append(rnn_test_acc)
    results['lstm_acc'].append(lstm_test_acc)

# Create visualization
plt.figure(figsize=(10, 6))
plt.plot(results['sequence_length'], results['simple_rnn_acc'],
         marker='o', linewidth=2, markersize=8, label='Simple RNN')
plt.plot(results['sequence_length'], results['lstm_acc'],
         marker='s', linewidth=2, markersize=8, label='LSTM')
plt.xlabel('Sequence Length (tokens)', fontsize=12)
plt.ylabel('Test Accuracy (%)', fontsize=12)
plt.title('Model Performance vs Sequence Length', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('experiment_a_accuracy_vs_length.png', dpi=300, bbox_inches='tight')
plt.show()

# Print results table
print("\nResults Summary:")
print("-" * 60)
print(f"{'Seq Length':<15} {'SimpleRNN Acc':<20} {'LSTM Acc':<20}")
print("-" * 60)
for i in range(len(results['sequence_length'])):
    print(f"{results['sequence_length'][i]:<15} {results['simple_rnn_acc'][i]:<20.2f} {results['lstm_acc'][i]:<20.2f}")


## EXPERIMENT B: Gradient Flow Measurement

In [None]:
print("\n" + "="*80)
print("EXPERIMENT B: Gradient Flow Measurement")
print("="*80)

# We'll use sequence length 200 for this experiment
SEQ_LEN = 200

# TODO: Create dataloaders
...

# TODO: Initialize fresh models
...

# TODO: Train models
print("Training SimpleRNN for gradient analysis...")
...

print("\nTraining LSTM for gradient analysis...")
...

# Measure gradients w.r.t. embeddings at each timestep
def measure_gradient_flow(model, data, target, seq_len):
    """
    Measure gradient magnitude at each timestep by computing gradients
    with respect to the embedding at each position.

    This properly captures the vanishing gradient problem by showing
    how gradient signal decays as it propagates backwards through time.
    """
    model.train()
    data = data.to(device)
    target = target.to(device)

    # Get embeddings with gradient tracking
    embedded = ...  # (batch_size, seq_len, embedding_dim)
    embedded.retain_grad()  # Important: retain gradients for intermediate tensor

    # TODO: Forward pass through RNN/LSTM
    ...

    # TODO: Use last timestep for prediction (as in original model)
    ...
    out = model.fc(last_output)
    out = model.sigmoid(out)

    # TODO: Compute loss
    loss = ...

    # TODO: Backward pass
    ...

    # Extract gradients at each timestep position
    if embedded.grad is not None:
        # TODO: Compute L2 norm at each timestep, averaged over batch and embedding dimension
        ...
    else:
        return np.zeros(seq_len)


# Get a batch from validation set
val_iter = iter(val_loader)
data_batch, target_batch = next(val_iter)

# TODO
print("\nCapturing gradients for SimpleRNN...")
rnn_grads = ...

# TODO
print("Capturing gradients for LSTM...")
lstm_grads = ...

# VISUALIZATION
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Linear scale
ax1 = axes[0]
timesteps = np.arange(1, SEQ_LEN + 1)

ax1.plot(timesteps, rnn_grads, linewidth=2, label='Simple RNN', alpha=0.8, color='#e74c3c')
ax1.plot(timesteps, lstm_grads, linewidth=2, label='LSTM', alpha=0.8, color='#3498db')

ax1.set_xlabel('Timestep (earlier ← → later)', fontsize=12)
ax1.set_ylabel('Gradient Magnitude (L2 Norm)', fontsize=12)
ax1.set_title('Gradient Flow Through Time - Linear Scale', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11, loc='upper right')
ax1.grid(True, alpha=0.3)

# Add annotation pointing to early timesteps
ax1.annotate('Early timesteps\n(vanishing gradients)',
             xy=(20, rnn_grads[19]), xytext=(50, rnn_grads[19] * 3),
             arrowprops=dict(arrowstyle='->', color='red', lw=1.5),
             fontsize=10, color='red')

# Plot 2: Log scale (better for seeing vanishing gradients)
ax2 = axes[1]

ax2.plot(timesteps, rnn_grads, linewidth=2, label='Simple RNN', alpha=0.8, color='#e74c3c')
ax2.plot(timesteps, lstm_grads, linewidth=2, label='LSTM', alpha=0.8, color='#3498db')

ax2.set_xlabel('Timestep (earlier ← → later)', fontsize=12)
ax2.set_ylabel('Gradient Magnitude (L2 Norm)', fontsize=12)
ax2.set_title('Gradient Flow Through Time - Log Scale', fontsize=14, fontweight='bold')
ax2.set_yscale('log')
ax2.legend(fontsize=11, loc='upper right')
ax2.grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.savefig('experiment_b_gradient_flow_corrected.png', dpi=300, bbox_inches='tight')
plt.show()

# ANALYSIS AND STATISTICS
print("\n" + "="*80)
print("GRADIENT FLOW ANALYSIS")
print("="*80)

# Calculate gradient decay ratios at different points
early_timestep = 10
mid_timestep = 100
late_timestep = 190

print(f"\n1. GRADIENT MAGNITUDES AT KEY TIMESTEPS:")
print(f"   {'Timestep':<15} {'SimpleRNN':<20} {'LSTM':<20}")
print(f"   {'-'*55}")
print(f"   {f't={early_timestep}':<15} {rnn_grads[early_timestep-1]:<20.8f} {lstm_grads[early_timestep-1]:<20.8f}")
print(f"   {f't={mid_timestep}':<15} {rnn_grads[mid_timestep-1]:<20.8f} {lstm_grads[mid_timestep-1]:<20.8f}")
print(f"   {f't={late_timestep}':<15} {rnn_grads[late_timestep-1]:<20.8f} {lstm_grads[late_timestep-1]:<20.8f}")

# Calculate decay ratios
rnn_early_to_late = rnn_grads[early_timestep-1] / rnn_grads[late_timestep-1]
lstm_early_to_late = lstm_grads[early_timestep-1] / lstm_grads[late_timestep-1]

print(f"\n2. GRADIENT DECAY (Early timestep / Late timestep):")
print(f"   SimpleRNN: {rnn_early_to_late:.6f}x  ({'SEVERE VANISHING' if rnn_early_to_late < 0.01 else 'Moderate vanishing'})")
print(f"   LSTM:      {lstm_early_to_late:.6f}x  ({'Good gradient flow' if lstm_early_to_late > 0.1 else 'Some vanishing'})")

## EXPERIMENT C: Hidden State Dynamics Analysis

In [None]:
print("\n" + "="*80)
print("EXPERIMENT C: Hidden State Dynamics Analysis")
print("="*80)

# TODO: Select test sequences (5 positive, 5 negative, all at least 150 tokens)
test_samples_filtered = ...
...
test_samples = pd.concat([positive_samples, negative_samples])

print(f"Selected {len(test_samples)} test sequences for analysis")

# Extract hidden states during inference
def extract_hidden_states(model, texts, labels, word_to_idx, model_type='rnn'):
    """Extract hidden states at each timestep"""
    model.eval()

    all_hidden_norms = []
    all_cell_norms = []
    all_labels = []

    with torch.no_grad():
        for text, label in zip(texts, labels):
            # TODO: Convert text to indices (don't pad, use natural length)
            indices = ...
            x = torch.LongTensor([indices]).to(device)

            # TODO: Get embeddings
            if model_type == 'rnn':
                embedded = ...
                ...
                hidden_states = ...
                cell_states = ...
            else:  # lstm
                embedded = ...
                ...
                hidden_states = ...

                # To get all cell states, we need to process step by step
                # Reinitialize and process manually
                all_cells = []
                # TODO: Initialize with correct number of layers
                num_layers = ...
                h_t = torch.zeros(...).to(device)
                c_t = torch.zeros(...).to(device)

                for t in range(embedded.size(1)):
                    out, (h_t, c_t) = model.lstm(embedded[:, t:t+1, :], (h_t, c_t))
                    # TODO: Store only the last layer's cell state
                    all_cells.append(...)

                cell_states = np.array(all_cells)

            # TODO: Compute L2 norms at each timestep
            ...
            all_hidden_norms.append(hidden_norms)

            if cell_states is not None:
                cell_norms = np.linalg.norm(cell_states, axis=1)
                all_cell_norms.append(cell_norms)

            all_labels.append(label)

    return all_hidden_norms, all_cell_norms, all_labels

# TODO
print("Extracting hidden states from SimpleRNN...")
...

# TODO
print("Extracting hidden states from LSTM...")
...

# Create visualizations
fig, axes = plt.subplots(3, 1, figsize=(14, 12))

# Plot 1: SimpleRNN Hidden States
ax1 = axes[0]
for i, (norms, label) in enumerate(zip(rnn_hidden_norms, rnn_labels)):
    color = 'blue' if label == 1 else 'red'
    label_text = 'Positive' if label == 1 else 'Negative'
    ax1.plot(norms, color=color, alpha=0.6, linewidth=1.5,
            label=label_text if i < 2 else "")  # Only label first of each

ax1.set_xlabel('Timestep', fontsize=11)
ax1.set_ylabel('Hidden State Norm (L2)', fontsize=11)
ax1.set_title('SimpleRNN: Hidden State Dynamics', fontsize=13, fontweight='bold')
ax1.legend(['Positive', 'Negative'], fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot 2: LSTM Hidden States
ax2 = axes[1]
for i, (norms, label) in enumerate(zip(lstm_hidden_norms, lstm_labels)):
    color = 'blue' if label == 1 else 'red'
    label_text = 'Positive' if label == 1 else 'Negative'
    ax2.plot(norms, color=color, alpha=0.6, linewidth=1.5,
            label=label_text if i < 2 else "")

ax2.set_xlabel('Timestep', fontsize=11)
ax2.set_ylabel('Hidden State Norm (L2)', fontsize=11)
ax2.set_title('LSTM: Hidden State Dynamics', fontsize=13, fontweight='bold')
ax2.legend(['Positive', 'Negative'], fontsize=10)
ax2.grid(True, alpha=0.3)

# Plot 3: LSTM Cell States
ax3 = axes[2]
for i, (norms, label) in enumerate(zip(lstm_cell_norms, lstm_labels)):
    color = 'blue' if label == 1 else 'red'
    label_text = 'Positive' if label == 1 else 'Negative'
    ax3.plot(norms, color=color, alpha=0.6, linewidth=1.5,
            label=label_text if i < 2 else "")

ax3.set_xlabel('Timestep', fontsize=11)
ax3.set_ylabel('Cell State Norm (L2)', fontsize=11)
ax3.set_title('LSTM: Cell State Dynamics', fontsize=13, fontweight='bold')
ax3.legend(['Positive', 'Negative'], fontsize=10)
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('experiment_c_hidden_state_dynamics.png', dpi=300, bbox_inches='tight')
plt.show()

# Compute statistics
print("\nHidden State Statistics:")
print("-" * 60)

rnn_variances = [np.var(norms) for norms in rnn_hidden_norms]
lstm_h_variances = [np.var(norms) for norms in lstm_hidden_norms]
lstm_c_variances = [np.var(norms) for norms in lstm_cell_norms]

print(f"SimpleRNN hidden state variance: Mean={np.mean(rnn_variances):.4f}, Std={np.std(rnn_variances):.4f}")
print(f"LSTM hidden state variance: Mean={np.mean(lstm_h_variances):.4f}, Std={np.std(lstm_h_variances):.4f}")
print(f"LSTM cell state variance: Mean={np.mean(lstm_c_variances):.4f}, Std={np.std(lstm_c_variances):.4f}")