# Notebook 1: The Mouse in the Maze
## Understanding Sequential Processing with RNNs

---

**Learning Objectives:**
- Understand how RNNs process sequences step-by-step
- Visualize hidden state evolution through time
- Observe how memory degrades over long sequences
- Compare simple RNN vs LSTM performance
- Establish baseline for transformer comparison

**Prerequisites:** Basic understanding of neural networks

**Estimated Time:** 45-60 minutes

---

## The Metaphor: The Mouse in the Maze

Imagine you're a mouse trying to navigate a maze to find cheese:

- üê≠ **You can only see your immediate surroundings** - no bird's-eye view
- üß† **You must remember where you've been** - but your memory is limited
- üö∂ **You must process the maze step-by-step** - you can't teleport
- üìâ **The longer the path, the hazier your memory** - early turns fade away

This is exactly how RNNs work: **sequential processing with limited memory**.

### Why This Matters

Before transformers revolutionized AI, this sequential constraint was fundamental. Understanding *why* it's limiting will help you appreciate what transformers achieve.

---

## Setup: Import Our Tools

In [None]:
# Add src to path
import sys
sys.path.insert(0, '../src')

# Core imports
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Our modules
from maze_envs import generate_simple_maze, MazeDataset, MazeConfig, Maze
from visualizations import MazeVisualizer, TrainingVisualizer, set_style
from rnn_solver import create_simple_rnn, create_lstm, RNNTrainer, RNNMazeSolver

# Set consistent style
set_style()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("‚úì Setup complete!")

## Part 1: Understanding the Problem

### 1.1 Generate a Sample Maze

Let's create a simple maze and see what the "mouse" needs to solve.

In [None]:
# Generate a maze
maze = generate_simple_maze(size=15, seed=42)
solution = maze.solve()

print(f"Maze size: {maze.config.height}x{maze.config.width}")
print(f"Start: {maze.start}")
print(f"Goal: {maze.goal}")
print(f"Solution length: {len(solution)} steps")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

MazeVisualizer.plot_maze(maze, ax=axes[0], title="Unsolved Maze")
MazeVisualizer.plot_maze(maze, ax=axes[1], show_solution=True, title="Optimal Solution")

plt.tight_layout()
plt.show()

print(f"\nOptimal path (first 10 moves): {maze.path_to_actions(solution)[:10]}")

### 1.2 The Sequential Challenge

The path above has ~{solution_length} steps. An RNN must:

1. **Start at S** with an initial hidden state
2. **Process each position** one at a time
3. **Update hidden state** after each step (trying to remember history)
4. **Predict the next move** based only on current position + hidden state
5. **Repeat** until reaching G

The critical question: **Can the hidden state remember the path from 20 steps ago?**

---

## Part 2: The Math Behind RNNs

### 2.1 Simple RNN: The Basic Mouse

The fundamental RNN equation:

$$h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h)$$

where:
- $h_t$ = hidden state at time $t$ (the "memory")
- $h_{t-1}$ = previous hidden state (what we remember)
- $x_t$ = current input (where we are now)
- $W_{hh}, W_{xh}, b_h$ = learned parameters

**Key insight**: $h_t$ depends ONLY on $h_{t-1}$ and $x_t$. To remember step 1 at step 20, information must flow through 19 intermediate states!

### 2.2 LSTM: The Mouse with Better Memory

LSTM adds memory "gates" to combat vanishing gradients:

$$
\begin{aligned}
f_t &= \sigma(W_f [h_{t-1}, x_t] + b_f) \quad \text{(forget gate)} \\
i_t &= \sigma(W_i [h_{t-1}, x_t] + b_i) \quad \text{(input gate)} \\
o_t &= \sigma(W_o [h_{t-1}, x_t] + b_o) \quad \text{(output gate)} \\
\tilde{C}_t &= \tanh(W_C [h_{t-1}, x_t] + b_C) \quad \text{(candidate cell)} \\
C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad \text{(cell state)} \\
h_t &= o_t \odot \tanh(C_t) \quad \text{(hidden state)}
\end{aligned}
$$

**Key insight**: The cell state $C_t$ provides a "highway" for gradients, helping memory persist longer.

But even LSTMs struggle with very long sequences!

---

## Part 3: Training Data Preparation

### 3.1 Generate Training Dataset

In [None]:
# Configuration
maze_config = MazeConfig(
    height=15,
    width=15,
    wall_probability=0.25,
    ensure_solvable=True,
    seed=42
)

# Generate datasets
print("Generating training data...")
train_dataset = MazeDataset(num_mazes=200, config=maze_config)

