# Reinforcement Learning Plan B - Part 1: RL Fundamentals & Tabular Methods

This notebook introduces the mathematical foundations of Reinforcement Learning and implements core tabular methods. We'll cover Markov Decision Processes, Bellman equations, and fundamental algorithms like value iteration, policy iteration, Q-learning, and SARSA.

**Learning Objectives:**
- Understand the mathematical framework of MDPs
- Derive and implement Bellman equations
- Compare dynamic programming vs temporal difference methods
- Build intuition through GridWorld environments
- Analyze convergence properties and performance trade-offs

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, deque
import pandas as pd
from typing import Tuple, Dict, List, Optional
import time
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
np.random.seed(42)

print("Environment setup complete!")

## 1. Markov Decision Processes (MDPs)

A **Markov Decision Process** is the mathematical framework for modeling decision-making in situations where outcomes are partly random and partly under the control of a decision maker.

### Mathematical Definition

An MDP is defined by the tuple $(\mathcal{S}, \mathcal{A}, P, R, \gamma)$:

- **$\mathcal{S}$**: State space - the set of all possible states
- **$\mathcal{A}$**: Action space - the set of all possible actions  
- **$P$**: Transition probability function: $P(s'|s,a) = \mathbb{P}[S_{t+1} = s' | S_t = s, A_t = a]$
- **$R$**: Reward function: $R(s,a,s') = \mathbb{E}[R_{t+1} | S_t = s, A_t = a, S_{t+1} = s']$
- **$\gamma$**: Discount factor: $\gamma \in [0,1]$

### Key Properties

**Markov Property**: The future is independent of the past given the present:
$$\mathbb{P}[S_{t+1} = s' | S_t = s, A_t = a, S_{t-1}, A_{t-1}, \ldots, S_0, A_0] = \mathbb{P}[S_{t+1} = s' | S_t = s, A_t = a]$$

**Return**: The cumulative discounted reward from time $t$:
$$G_t = \sum_{k=0}^{\infty} \gamma^k R_{t+k+1}$$

The discount factor $\gamma$ balances immediate vs future rewards:
- $\gamma = 0$: Only immediate rewards matter (myopic)
- $\gamma = 1$: All future rewards are equally important  
- $\gamma < 1$: Future rewards are discounted exponentially

## 2. GridWorld Environment Implementation

We'll implement a classic GridWorld environment to demonstrate RL concepts. This serves as our testing ground for various algorithms.

### Environment Dynamics

- **States**: Grid positions $(i,j)$
- **Actions**: Up, Down, Left, Right
- **Transitions**: Deterministic movement (with boundary handling)
- **Rewards**: Goal state gives positive reward, obstacles negative, others zero
- **Terminal States**: Goal and obstacle states end episodes

