In [None]:
import numpy as np
from scipy.stats import poisson
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

In [None]:
# Precompute all Poisson probabilities once (huge speedup!)
num_req_ret = 11
req1_probs = np.array([poisson.pmf(k, 3) for k in range(num_req_ret)])
ret1_probs = np.array([poisson.pmf(k, 3) for k in range(num_req_ret)])
req2_probs = np.array([poisson.pmf(k, 4) for k in range(num_req_ret)])
ret2_probs = np.array([poisson.pmf(k, 2) for k in range(num_req_ret)])

# Store V as a 2D array instead of list of objects
V = np.zeros((21, 21), dtype=np.float64)
policy = np.zeros((21, 21), dtype=np.int32)

# Store policy history for visualization
policy_history = []

In [None]:
def compute_state_value_vectorized(loc1, loc2, action, gamma=0.9):
    """Vectorized computation for a single state."""
    
    net_move = action
    loc1_today = loc1 - net_move
    loc2_today = loc2 + net_move
    
    # Check validity
    if loc1_today < 0 or loc2_today < 0:
        return 0.0
    
    # Create all combinations using meshgrid (vectorized!)
    rq1_vals = np.arange(num_req_ret)
    re1_vals = np.arange(num_req_ret)
    rq2_vals = np.arange(num_req_ret)
    re2_vals = np.arange(num_req_ret)
    
    rq1, re1, rq2, re2 = np.meshgrid(rq1_vals, re1_vals, rq2_vals, re2_vals, indexing='ij')
    
    # Compute all probabilities at once (vectorized!)
    probs = (req1_probs[rq1] * ret1_probs[re1] * 
             req2_probs[rq2] * ret2_probs[re2])
    
    # Compute actual rentals (vectorized!)
    actual_rentals_1 = np.minimum(rq1, loc1_today)
    actual_rentals_2 = np.minimum(rq2, loc2_today)
    
    # Compute rewards (vectorized!)
    rewards = -2 * abs(net_move) + 10 * actual_rentals_1 + 10 * actual_rentals_2
    
    # Compute next states (vectorized!)
    new_loc1 = np.clip(loc1_today - actual_rentals_1 + re1, 0, 20)
    new_loc2 = np.clip(loc2_today - actual_rentals_2 + re2, 0, 20)
    
    # Get next state values (vectorized indexing!)
    next_values = V[new_loc1, new_loc2]
    
    # Compute expected value (single sum over all combinations)
    state_value = np.sum(probs * (rewards + gamma * next_values))
    
    return state_value

In [None]:
def run_policy_evaluation_vectorized(theta=0.01, gamma=0.9, max_iterations=100):
    """Policy evaluation with vectorized state value computation."""
    
    global V
    
    for iteration in range(max_iterations):
        V_new = np.zeros_like(V)
        
        # Loop over states (can't easily vectorize this outer loop without huge memory)
        for loc1 in range(21):
            for loc2 in range(21):
                action = policy[loc1, loc2]
                V_new[loc1, loc2] = compute_state_value_vectorized(loc1, loc2, action, gamma)
        
        # Compute delta
        delta = np.max(np.abs(V - V_new))
        
        #print(f"Iteration {iteration + 1}: delta = {delta:.6f}")
        
        V = V_new
        
        if delta < theta:
            #print(f"Converged after {iteration + 1} iterations!")
            break
    
    return V

In [None]:
def compute_action_values_vectorized(loc1, loc2, gamma=0.9):
    """Compute Q(s,a) for all actions at once."""
    
    actions = np.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5])
    action_values = np.full(11, -np.inf)
    
    for i, action in enumerate(actions):
        # Check validity
        loc1_after = loc1 - action
        loc2_after = loc2 + action
        
        if loc1_after < 0 or loc2_after < 0:
            continue
        
        action_values[i] = compute_state_value_vectorized(loc1, loc2, action, gamma)
    
    return actions[np.argmax(action_values)]