maze_config.seed = 1000
val_dataset = MazeDataset(num_mazes=50, config=maze_config)

print(f"‚úì Training mazes: {len(train_dataset)}")
print(f"‚úì Validation mazes: {len(val_dataset)}")

# Analyze path lengths
train_lengths = [len(sol) for _, sol in train_dataset]
val_lengths = [len(sol) for _, sol in val_dataset]

print(f"\nPath length statistics:")
print(f"  Training: {np.mean(train_lengths):.1f} ¬± {np.std(train_lengths):.1f} steps")
print(f"  Range: [{min(train_lengths)}, {max(train_lengths)}]")

# Visualize distribution
plt.figure(figsize=(10, 5))
plt.hist(train_lengths, bins=20, alpha=0.7, label='Training', edgecolor='black')
plt.hist(val_lengths, bins=20, alpha=0.7, label='Validation', edgecolor='black')
plt.xlabel('Path Length (steps)')
plt.ylabel('Count')
plt.title('Distribution of Solution Path Lengths')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

### 3.2 Prepare Training Batches

We need to convert mazes into tensor format for training.

In [None]:
def prepare_batch(dataset, indices, device):
    """
    Prepare a batch of mazes and solutions for training.
    
    Returns:
        maze_grids: [batch_size, height, width]
        position_sequences: [batch_size, max_len, 2]
        action_sequences: [batch_size, max_len]
    """
    batch_mazes = [dataset.mazes[i] for i in indices]
    batch_solutions = [dataset.solutions[i] for i in indices]
    
    # Find max sequence length in batch
    max_len = max(len(sol) for sol in batch_solutions)
    
    maze_grids = []
    position_seqs = []
    action_seqs = []
    
    for maze, solution in zip(batch_mazes, batch_solutions):
        # Maze grid
        maze_grids.append(maze.grid)
        
        # Position sequence (padded)
        positions = np.array(solution)
        pad_len = max_len - len(positions)
        if pad_len > 0:
            positions = np.vstack([positions, np.zeros((pad_len, 2), dtype=int)])
        position_seqs.append(positions)
        
        # Action sequence (padded with -1 for ignore)
        actions = maze.path_to_actions(solution)
        action_map = {'UP': 0, 'DOWN': 1, 'LEFT': 2, 'RIGHT': 3}
        action_ids = [action_map[a] for a in actions]
        
        # Pad with -1 (will be ignored in loss)
        action_ids.extend([-1] * (max_len - 1 - len(action_ids)))
        action_seqs.append(action_ids)
    
    # Convert to tensors
    maze_grids = torch.tensor(np.stack(maze_grids), dtype=torch.long, device=device)
    position_seqs = torch.tensor(np.stack(position_seqs), dtype=torch.long, device=device)
    action_seqs = torch.tensor(np.stack(action_seqs), dtype=torch.long, device=device)
    
    return maze_grids, position_seqs, action_seqs

# Test the function
test_batch = prepare_batch(train_dataset, [0, 1, 2], device)
print(f"Batch shapes:")
print(f"  Maze grids: {test_batch[0].shape}")
print(f"  Positions: {test_batch[1].shape}")
print(f"  Actions: {test_batch[2].shape}")
print(f"\n‚úì Data preparation working!")

---

## Part 4: Train Simple RNN

### 4.1 Create and Configure Model

In [None]:
# Create simple RNN
simple_rnn = create_simple_rnn(maze_size=15, hidden_dim=128)
simple_rnn = simple_rnn.to(device)

print(f"Model architecture:")
print(simple_rnn)
print(f"\nTotal parameters: {sum(p.numel() for p in simple_rnn.parameters()):,}")

# Create trainer
rnn_trainer = RNNTrainer(simple_rnn, learning_rate=1e-3)

### 4.2 Training Loop

In [None]:
# Training configuration
num_epochs = 20
batch_size = 16

# History tracking
rnn_history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': []
}

print("Training Simple RNN...\n")