In [None]:
class GridWorld:
    """
    GridWorld environment for reinforcement learning experiments.
    
    The agent navigates a grid to reach a goal while avoiding obstacles.
    This environment serves as a perfect testbed for tabular RL methods.
    """
    
    def __init__(self, height: int = 5, width: int = 5, goal_reward: float = 1.0, 
                 step_reward: float = -0.01, obstacle_reward: float = -1.0):
        self.height = height
        self.width = width
        self.goal_reward = goal_reward
        self.step_reward = step_reward
        self.obstacle_reward = obstacle_reward
        
        # Define actions: 0=Up, 1=Down, 2=Left, 3=Right
        self.actions = [(0, -1), (0, 1), (-1, 0), (1, 0)]  # (dx, dy)
        self.action_names = ['Up', 'Down', 'Left', 'Right']
        self.num_actions = len(self.actions)
        
        # Set up grid layout
        self.goal_state = (height-1, width-1)  # Bottom-right corner
        self.start_state = (0, 0)  # Top-left corner
        
        # Define obstacles (can be customized)
        self.obstacles = {(2, 2), (1, 3)} if height >= 4 and width >= 4 else set()
        
        # Current state
        self.current_state = self.start_state
        self.done = False
    
    def reset(self) -> Tuple[int, int]:
        """Reset environment to starting state."""
        self.current_state = self.start_state
        self.done = False
        return self.current_state
    
    def step(self, action: int) -> Tuple[Tuple[int, int], float, bool, dict]:
        """Execute action and return (next_state, reward, done, info)."""
        if self.done:
            return self.current_state, 0, True, {}
        
        # Calculate next state
        dx, dy = self.actions[action]
        next_x = max(0, min(self.width - 1, self.current_state[0] + dx))
        next_y = max(0, min(self.height - 1, self.current_state[1] + dy))
        next_state = (next_x, next_y)
        
        # Calculate reward
        if next_state == self.goal_state:
            reward = self.goal_reward
            self.done = True
        elif next_state in self.obstacles:
            reward = self.obstacle_reward
            self.done = True
        else:
            reward = self.step_reward
        
        self.current_state = next_state
        
        return next_state, reward, self.done, {}
    
    def get_all_states(self) -> List[Tuple[int, int]]:
        """Return all possible states."""
        return [(i, j) for i in range(self.width) for j in range(self.height)]
    
    def is_terminal(self, state: Tuple[int, int]) -> bool:
        """Check if state is terminal."""
        return state == self.goal_state or state in self.obstacles
    
    def get_transition_prob(self, state: Tuple[int, int], action: int, 
                          next_state: Tuple[int, int]) -> float:
        """Get transition probability P(s'|s,a)."""
        if self.is_terminal(state):
            return 1.0 if next_state == state else 0.0
        
        # Calculate expected next state
        dx, dy = self.actions[action]
        expected_next_x = max(0, min(self.width - 1, state[0] + dx))
        expected_next_y = max(0, min(self.height - 1, state[1] + dy))
        expected_next_state = (expected_next_x, expected_next_y)
        
        return 1.0 if next_state == expected_next_state else 0.0
    
    def get_reward(self, state: Tuple[int, int], action: int, 
                   next_state: Tuple[int, int]) -> float:
        """Get reward R(s,a,s')."""
        if next_state == self.goal_state:
            return self.goal_reward
        elif next_state in self.obstacles:
            return self.obstacle_reward
        else:
            return self.step_reward
    
    def render(self, values: Optional[Dict] = None, policy: Optional[Dict] = None) -> None:
        """Visualize the grid world with optional value function or policy."""
        fig, ax = plt.subplots(figsize=(8, 6))
        
        # Create grid visualization
        grid = np.zeros((self.height, self.width))
        
        if values:
            for (x, y), value in values.items():
                grid[y, x] = value
        
        # Plot heatmap
        im = ax.imshow(grid, cmap='coolwarm', alpha=0.7)
        
        # Add grid lines
        ax.set_xticks(np.arange(self.width + 1) - 0.5, minor=True)
        ax.set_yticks(np.arange(self.height + 1) - 0.5, minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
        
        # Mark special states
        start_y, start_x = self.start_state[1], self.start_state[0]
        goal_y, goal_x = self.goal_state[1], self.goal_state[0]
        
        ax.text(start_x, start_y, 'S', ha='center', va='center', 
                fontsize=16, fontweight='bold', color='green')
        ax.text(goal_x, goal_y, 'G', ha='center', va='center', 
                fontsize=16, fontweight='bold', color='red')
        
        # Mark obstacles
        for (obs_x, obs_y) in self.obstacles:
            ax.text(obs_x, obs_y, 'X', ha='center', va='center', 
                    fontsize=16, fontweight='bold', color='black')
        
        # Add policy arrows if provided
        if policy:
            arrow_props = dict(arrowstyle='->', lw=2, color='blue')
            for (x, y), action in policy.items():
                if not self.is_terminal((x, y)):
                    dx, dy = self.actions[action]
                    ax.annotate('', xy=(x + dx*0.3, y + dy*0.3), xytext=(x, y),
                              arrowprops=arrow_props)
        
        # Add value labels if provided
        if values:
            for (x, y), value in values.items():
                if not self.is_terminal((x, y)):
                    ax.text(x, y + 0.3, f'{value:.2f}', ha='center', va='center', 
                            fontsize=10, color='white', fontweight='bold')
        
        ax.set_title('GridWorld Environment')
        ax.set_xticks(range(self.width))
        ax.set_yticks(range(self.height))
        
        if values:
            plt.colorbar(im, ax=ax, label='State Value')
        
        plt.tight_layout()
        plt.show()

# Create and visualize the environment
env = GridWorld(height=5, width=5)
print(f"GridWorld created: {env.width}x{env.height}")
print(f"Start: {env.start_state}, Goal: {env.goal_state}")
print(f"Obstacles: {env.obstacles}")
print(f"Actions: {env.action_names}")

env.render()

## 3. Value Functions and Bellman Equations

Value functions are fundamental to RL - they estimate how good it is to be in a particular state or to take a particular action in a state.

### State Value Function

The **state value function** $V^\pi(s)$ gives the expected return when starting from state $s$ and following policy $\pi$:

$$V^\pi(s) = \mathbb{E}_\pi[G_t | S_t = s] = \mathbb{E}_\pi\left[\sum_{k=0}^{\infty} \gamma^k R_{t+k+1} \mid S_t = s\right]$$

### Action Value Function (Q-Function)

The **action value function** $Q^\pi(s,a)$ gives the expected return when starting from state $s$, taking action $a$, then following policy $\pi$:

$$Q^\pi(s,a) = \mathbb{E}_\pi[G_t | S_t = s, A_t = a] = \mathbb{E}_\pi\left[\sum_{k=0}^{\infty} \gamma^k R_{t+k+1} \mid S_t = s, A_t = a\right]$$

### Bellman Equations

The **Bellman equation** for $V^\pi$ expresses the recursive relationship:

$$V^\pi(s) = \sum_a \pi(a|s) \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^\pi(s')]$$

The **Bellman equation** for $Q^\pi$:

$$Q^\pi(s,a) = \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma \sum_{a'} \pi(a'|s') Q^\pi(s',a')]$$

### Optimal Value Functions

The **optimal state value function** is:
$$V^*(s) = \max_\pi V^\pi(s)$$

The **optimal action value function** is:
$$Q^*(s,a) = \max_\pi Q^\pi(s,a)$$

### Bellman Optimality Equations

$$V^*(s) = \max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^*(s')]$$

$$Q^*(s,a) = \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma \max_{a'} Q^*(s',a')]$$

## 4. Dynamic Programming: Value Iteration

**Value Iteration** is a dynamic programming algorithm that computes the optimal value function by iteratively applying the Bellman optimality equation.

### Algorithm

1. Initialize $V_0(s)$ arbitrarily for all $s \in \mathcal{S}$
2. For $k = 1, 2, 3, \ldots$ until convergence:
   $$V_{k+1}(s) = \max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V_k(s')]$$
3. Extract optimal policy: $\pi^*(s) = \arg\max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^*(s')]$

### Convergence

Value iteration converges to $V^*$ under the **contraction mapping theorem**. The Bellman operator $T$ defined by:
$$TV(s) = \max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V(s')]$$

is a $\gamma$-contraction, meaning:
$$\|TV_1 - TV_2\|_\infty \leq \gamma \|V_1 - V_2\|_\infty$$

This guarantees convergence when $\gamma < 1$.

In [None]:
class ValueIteration:
    """
    Value Iteration algorithm for solving MDPs.
    
    This dynamic programming approach iteratively updates value estimates
    until convergence to the optimal value function.
    """
    
    def __init__(self, env: GridWorld, gamma: float = 0.9, theta: float = 1e-6):
        self.env = env
        self.gamma = gamma
        self.theta = theta  # Convergence threshold
        
        self.states = env.get_all_states()
        self.num_states = len(self.states)
        self.num_actions = env.num_actions
        
        # Initialize value function
        self.V = {state: 0.0 for state in self.states}
        self.policy = {state: 0 for state in self.states}
        
        # Track learning progress
        self.value_history = []
        self.delta_history = []
    
    def bellman_update(self, state: Tuple[int, int]) -> float:
        """Perform Bellman update for a single state."""
        if self.env.is_terminal(state):
            return 0.0
        
        action_values = []
        
        for action in range(self.num_actions):
            action_value = 0.0
            
            # Sum over all possible next states
            for next_state in self.states:
                prob = self.env.get_transition_prob(state, action, next_state)
                if prob > 0:
                    reward = self.env.get_reward(state, action, next_state)
                    action_value += prob * (reward + self.gamma * self.V[next_state])
            
            action_values.append(action_value)
        
        return max(action_values)
    
    def solve(self, max_iterations: int = 1000, verbose: bool = True) -> Tuple[Dict, Dict]:
        """Solve MDP using value iteration."""
        
        for iteration in range(max_iterations):
            # Store current values for convergence check
            old_V = self.V.copy()
            
            # Update all state values
            for state in self.states:
                self.V[state] = self.bellman_update(state)
            
            # Check convergence
            delta = max(abs(self.V[s] - old_V[s]) for s in self.states)
            self.delta_history.append(delta)
            self.value_history.append(self.V.copy())
            
            if verbose and iteration % 10 == 0:
                print(f"Iteration {iteration}, Max value change: {delta:.6f}")
            
            if delta < self.theta:
                if verbose:
                    print(f"\nConverged after {iteration + 1} iterations!")
                break
        
        # Extract optimal policy
        self.extract_policy()
        
        return self.V, self.policy
    
    def extract_policy(self) -> None:
        """Extract optimal policy from value function."""
        for state in self.states:
            if self.env.is_terminal(state):
                continue
            
            action_values = []
            
            for action in range(self.num_actions):
                action_value = 0.0
                
                for next_state in self.states:
                    prob = self.env.get_transition_prob(state, action, next_state)
                    if prob > 0:
                        reward = self.env.get_reward(state, action, next_state)
                        action_value += prob * (reward + self.gamma * self.V[next_state])
                
                action_values.append(action_value)
            
            # Choose action with highest value
            self.policy[state] = np.argmax(action_values)
    
    def plot_convergence(self) -> None:
        """Plot convergence of value iteration."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot max value change over iterations
        ax1.plot(self.delta_history)
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Max Value Change')
        ax1.set_title('Value Iteration Convergence')
        ax1.set_yscale('log')
        ax1.grid(True, alpha=0.3)
        
        # Plot value function evolution for a few states
        sample_states = [self.env.start_state, (2, 2), (3, 3)]
        for state in sample_states:
            if state in self.states and not self.env.is_terminal(state):
                values = [v_dict[state] for v_dict in self.value_history]
                ax2.plot(values, label=f'State {state}')
        
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('State Value')
        ax2.set_title('Value Function Evolution')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Solve GridWorld with Value Iteration
print("Solving GridWorld with Value Iteration...")
vi_solver = ValueIteration(env, gamma=0.9)
optimal_values, optimal_policy = vi_solver.solve()

print(f"\nOptimal Values (sample):")
for i, (state, value) in enumerate(list(optimal_values.items())[:8]):
    print(f"V*{state} = {value:.4f}")

print(f"\nOptimal Policy (sample):")
for i, (state, action) in enumerate(list(optimal_policy.items())[:8]):
    if not env.is_terminal(state):
        print(f"π*{state} = {env.action_names[action]}")

# Plot convergence
vi_solver.plot_convergence()

In [None]:
# Visualize the solution
print("Optimal Value Function:")
env.render(values=optimal_values)

print("\nOptimal Policy:")
env.render(policy=optimal_policy)

## 5. Dynamic Programming: Policy Iteration

**Policy Iteration** alternates between policy evaluation and policy improvement until convergence to the optimal policy.

### Algorithm

1. **Initialize** policy $\pi_0$ arbitrarily
2. **Repeat** until policy converges:
   - **Policy Evaluation**: Compute $V^{\pi_k}$ by solving:
     $$V^{\pi_k}(s) = \sum_a \pi_k(a|s) \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^{\pi_k}(s')]$$
   - **Policy Improvement**: Update policy greedily:
     $$\pi_{k+1}(s) = \arg\max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^{\pi_k}(s')]$$

### Policy Improvement Theorem

If $\pi'$ is the greedy policy with respect to $V^\pi$, then:
$$V^{\pi'}(s) \geq V^\pi(s) \text{ for all } s$$

This guarantees that policy iteration converges to the optimal policy in finite steps.

In [None]:
class PolicyIteration:
    """
    Policy Iteration algorithm for solving MDPs.
    
    Alternates between policy evaluation and policy improvement
    until convergence to the optimal policy.
    """
    
    def __init__(self, env: GridWorld, gamma: float = 0.9, theta: float = 1e-6):
        self.env = env
        self.gamma = gamma
        self.theta = theta
        
        self.states = env.get_all_states()
        self.num_states = len(self.states)
        self.num_actions = env.num_actions
        
        # Initialize random policy and zero values
        self.V = {state: 0.0 for state in self.states}
        self.policy = {state: np.random.randint(self.num_actions) for state in self.states}
        
        # Track progress
        self.policy_history = []
        self.value_history = []
    
    def policy_evaluation(self, max_iterations: int = 1000) -> None:
        """Evaluate current policy until convergence."""
        for iteration in range(max_iterations):
            old_V = self.V.copy()
            
            for state in self.states:
                if self.env.is_terminal(state):
                    self.V[state] = 0.0
                    continue
                
                action = self.policy[state]
                value = 0.0
                
                for next_state in self.states:
                    prob = self.env.get_transition_prob(state, action, next_state)
                    if prob > 0:
                        reward = self.env.get_reward(state, action, next_state)
                        value += prob * (reward + self.gamma * self.V[next_state])
                
                self.V[state] = value
            
            # Check convergence
            delta = max(abs(self.V[s] - old_V[s]) for s in self.states)
            if delta < self.theta:
                break
    
    def policy_improvement(self) -> bool:
        """Improve policy greedily. Returns True if policy changed."""
        old_policy = self.policy.copy()
        
        for state in self.states:
            if self.env.is_terminal(state):
                continue
            
            action_values = []
            
            for action in range(self.num_actions):
                action_value = 0.0
                
                for next_state in self.states:
                    prob = self.env.get_transition_prob(state, action, next_state)
                    if prob > 0:
                        reward = self.env.get_reward(state, action, next_state)
                        action_value += prob * (reward + self.gamma * self.V[next_state])
                
                action_values.append(action_value)
            
            self.policy[state] = np.argmax(action_values)
        
        # Check if policy changed
        policy_changed = any(old_policy[s] != self.policy[s] for s in self.states
                           if not self.env.is_terminal(s))
        
        return policy_changed
    
    def solve(self, max_iterations: int = 100, verbose: bool = True) -> Tuple[Dict, Dict]:
        """Solve MDP using policy iteration."""
        
        for iteration in range(max_iterations):
            if verbose:
                print(f"Policy Iteration {iteration + 1}")
            
            # Policy Evaluation
            self.policy_evaluation()
            self.value_history.append(self.V.copy())
            
            # Policy Improvement
            policy_changed = self.policy_improvement()
            self.policy_history.append(self.policy.copy())
            
            if not policy_changed:
                if verbose:
                    print(f"\nPolicy converged after {iteration + 1} iterations!")
                break
        
        return self.V, self.policy
    
    def plot_progress(self) -> None:
        """Plot policy iteration progress."""
        fig, axes = plt.subplots(1, min(len(self.policy_history), 4), figsize=(16, 4))
        
        if len(self.policy_history) == 1:
            axes = [axes]
        
        for i, policy in enumerate(self.policy_history[:4]):
            ax = axes[i] if len(self.policy_history) > 1 else axes[0]
            
            # Create policy visualization
            policy_grid = np.zeros((self.env.height, self.env.width))
            
            for (x, y), action in policy.items():
                if not self.env.is_terminal((x, y)):
                    policy_grid[y, x] = action
            
            im = ax.imshow(policy_grid, cmap='tab10', alpha=0.7)
            
            # Add arrows
            arrow_props = dict(arrowstyle='->', lw=2, color='blue')
            for (x, y), action in policy.items():
                if not self.env.is_terminal((x, y)):
                    dx, dy = self.env.actions[action]
                    ax.annotate('', xy=(x + dx*0.3, y + dy*0.3), xytext=(x, y),
                              arrowprops=arrow_props)
            
            ax.set_title(f'Policy Iteration {i + 1}')
            ax.set_xticks(range(self.env.width))
            ax.set_yticks(range(self.env.height))
        
        plt.tight_layout()
        plt.show()

# Solve with Policy Iteration
print("Solving GridWorld with Policy Iteration...")
pi_solver = PolicyIteration(env, gamma=0.9)
pi_values, pi_policy = pi_solver.solve()

print(f"\nPolicy Iteration completed!")
print(f"Number of policy iterations: {len(pi_solver.policy_history)}")

# Compare with Value Iteration
print("\nComparison with Value Iteration:")
print(f"Value difference (max): {max(abs(optimal_values[s] - pi_values[s]) for s in env.get_all_states()):.8f}")

policy_diff = sum(1 for s in env.get_all_states() 
                  if not env.is_terminal(s) and optimal_policy[s] != pi_policy[s])
print(f"Policy differences: {policy_diff} states")

# Plot policy evolution
pi_solver.plot_progress()

## 6. Q-Learning: Model-Free Temporal Difference Learning

**Q-Learning** is an off-policy temporal difference learning algorithm that directly learns the optimal action-value function without needing a model of the environment.

### Algorithm

Q-Learning uses the following update rule:

$$Q(s,a) \leftarrow Q(s,a) + \alpha [R + \gamma \max_{a'} Q(s',a') - Q(s,a)]$$

Where:
- $\alpha \in (0,1]$ is the learning rate
- $R$ is the immediate reward
- $s'$ is the next state
- The term $[R + \gamma \max_{a'} Q(s',a') - Q(s,a)]$ is called the **TD error**

### Key Properties

1. **Off-policy**: Q-learning can learn the optimal policy while following any exploratory policy
2. **Model-free**: No need to know transition probabilities or reward function
3. **Convergence**: Under certain conditions (visiting all state-action pairs infinitely often, decreasing learning rate), Q-learning converges to $Q^*$

### Exploration vs Exploitation

Q-learning requires balancing exploration and exploitation. Common strategies:

- **ε-greedy**: Choose random action with probability $\epsilon$, otherwise choose $\arg\max_a Q(s,a)$
- **ε-decay**: Decrease $\epsilon$ over time to reduce exploration
- **Boltzmann exploration**: Choose actions probabilistically based on Q-values

In [None]:
class QLearningAgent:
    """
    Q-Learning agent for model-free reinforcement learning.
    
    Learns optimal action-value function through temporal difference updates
    without requiring knowledge of environment dynamics.
    """
    
    def __init__(self, env: GridWorld, alpha: float = 0.1, gamma: float = 0.9, 
                 epsilon: float = 0.1, epsilon_decay: float = 0.995, 
                 epsilon_min: float = 0.01):
        self.env = env
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        
        self.states = env.get_all_states()
        self.num_actions = env.num_actions
        
        # Initialize Q-table
        self.Q = defaultdict(lambda: np.zeros(self.num_actions))
        
        # Track learning progress
        self.episode_rewards = []
        self.episode_lengths = []
        self.td_errors = []
        self.q_value_history = []
    
    def choose_action(self, state: Tuple[int, int], training: bool = True) -> int:
        """Choose action using ε-greedy policy."""
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.num_actions)
        else:
            return np.argmax(self.Q[state])
    
    def update_q_value(self, state: Tuple[int, int], action: int, reward: float, 
                       next_state: Tuple[int, int], done: bool) -> float:
        """Update Q-value using Q-learning rule."""
        # Current Q-value
        current_q = self.Q[state][action]
        
        # Target Q-value
        if done:
            target_q = reward
        else:
            target_q = reward + self.gamma * np.max(self.Q[next_state])
        
        # TD error
        td_error = target_q - current_q
        
        # Q-learning update
        self.Q[state][action] = current_q + self.alpha * td_error
        
        return abs(td_error)
    
    def train(self, num_episodes: int = 1000, verbose: bool = True) -> None:
        """Train the Q-learning agent."""
        
        for episode in range(num_episodes):
            state = self.env.reset()
            episode_reward = 0
            episode_length = 0
            episode_td_errors = []
            
            while True:
                # Choose action
                action = self.choose_action(state, training=True)
                
                # Take action
                next_state, reward, done, _ = self.env.step(action)
                
                # Update Q-value
                td_error = self.update_q_value(state, action, reward, next_state, done)
                episode_td_errors.append(td_error)
                
                episode_reward += reward
                episode_length += 1
                
                if done:
                    break
                
                state = next_state
            
            # Store episode statistics
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_length)
            self.td_errors.extend(episode_td_errors)
            
            # Decay epsilon
            if self.epsilon > self.epsilon_min:
                self.epsilon *= self.epsilon_decay
            
            # Track Q-values periodically
            if episode % 100 == 0:
                sample_q_values = {state: np.max(self.Q[state]) for state in self.states[:5]}
                self.q_value_history.append(sample_q_values)
                
                if verbose:
                    avg_reward = np.mean(self.episode_rewards[-100:])
                    avg_length = np.mean(self.episode_lengths[-100:])
                    print(f"Episode {episode}: Avg Reward = {avg_reward:.3f}, "
                          f"Avg Length = {avg_length:.1f}, ε = {self.epsilon:.3f}")
    
    def get_policy(self) -> Dict[Tuple[int, int], int]:
        """Extract greedy policy from Q-values."""
        policy = {}
        for state in self.states:
            if not self.env.is_terminal(state):
                policy[state] = np.argmax(self.Q[state])
        return policy
    
    def get_value_function(self) -> Dict[Tuple[int, int], float]:
        """Extract value function from Q-values: V(s) = max_a Q(s,a)."""
        return {state: np.max(self.Q[state]) for state in self.states}
    
    def plot_training_progress(self) -> None:
        """Plot training progress metrics."""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Episode rewards
        axes[0, 0].plot(self.episode_rewards, alpha=0.7)
        # Moving average
        window = min(100, len(self.episode_rewards) // 10)
        if window > 1:
            moving_avg = pd.Series(self.episode_rewards).rolling(window=window).mean()
            axes[0, 0].plot(moving_avg, 'r-', linewidth=2, label=f'{window}-episode average')
            axes[0, 0].legend()
        
        axes[0, 0].set_xlabel('Episode')
        axes[0, 0].set_ylabel('Episode Reward')
        axes[0, 0].set_title('Learning Progress: Episode Rewards')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Episode lengths
        axes[0, 1].plot(self.episode_lengths, alpha=0.7)
        if window > 1:
            moving_avg = pd.Series(self.episode_lengths).rolling(window=window).mean()
            axes[0, 1].plot(moving_avg, 'r-', linewidth=2, label=f'{window}-episode average')
            axes[0, 1].legend()
        
        axes[0, 1].set_xlabel('Episode')
        axes[0, 1].set_ylabel('Episode Length')
        axes[0, 1].set_title('Learning Progress: Episode Lengths')
        axes[0, 1].grid(True, alpha=0.3)
        
        # TD errors
        if len(self.td_errors) > 100:
            td_window = min(1000, len(self.td_errors) // 20)
            moving_td_avg = pd.Series(self.td_errors).rolling(window=td_window).mean()
            axes[1, 0].plot(moving_td_avg)
        else:
            axes[1, 0].plot(self.td_errors)
        
        axes[1, 0].set_xlabel('Update Step')
        axes[1, 0].set_ylabel('Average TD Error')
        axes[1, 0].set_title('Learning Progress: TD Error')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Q-value evolution
        if self.q_value_history:
            for state in list(self.q_value_history[0].keys())[:3]:
                values = [q_dict[state] for q_dict in self.q_value_history if state in q_dict]
                axes[1, 1].plot(values, label=f'State {state}')
            
            axes[1, 1].set_xlabel('Training Checkpoint')
            axes[1, 1].set_ylabel('Max Q-Value')
            axes[1, 1].set_title('Q-Value Evolution')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Train Q-Learning agent
print("Training Q-Learning agent...")
q_agent = QLearningAgent(env, alpha=0.1, gamma=0.9, epsilon=0.1, epsilon_decay=0.995)
q_agent.train(num_episodes=1000, verbose=True)

# Extract learned policy and values
q_policy = q_agent.get_policy()
q_values = q_agent.get_value_function()

print(f"\nQ-Learning training completed!")
print(f"Final exploration rate: {q_agent.epsilon:.4f}")
print(f"Average reward (last 100 episodes): {np.mean(q_agent.episode_rewards[-100:]):.4f}")

In [None]:
# Plot training progress
q_agent.plot_training_progress()

In [None]:
# Compare Q-Learning results with optimal solution
print("Q-Learning Learned Value Function:")
env.render(values=q_values)

print("\nQ-Learning Learned Policy:")
env.render(policy=q_policy)

# Quantitative comparison
print("\nComparison with Optimal Solution:")
value_diff = max(abs(optimal_values[s] - q_values[s]) for s in env.get_all_states())
print(f"Max value difference: {value_diff:.4f}")

policy_diff = sum(1 for s in env.get_all_states() 
                  if not env.is_terminal(s) and optimal_policy[s] != q_policy[s])
print(f"Policy differences: {policy_diff} states")

# Show Q-table for start state
start_state = env.start_state
print(f"\nQ-values for start state {start_state}:")
for action, q_val in enumerate(q_agent.Q[start_state]):
    print(f"  {env.action_names[action]}: {q_val:.4f}")

## 7. SARSA: On-Policy Temporal Difference Learning

**SARSA** (State-Action-Reward-State-Action) is an on-policy temporal difference learning algorithm that learns the Q-value for the policy being followed.

### Algorithm

SARSA uses the following update rule:

$$Q(s,a) \leftarrow Q(s,a) + \alpha [R + \gamma Q(s',a') - Q(s,a)]$$

Where $a'$ is the action actually taken in state $s'$ (not the maximum as in Q-learning).

### Key Differences from Q-Learning

1. **On-policy vs Off-policy**: 
   - SARSA learns about the policy being followed (including exploration)
   - Q-learning learns about the optimal policy regardless of behavior

2. **Update target**:
   - SARSA: $R + \gamma Q(s',a')$ (uses actual next action)
   - Q-learning: $R + \gamma \max_{a'} Q(s',a')$ (uses optimal next action)

3. **Convergence**:
   - SARSA converges to optimal policy if exploration decreases appropriately
   - Q-learning converges to optimal Q-function even with fixed exploration

### Expected SARSA

A variant that uses the expected value under the current policy:

$$Q(s,a) \leftarrow Q(s,a) + \alpha [R + \gamma \sum_{a'} \pi(a'|s') Q(s',a') - Q(s,a)]$$

In [None]:
class SARSAAgent:
    """
    SARSA agent for on-policy temporal difference learning.
    
    Learns Q-values for the policy being followed, making it more
    conservative than Q-learning in stochastic environments.
    """
    
    def __init__(self, env: GridWorld, alpha: float = 0.1, gamma: float = 0.9, 
                 epsilon: float = 0.1, epsilon_decay: float = 0.995, 
                 epsilon_min: float = 0.01):
        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        
        self.states = env.get_all_states()
        self.num_actions = env.num_actions
        
        # Initialize Q-table
        self.Q = defaultdict(lambda: np.zeros(self.num_actions))
        
        # Track learning progress
        self.episode_rewards = []
        self.episode_lengths = []
        self.td_errors = []
    
    def choose_action(self, state: Tuple[int, int], training: bool = True) -> int:
        """Choose action using ε-greedy policy."""
        if training and np.random.random() < self.epsilon:
            return np.random.randint(self.num_actions)
        else:
            return np.argmax(self.Q[state])
    
    def update_q_value(self, state: Tuple[int, int], action: int, reward: float, 
                       next_state: Tuple[int, int], next_action: int, done: bool) -> float:
        """Update Q-value using SARSA rule."""
        # Current Q-value
        current_q = self.Q[state][action]
        
        # Target Q-value (key difference from Q-learning)
        if done:
            target_q = reward
        else:
            target_q = reward + self.gamma * self.Q[next_state][next_action]
        
        # TD error
        td_error = target_q - current_q
        
        # SARSA update
        self.Q[state][action] = current_q + self.alpha * td_error
        
        return abs(td_error)
    
    def train(self, num_episodes: int = 1000, verbose: bool = True) -> None:
        """Train the SARSA agent."""
        
        for episode in range(num_episodes):
            state = self.env.reset()
            action = self.choose_action(state, training=True)
            
            episode_reward = 0
            episode_length = 0
            episode_td_errors = []
            
            while True:
                # Take action
                next_state, reward, done, _ = self.env.step(action)
                
                # Choose next action (important for SARSA)
                if not done:
                    next_action = self.choose_action(next_state, training=True)
                else:
                    next_action = None
                
                # Update Q-value
                td_error = self.update_q_value(state, action, reward, next_state, next_action, done)
                episode_td_errors.append(td_error)
                
                episode_reward += reward
                episode_length += 1
                
                if done:
                    break
                
                # Move to next state-action pair
                state = next_state
                action = next_action
            
            # Store episode statistics
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_length)
            self.td_errors.extend(episode_td_errors)
            
            # Decay epsilon
            if self.epsilon > self.epsilon_min:
                self.epsilon *= self.epsilon_decay
            
            if episode % 100 == 0 and verbose:
                avg_reward = np.mean(self.episode_rewards[-100:])
                avg_length = np.mean(self.episode_lengths[-100:])
                print(f"Episode {episode}: Avg Reward = {avg_reward:.3f}, "
                      f"Avg Length = {avg_length:.1f}, ε = {self.epsilon:.3f}")
    
    def get_policy(self) -> Dict[Tuple[int, int], int]:
        """Extract greedy policy from Q-values."""
        policy = {}
        for state in self.states:
            if not self.env.is_terminal(state):
                policy[state] = np.argmax(self.Q[state])
        return policy
    
    def get_value_function(self) -> Dict[Tuple[int, int], float]:
        """Extract value function from Q-values."""
        return {state: np.max(self.Q[state]) for state in self.states}

# Train SARSA agent
print("Training SARSA agent...")
sarsa_agent = SARSAAgent(env, alpha=0.1, gamma=0.9, epsilon=0.1, epsilon_decay=0.995)
sarsa_agent.train(num_episodes=1000, verbose=True)

# Extract learned policy and values
sarsa_policy = sarsa_agent.get_policy()
sarsa_values = sarsa_agent.get_value_function()

print(f"\nSARSA training completed!")
print(f"Final exploration rate: {sarsa_agent.epsilon:.4f}")
print(f"Average reward (last 100 episodes): {np.mean(sarsa_agent.episode_rewards[-100:]):.4f}")

## 8. Algorithm Comparison and Analysis

Let's compare all the algorithms we've implemented to understand their strengths, weaknesses, and convergence properties.

In [None]:
# Comprehensive comparison of all algorithms
def compare_algorithms():
    """Compare all implemented algorithms."""
    
    algorithms = {
        'Value Iteration': {
            'values': optimal_values,
            'policy': optimal_policy,
            'type': 'Dynamic Programming',
            'model_free': False,
            'iterations': len(vi_solver.delta_history)
        },
        'Policy Iteration': {
            'values': pi_values,
            'policy': pi_policy,
            'type': 'Dynamic Programming', 
            'model_free': False,
            'iterations': len(pi_solver.policy_history)
        },
        'Q-Learning': {
            'values': q_values,
            'policy': q_policy,
            'type': 'Temporal Difference (Off-policy)',
            'model_free': True,
            'episodes': len(q_agent.episode_rewards)
        },
        'SARSA': {
            'values': sarsa_values,
            'policy': sarsa_policy,
            'type': 'Temporal Difference (On-policy)',
            'model_free': True,
            'episodes': len(sarsa_agent.episode_rewards)
        }
    }
    
    print("=== Algorithm Comparison ===")
    print(f"{'Algorithm':<18} {'Type':<30} {'Model-Free':<12} {'Iterations/Episodes':<20}")
    print("-" * 85)
    
    for name, info in algorithms.items():
        iters = info.get('iterations', info.get('episodes', 'N/A'))
        print(f"{name:<18} {info['type']:<30} {info['model_free']:<12} {iters:<20}")
    
    # Value function comparison
    print("\n=== Value Function Comparison ===")
    baseline_values = optimal_values  # Use Value Iteration as baseline
    
    sample_states = [(0, 0), (1, 1), (2, 2), (3, 3)]
    sample_states = [s for s in sample_states if s in env.get_all_states() and not env.is_terminal(s)]
    
    print(f"\n{'State':<12}", end="")
    for name in algorithms.keys():
        print(f"{name:<18}", end="")
    print()
    print("-" * (12 + 18 * len(algorithms)))
    
    for state in sample_states:
        print(f"{str(state):<12}", end="")
        for name, info in algorithms.items():
            value = info['values'].get(state, 0.0)
            print(f"{value:<18.4f}", end="")
        print()
    
    # Policy agreement
    print("\n=== Policy Comparison ===")
    non_terminal_states = [s for s in env.get_all_states() if not env.is_terminal(s)]
    
    for i, (name1, info1) in enumerate(algorithms.items()):
        for j, (name2, info2) in enumerate(algorithms.items()):
            if i < j:
                agreements = sum(1 for s in non_terminal_states 
                               if info1['policy'][s] == info2['policy'][s])
                total = len(non_terminal_states)
                percentage = (agreements / total) * 100
                print(f"{name1} vs {name2}: {agreements}/{total} states agree ({percentage:.1f}%)")

compare_algorithms()

In [None]:
# Plot learning curves comparison
def plot_learning_comparison():
    """Plot learning curves for temporal difference methods."""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Episode rewards
    window = 50
    q_rewards_smooth = pd.Series(q_agent.episode_rewards).rolling(window=window).mean()
    sarsa_rewards_smooth = pd.Series(sarsa_agent.episode_rewards).rolling(window=window).mean()
    
    axes[0, 0].plot(q_rewards_smooth, label='Q-Learning', alpha=0.8)
    axes[0, 0].plot(sarsa_rewards_smooth, label='SARSA', alpha=0.8)
    axes[0, 0].set_xlabel('Episode')
    axes[0, 0].set_ylabel('Average Episode Reward')
    axes[0, 0].set_title('Learning Progress: Episode Rewards')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Episode lengths
    q_lengths_smooth = pd.Series(q_agent.episode_lengths).rolling(window=window).mean()
    sarsa_lengths_smooth = pd.Series(sarsa_agent.episode_lengths).rolling(window=window).mean()
    
    axes[0, 1].plot(q_lengths_smooth, label='Q-Learning', alpha=0.8)
    axes[0, 1].plot(sarsa_lengths_smooth, label='SARSA', alpha=0.8)
    axes[0, 1].set_xlabel('Episode')
    axes[0, 1].set_ylabel('Average Episode Length')
    axes[0, 1].set_title('Learning Progress: Episode Lengths')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Value iteration convergence
    axes[1, 0].plot(vi_solver.delta_history)
    axes[1, 0].set_xlabel('Iteration')
    axes[1, 0].set_ylabel('Max Value Change')
    axes[1, 0].set_title('Value Iteration Convergence')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Final performance comparison
    algorithms = ['Value Iter.', 'Policy Iter.', 'Q-Learning', 'SARSA']
    final_rewards = [
        1.0,  # Optimal (assumed perfect performance)
        1.0,  # Optimal (assumed perfect performance)
        np.mean(q_agent.episode_rewards[-100:]),
        np.mean(sarsa_agent.episode_rewards[-100:])
    ]
    
    colors = ['blue', 'green', 'red', 'orange']
    bars = axes[1, 1].bar(algorithms, final_rewards, color=colors, alpha=0.7)
    axes[1, 1].set_ylabel('Average Reward (Last 100 Episodes)')
    axes[1, 1].set_title('Final Performance Comparison')
    axes[1, 1].set_ylim(min(final_rewards) - 0.1, max(final_rewards) + 0.1)
    
    # Add value labels on bars
    for bar, reward in zip(bars, final_rewards):
        axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                       f'{reward:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

plot_learning_comparison()

## 9. Key Takeaways and Insights

From our implementation and comparison of tabular RL methods, we can draw several important insights:

### Mathematical Foundations
1. **Bellman Equations** are the cornerstone of RL, providing both the theoretical foundation and practical algorithms
2. **Contraction Mapping** properties ensure convergence under appropriate conditions
3. **Optimal Value Functions** satisfy the Bellman optimality equations

### Algorithm Characteristics

**Dynamic Programming (Value/Policy Iteration)**:
- Requires complete model knowledge
- Guaranteed optimal solutions
- Fast convergence (few iterations)
- Computationally expensive per iteration

**Q-Learning**:
- Model-free and off-policy
- Learns optimal policy regardless of behavior
- More aggressive exploration can help
- Converges to optimal Q-function

**SARSA**:
- Model-free and on-policy
- Learns about the policy being followed
- More conservative in risky environments
- Better for safety-critical applications

### Practical Considerations
- **Exploration vs Exploitation**: Critical for TD methods
- **Learning Rate**: Affects convergence speed and stability
- **Discount Factor**: Controls planning horizon
- **Environment Complexity**: Affects sample complexity

In [None]:
# Test learned policies by running episodes
def test_policy(agent, num_episodes: int = 10, policy_name: str = "Agent") -> None:
    """Test a learned policy by running episodes."""
    
    episode_rewards = []
    episode_lengths = []
    
    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0
        episode_length = 0
        
        print(f"\n{policy_name} Episode {episode + 1}: ", end="")
        
        while episode_length < 50:  # Prevent infinite loops
            if hasattr(agent, 'choose_action'):
                action = agent.choose_action(state, training=False)  # No exploration
            else:
                # For dictionaries (optimal policies from DP)
                action = agent[state]
            
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            episode_length += 1
            
            print(f"{env.action_names[action][0]}", end="")
            
            if done:
                print(f" -> Goal! Reward: {episode_reward:.3f}, Length: {episode_length}")
                break
            
            state = next_state
        
        if not done:
            print(f" -> Timeout. Reward: {episode_reward:.3f}, Length: {episode_length}")
        
        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_length)
    
    print(f"\n{policy_name} Average Performance:")
    print(f"  Average Reward: {np.mean(episode_rewards):.4f} ± {np.std(episode_rewards):.4f}")
    print(f"  Average Length: {np.mean(episode_lengths):.2f} ± {np.std(episode_lengths):.2f}")
    print(f"  Success Rate: {sum(1 for r in episode_rewards if r > 0.5) / len(episode_rewards) * 100:.1f}%")

print("Testing learned policies...")
print("=" * 50)

# Test optimal policy
test_policy(optimal_policy, num_episodes=5, policy_name="Optimal (Value Iteration)")

# Test Q-learning policy
test_policy(q_agent, num_episodes=5, policy_name="Q-Learning")

# Test SARSA policy  
test_policy(sarsa_agent, num_episodes=5, policy_name="SARSA")

## Summary

In this notebook, we've covered the fundamental concepts and algorithms of reinforcement learning:

### **Mathematical Framework**
- **Markov Decision Processes**: The mathematical foundation of sequential decision making
- **Bellman Equations**: Recursive relationships that define optimal value functions
- **Value Functions**: Tools for evaluating states and actions

### **Dynamic Programming**
- **Value Iteration**: Direct application of Bellman optimality equations
- **Policy Iteration**: Alternating policy evaluation and improvement
- **Convergence Guarantees**: Theoretical foundation for algorithm correctness

### **Temporal Difference Learning**
- **Q-Learning**: Off-policy learning of optimal action-value function
- **SARSA**: On-policy learning with exploration considerations
- **Model-Free Learning**: Learning without environment dynamics knowledge

### **Key Insights**
- **Exploration vs Exploitation**: Fundamental trade-off in RL
- **On-Policy vs Off-Policy**: Different learning paradigms with distinct properties
- **Sample Complexity**: Model-free methods require more interaction data
- **Convergence Properties**: Different algorithms have different theoretical guarantees

These tabular methods form the foundation for understanding more advanced RL algorithms. In the next notebook, we'll explore Monte Carlo methods and more sophisticated temporal difference techniques that will bridge us toward function approximation and deep reinforcement learning.