def run_policy_improvement_vectorized(gamma=0.9):
    """Policy improvement with vectorized action selection."""
    
    global policy
    
    policy_new = np.zeros_like(policy)
    
    for loc1 in range(21):
        for loc2 in range(21):
            policy_new[loc1, loc2] = compute_action_values_vectorized(loc1, loc2, gamma)
    
    # Check stability
    policy_changes = np.sum(policy_new != policy)
    print(f"Policy changes: {policy_changes}")
    
    policy = policy_new
    
    return policy_changes == 0

In [None]:
def run_policy_iteration_vectorized(theta=0.01, gamma=0.9, max_iterations=100):
    """Complete policy iteration with vectorization."""
    
    global policy_history
    policy_history = []

    # Save initial policy
    policy_history.append(policy.copy())

    for iteration in range(max_iterations):
        print(f"\n=== Policy Iteration {iteration + 1} ===")
        
        # Policy Evaluation
        print("Running policy evaluation...")
        run_policy_evaluation_vectorized(theta, gamma)
        
        # Policy Improvement
        print("Running policy improvement...")
        is_stable = run_policy_improvement_vectorized(gamma)

        # Save policy after improvement
        policy_history.append(policy.copy())
        
        if is_stable:
            print("Policy is stable! Converged.")
            break
    
    return V, policy