for epoch in range(num_epochs):
    # Training
    train_losses = []
    simple_rnn.train()
    
    # Shuffle training data
    train_indices = np.random.permutation(len(train_dataset))
    
    for i in range(0, len(train_dataset), batch_size):
        batch_idx = train_indices[i:i+batch_size]
        maze_grids, positions, actions = prepare_batch(train_dataset, batch_idx, device)
        
        # Remove last position (no action to predict)
        positions = positions[:, :-1, :]
        
        loss = rnn_trainer.train_step(maze_grids, positions, actions)
        train_losses.append(loss)
    
    # Validation
    val_losses = []
    val_accs = []
    
    for i in range(0, len(val_dataset), batch_size):
        batch_idx = list(range(i, min(i+batch_size, len(val_dataset))))
        maze_grids, positions, actions = prepare_batch(val_dataset, batch_idx, device)
        positions = positions[:, :-1, :]
        
        val_loss, val_acc = rnn_trainer.evaluate(maze_grids, positions, actions)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    
    # Record history
    rnn_history['train_loss'].append(np.mean(train_losses))
    rnn_history['val_loss'].append(np.mean(val_losses))
    rnn_history['val_acc'].append(np.mean(val_accs))
    
    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Train Loss: {rnn_history['train_loss'][-1]:.4f} | "
              f"Val Loss: {rnn_history['val_loss'][-1]:.4f} | "
              f"Val Acc: {rnn_history['val_acc'][-1]:.3f}")

print("\n‚úì Training complete!")

### 4.3 Visualize Training Progress

In [None]:
fig, axes = TrainingVisualizer.plot_training_curves(rnn_history, figsize=(15, 5))
fig.suptitle('Simple RNN Training Progress', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

---

## Part 5: Train LSTM (Better Memory)

### 5.1 Create LSTM Model

In [None]:
# Create LSTM
lstm_model = create_lstm(maze_size=15, hidden_dim=128)
lstm_model = lstm_model.to(device)

print(f"LSTM parameters: {sum(p.numel() for p in lstm_model.parameters()):,}")
print(f"Simple RNN parameters: {sum(p.numel() for p in simple_rnn.parameters()):,}")
print(f"\nLSTM is ~{sum(p.numel() for p in lstm_model.parameters()) / sum(p.numel() for p in simple_rnn.parameters()):.1f}x larger (gates require more parameters)")

# Create trainer
lstm_trainer = RNNTrainer(lstm_model, learning_rate=1e-3)

### 5.2 Train LSTM

In [None]:
# History tracking
lstm_history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': []
}

print("Training LSTM...\n")

for epoch in range(num_epochs):
    # Training
    train_losses = []
    lstm_model.train()
    
    train_indices = np.random.permutation(len(train_dataset))
    
    for i in range(0, len(train_dataset), batch_size):
        batch_idx = train_indices[i:i+batch_size]
        maze_grids, positions, actions = prepare_batch(train_dataset, batch_idx, device)
        positions = positions[:, :-1, :]
        
        loss = lstm_trainer.train_step(maze_grids, positions, actions)
        train_losses.append(loss)
    
    # Validation
    val_losses = []
    val_accs = []
    
    for i in range(0, len(val_dataset), batch_size):
        batch_idx = list(range(i, min(i+batch_size, len(val_dataset))))
        maze_grids, positions, actions = prepare_batch(val_dataset, batch_idx, device)
        positions = positions[:, :-1, :]
        
        val_loss, val_acc = lstm_trainer.evaluate(maze_grids, positions, actions)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    
    # Record history
    lstm_history['train_loss'].append(np.mean(train_losses))
    lstm_history['val_loss'].append(np.mean(val_losses))
    lstm_history['val_acc'].append(np.mean(val_accs))
    
    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:2d}/{num_epochs} | "
              f"Train Loss: {lstm_history['train_loss'][-1]:.4f} | "
              f"Val Loss: {lstm_history['val_loss'][-1]:.4f} | "
              f"Val Acc: {lstm_history['val_acc'][-1]:.3f}")

print("\n‚úì Training complete!")

### 5.3 Compare RNN vs LSTM