In [None]:
def plot_policy_progression():
    """Plot the policy after each iteration."""
    
    n_policies = len(policy_history)
    
    # Calculate grid dimensions for subplots
    n_cols = min(3, n_policies)
    n_rows = (n_policies + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
    
    # Handle single subplot case
    if n_policies == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for idx, pol in enumerate(policy_history):
        ax = axes[idx]
        
        # Create contour plot with swapped axes
        # X-axis: location 2, Y-axis: location 1
        im = ax.contourf(range(21), range(21), pol, levels=20, cmap='RdYlGn')
        ax.contour(range(21), range(21), pol, levels=10, colors='black', alpha=0.3, linewidths=0.5)
        
        ax.set_xlabel('# Cars at Location 2', fontsize=10)
        ax.set_ylabel('# Cars at Location 1', fontsize=10)
        
        if idx == 0:
            ax.set_title(f'Initial Policy (π₀)', fontsize=12, fontweight='bold')
        elif idx == len(policy_history) - 1:
            ax.set_title(f'Final Policy (π*) - Iteration {idx}', fontsize=12, fontweight='bold')
        else:
            ax.set_title(f'Policy after Iteration {idx}', fontsize=12, fontweight='bold')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Action (cars moved)', rotation=270, labelpad=15)
        
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(n_policies, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('policy_progression.png', dpi=150, bbox_inches='tight')
    print("\nSaved policy progression to 'policy_progression.png'")
    plt.show()

def plot_optimal_value_function():
    """Plot the optimal state-value function."""
    
    fig = plt.figure(figsize=(16, 6))
    
    # 3D Surface Plot
    ax1 = fig.add_subplot(121, projection='3d')
    
    # X-axis: location 2, Y-axis: location 1
    X, Y = np.meshgrid(range(21), range(21))
    
    surf = ax1.plot_surface(X, Y, V.T, cmap='viridis', 
                            edgecolor='none', alpha=0.9)
    
    ax1.set_xlabel('# Cars at Location 2', fontsize=11)
    ax1.set_ylabel('# Cars at Location 1', fontsize=11)
    ax1.set_zlabel('State Value', fontsize=11)
    ax1.set_title('Optimal State-Value Function V*(s)', fontsize=13, fontweight='bold')
    
    fig.colorbar(surf, ax=ax1, shrink=0.5, aspect=5)
    
    # Set better viewing angle
    ax1.view_init(elev=25, azim=45)
    
    # 2D Heatmap with values
    ax2 = fig.add_subplot(122)
    
    # origin='lower' makes (0,0) at bottom-left
    # X-axis: location 2, Y-axis: location 1
    im = ax2.imshow(V, cmap='viridis', origin='lower', aspect='auto')
    
    ax2.set_xlabel('# Cars at Location 2', fontsize=11)
    ax2.set_ylabel('# Cars at Location 1', fontsize=11)
    ax2.set_title('Optimal State-Value Function V*(s) - Heatmap', fontsize=13, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax2)
    cbar.set_label('State Value', rotation=270, labelpad=20)
    
    # Add text annotations for a subset of cells (every 3rd cell to avoid clutter)
    for i in range(0, 21, 3):  # i = location 1 (y-axis)
        for j in range(0, 21, 3):  # j = location 2 (x-axis)
            text = ax2.text(j, i, f'{V[i, j]:.0f}',
                           ha="center", va="center", color="white", fontsize=7)
    
    ax2.set_xticks(range(0, 21, 2))
    ax2.set_yticks(range(0, 21, 2))
    ax2.grid(True, alpha=0.3, color='white', linewidth=0.5)
    
    plt.tight_layout()
    plt.savefig('optimal_value_function.png', dpi=150, bbox_inches='tight')
    print("Saved optimal value function to 'optimal_value_function.png'")
    plt.show()

def plot_optimal_policy_detailed():
    """Plot the optimal policy with arrows showing car movements."""
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Create contour plot
    # X-axis: location 2, Y-axis: location 1
    im = ax.contourf(range(21), range(21), policy, levels=20, cmap='RdYlGn')
    contours = ax.contour(range(21), range(21), policy, levels=11, 
                          colors='black', alpha=0.4, linewidths=0.5)
    
    ax.clabel(contours, inline=True, fontsize=8)
    
    ax.set_xlabel('# Cars at Location 2', fontsize=12)
    ax.set_ylabel('# Cars at Location 1', fontsize=12)
    ax.set_title('Optimal Policy π*(s) - Car Movement Strategy', fontsize=14, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Action (# cars moved: positive = loc1→loc2, negative = loc2→loc1)', 
                   rotation=270, labelpad=25)
    
    # Add text annotations for policy values (every 2nd cell)
    for i in range(0, 21, 2):  # i = location 1 (y-axis)
        for j in range(0, 21, 2):  # j = location 2 (x-axis)
            action = policy[i, j]
            if action != 0:
                ax.text(j, i, f'{int(action)}',
                       ha="center", va="center", 
                       color="black" if abs(action) < 3 else "white", 
                       fontsize=8, fontweight='bold')
    
    ax.grid(True, alpha=0.3)
    ax.set_xticks(range(0, 21, 2))
    ax.set_yticks(range(0, 21, 2))
    
    plt.tight_layout()
    plt.savefig('optimal_policy_detailed.png', dpi=150, bbox_inches='tight')
    print("Saved optimal policy to 'optimal_policy_detailed.png'")
    plt.show()

In [None]:
# Run policy iteration
print("Starting vectorized policy iteration...\n")
V_optimal, policy_optimal = run_policy_iteration_vectorized(theta=0.01, gamma=0.9)

print(f"\n{'='*60}")
print(f"RESULTS:")
print(f"{'='*60}")
print(f"Optimal policy at state (10, 10): {policy_optimal[10, 10]} cars")
print(f"Optimal value at state (10, 10): {V_optimal[10, 10]:.2f}")
print(f"Optimal policy at state (5, 5): {policy_optimal[5, 5]} cars")
print(f"Optimal value at state (5, 5): {V_optimal[5, 5]:.2f}")
print(f"Optimal policy at state (15, 15): {policy_optimal[15, 15]} cars")
print(f"Optimal value at state (15, 15): {V_optimal[15, 15]:.2f}")
print(f"{'='*60}\n")

# Generate all plots
print("Generating visualizations...\n")
plot_policy_progression()
plot_optimal_value_function()
plot_optimal_policy_detailed()

print("\nAll visualizations complete!")