In [None]:
# Side-by-side comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss comparison
axes[0].plot(rnn_history['val_loss'], label='Simple RNN', linewidth=2, marker='o')
axes[0].plot(lstm_history['val_loss'], label='LSTM', linewidth=2, marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Loss')
axes[0].set_title('Validation Loss: RNN vs LSTM', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy comparison
axes[1].plot(rnn_history['val_acc'], label='Simple RNN', linewidth=2, marker='o')
axes[1].plot(lstm_history['val_acc'], label='LSTM', linewidth=2, marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Accuracy')
axes[1].set_title('Validation Accuracy: RNN vs LSTM', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"  Simple RNN - Val Loss: {rnn_history['val_loss'][-1]:.4f}, Val Acc: {rnn_history['val_acc'][-1]:.3f}")
print(f"  LSTM       - Val Loss: {lstm_history['val_loss'][-1]:.4f}, Val Acc: {lstm_history['val_acc'][-1]:.3f}")
print(f"\n  LSTM is {((lstm_history['val_acc'][-1] - rnn_history['val_acc'][-1]) / rnn_history['val_acc'][-1] * 100):.1f}% better!")

---

## Part 6: Test on Example Mazes

### 6.1 Generate Test Paths

In [None]:
# Generate a test maze
test_maze = generate_simple_maze(size=15, seed=999)
optimal_solution = test_maze.solve()

print(f"Test maze difficulty: {len(optimal_solution)} steps")

# Generate paths from both models
print("\nGenerating paths...")
rnn_path = simple_rnn.generate_path(
    torch.tensor(test_maze.grid, dtype=torch.long).unsqueeze(0).to(device),
    test_maze.start,
    max_steps=100
)

lstm_path = lstm_model.generate_path(
    torch.tensor(test_maze.grid, dtype=torch.long).unsqueeze(0).to(device),
    test_maze.start,
    max_steps=100
)

print(f"‚úì Simple RNN path: {len(rnn_path)} steps")
print(f"‚úì LSTM path: {len(lstm_path)} steps")
print(f"‚úì Optimal path: {len(optimal_solution)} steps")

### 6.2 Visualize Results

In [None]:
# Create comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot each path
test_maze_rnn = generate_simple_maze(size=15, seed=999)
test_maze_lstm = generate_simple_maze(size=15, seed=999)
test_maze_opt = generate_simple_maze(size=15, seed=999)

# Mark paths
for pos in rnn_path[1:-1]:
    if test_maze_rnn.grid[pos] == 1:
        test_maze_rnn.grid[pos] = 4

for pos in lstm_path[1:-1]:
    if test_maze_lstm.grid[pos] == 1:
        test_maze_lstm.grid[pos] = 4

# Restore start and goal
test_maze_rnn.grid[test_maze.start] = 2
test_maze_rnn.grid[test_maze.goal] = 3
test_maze_lstm.grid[test_maze.start] = 2
test_maze_lstm.grid[test_maze.goal] = 3

MazeVisualizer.plot_maze(test_maze_rnn, ax=axes[0], 
                        title=f"Simple RNN ({len(rnn_path)} steps)")
MazeVisualizer.plot_maze(test_maze_lstm, ax=axes[1], 
                        title=f"LSTM ({len(lstm_path)} steps)")
MazeVisualizer.plot_maze(test_maze_opt, ax=axes[2], show_solution=True,
                        title=f"Optimal ({len(optimal_solution)} steps)")

plt.tight_layout()
plt.show()

# Check if reached goal
rnn_reached = rnn_path[-1] == test_maze.goal
lstm_reached = lstm_path[-1] == test_maze.goal

print(f"\nResults:")
print(f"  Simple RNN: {'‚úì Reached goal' if rnn_reached else '‚úó Did not reach goal'}")
print(f"  LSTM:       {'‚úì Reached goal' if lstm_reached else '‚úó Did not reach goal'}")

---

## Part 7: The Critical Test - Performance vs Path Length

### 7.1 Generate Mazes of Varying Difficulty

**This is the key experiment**: How does performance degrade as paths get longer?

In [None]:
# Test on mazes with different path lengths
length_bins = [0, 10, 15, 20, 25, 100]  # Path length ranges
bin_labels = ['<10', '10-15', '15-20', '20-25', '25+']

# Categorize validation mazes by path length
binned_mazes = {label: [] for label in bin_labels}

for idx, (maze, solution) in enumerate(val_dataset):
    path_len = len(solution)
    for i, (low, high) in enumerate(zip(length_bins[:-1], length_bins[1:])):
        if low <= path_len < high:
            binned_mazes[bin_labels[i]].append(idx)
            break

print("Mazes per difficulty bin:")
for label, indices in binned_mazes.items():
    print(f"  {label:6s}: {len(indices):2d} mazes")

### 7.2 Test Models on Each Difficulty Level

In [None]:
# Test each model on each bin
rnn_results = {label: [] for label in bin_labels}
lstm_results = {label: [] for label in bin_labels}

print("Testing models on different path lengths...\n")

for label, indices in binned_mazes.items():
    if len(indices) == 0:
        continue
        
    print(f"Testing on {label} step mazes...")
    
    for idx in indices:
        maze, solution = val_dataset[idx]
        maze_tensor = torch.tensor(maze.grid, dtype=torch.long).unsqueeze(0).to(device)
        
        # RNN
        rnn_path = simple_rnn.generate_path(maze_tensor, maze.start, max_steps=100)
        rnn_success = (rnn_path[-1] == maze.goal)
        rnn_results[label].append(rnn_success)
        
        # LSTM
        lstm_path = lstm_model.generate_path(maze_tensor, maze.start, max_steps=100)
        lstm_success = (lstm_path[-1] == maze.goal)
        lstm_results[label].append(lstm_success)
    
    # Print results for this bin
    rnn_acc = np.mean(rnn_results[label])
    lstm_acc = np.mean(lstm_results[label])
    print(f"  RNN:  {rnn_acc:.2%} success rate")
    print(f"  LSTM: {lstm_acc:.2%} success rate\n")

print("‚úì Testing complete!")

### 7.3 Visualize the Sequential Bottleneck

In [None]:
# Calculate mean accuracy per bin
rnn_acc_by_length = [np.mean(rnn_results[label]) if rnn_results[label] else 0 
                     for label in bin_labels]
lstm_acc_by_length = [np.mean(lstm_results[label]) if lstm_results[label] else 0 
                      for label in bin_labels]

# Plot
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(bin_labels))
width = 0.35

bars1 = ax.bar(x - width/2, rnn_acc_by_length, width, label='Simple RNN', 
               color='#E74C3C', alpha=0.8, edgecolor='black')
bars2 = ax.bar(x + width/2, lstm_acc_by_length, width, label='LSTM',
               color='#3498DB', alpha=0.8, edgecolor='black')

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{height:.1%}',
               ha='center', va='bottom', fontsize=10, fontweight='bold')

ax.set_xlabel('Solution Path Length (steps)', fontsize=12, fontweight='bold')
ax.set_ylabel('Success Rate', fontsize=12, fontweight='bold')
ax.set_title('The Sequential Bottleneck: Performance Degrades with Path Length', 
            fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(bin_labels)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([0, 1.1])

plt.tight_layout()
plt.show()

print("\nüîç Key Observation:")
print("   As paths get longer, both RNN and LSTM struggle to maintain performance.")
print("   The 'memory backpack' can only hold so much!")
print("\n   This is the fundamental limitation transformers solve!")

---

## Part 8: Hidden State Analysis

### 8.1 Visualize How Hidden State Evolves

Let's look inside the "memory backpack" to see what's happening.

In [None]:
# Get a long maze
long_maze_idx = None
for idx, (maze, solution) in enumerate(val_dataset):
    if len(solution) > 20:
        long_maze_idx = idx
        break

if long_maze_idx is not None:
    test_maze, test_solution = val_dataset[long_maze_idx]
    print(f"Analyzing maze with path length: {len(test_solution)}")
    
    # Prepare data
    maze_grid = torch.tensor(test_maze.grid, dtype=torch.long).unsqueeze(0).to(device)
    positions = torch.tensor(test_solution, dtype=torch.long).unsqueeze(0).to(device)
    
    # Forward pass to get hidden states
    lstm_model.eval()
    with torch.no_grad():
        # We need to modify forward to return hidden states at each step
        # For now, just show the concept
        action_logits, (h_n, c_n) = lstm_model(maze_grid, positions)
        
    print(f"\n‚úì Hidden state analysis complete!")
    print(f"   Final hidden state shape: {h_n.shape}")
    print(f"   Final cell state shape: {c_n.shape}")
    
    # Note: Full hidden state evolution visualization would require
    # storing intermediate hidden states during forward pass
    print("\n   (Full hidden state evolution requires model modification)")
else:
    print("No sufficiently long maze found in validation set")

---

## Summary: Key Takeaways

### What We Learned

1. **Sequential Processing is Limiting**
   - RNNs must process inputs one at a time
   - Information must flow through every intermediate state
   - No shortcuts or "teleportation" between distant positions

2. **Memory Degrades Over Distance**
   - Both RNN and LSTM performance drops on longer paths
   - The "memory backpack" has finite capacity
   - Early information gets compressed/forgotten

3. **LSTM Helps But Doesn't Solve the Problem**
   - Gates allow better gradient flow
   - Cell state provides a "memory highway"
   - But still fundamentally sequential

### The Question for Next Time

**What if we could see the entire maze at once?**

Instead of processing step-by-step like a mouse:
- üó∫Ô∏è View from above (bird's-eye perspective)
- üîó Connect any two positions directly
- ‚ö° Process all positions in parallel

This is what **attention** enables. In Notebook 2, we'll build this mechanism from scratch!

---

## Exercises (Optional)

1. **Experiment with hidden dimensions**: Try hidden_dim=64 and hidden_dim=256. How does this affect performance?

2. **Add GRU**: Implement a GRU variant and compare to RNN and LSTM

3. **Longer mazes**: Generate 20√ó20 mazes. How much does performance degrade?

4. **Attention preview**: Can you think of how to let position 20 "see" position 1 directly?

---

## Next: Notebook 2 - The Map (Attention)

We'll implement attention from scratch and see how it transforms this sequential bottleneck into a parallel process!