# Stochastic Shortest Path with Dynamic Programming

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/stochastic-optimization/blob/master/notebooks/ssp_dynamic.ipynb)

## Problem Overview

This notebook demonstrates stochastic shortest path optimization using JAX. The problem involves:
- **Random graph navigation** with uncertain edge costs
- **Multi-step lookahead** using backward induction
- **Running average cost estimation** - learning edge costs over time
- **Risk-sensitive decision making** using percentile-based policies

The agent must navigate from a starting node to a target node while:
- Minimizing total path cost
- Learning edge costs through experience
- Planning ahead with dynamic programming

---

## Mathematical Formulation

### State Space
The state at time $t$ is:
$$s_t = [\text{node}_t, t, \hat{C}, N]$$

where:
- $\text{node}_t \in \{0, 1, \ldots, n-1\}$ is the current node
- $t$ is the current time step
- $\hat{C} \in \mathbb{R}^{n \times n}$ is the matrix of estimated edge costs
- $N \in \mathbb{N}^{n \times n}$ is the observation count matrix (for running averages)

### Graph Structure
The graph is represented by:
- **Adjacency matrix** $A \in \{0,1\}^{n \times n}$: $A[i,j] = 1$ if edge $(i,j)$ exists
- **Mean costs** $\mu \in \mathbb{R}_+^{n \times n}$: Expected cost for each edge
- **Spreads** $\sigma \in [0,1)^{n \times n}$: Relative variability of edge costs

### Dynamics
When the agent moves from node $i$ to node $j$:

1. **Sample cost**: $c_{i,j} \sim \text{Uniform}[\mu_{i,j}(1-\sigma_{i,j}), \mu_{i,j}(1+\sigma_{i,j})]$
2. **Update estimate** (running average):
   $$\hat{C}[i,j] \leftarrow \left(1 - \frac{1}{N[i,j]}\right) \hat{C}[i,j] + \frac{1}{N[i,j]} c_{i,j}$$
3. **Increment observations**: $N[i,j] \leftarrow N[i,j] + 1$
4. **Move**: $\text{node}_{t+1} = j$, $t \leftarrow t + 1$

### Reward Function
The single-step reward is the negative edge cost:
$$R(s_t, a_t, w_t) = -c_{i,j}$$

where $a_t = j$ is the decision to move to node $j$.

### Policy: Lookahead with Backward Induction

The **LookaheadPolicy** uses dynamic programming over horizon $H$:

1. **Value function initialization** (terminal condition):
   $$V_H[k] = \begin{cases} 0 & \text{if } k = \text{target} \\ \infty & \text{otherwise} \end{cases}$$

2. **Backward induction** for $t = H-1, H-2, \ldots, 0$:
   $$V_t[i] = \min_{j : A[i,j]=1} \left\{ \tilde{C}_{\theta}[i,j] + V_{t+1}[j] \right\}$$

3. **Risk-adjusted costs** (percentile-based):
   $$\tilde{C}_{\theta}[i,j] = \hat{C}[i,j] \cdot \left[(1 - \sigma[i,j]) + 2\sigma[i,j] \cdot \theta\right]$$
   
   - $\theta = 0.0$: Pessimistic (uses $(1-\sigma)\mu$, assumes worst case)
   - $\theta = 0.5$: Risk-neutral (uses $\mu$, expected value)
   - $\theta = 1.0$: Optimistic (uses $(1+\sigma)\mu$, assumes best case)

4. **Decision**:
   $$a^*_t = \arg\min_{j : A[i,j]=1} \left\{ \tilde{C}_{\theta}[i,j] + V_0[j] \right\}$$

### Objective
Minimize expected total cost to reach target:
$$\min_{\pi} \mathbb{E}\left[\sum_{t=0}^{T-1} c_{\text{node}_t, a_t}\right]$$

where $T$ is the (random) time when the target is reached.

---

## Setup and Installation

First, let's install the required packages and clone the repository.

In [None]:
# Install JAX and dependencies!pip install -q jax jaxlib jaxtyping chex numpy matplotlib networkx# Clone repository (only needed for Colab)import osif 'COLAB_GPU' in os.environ or not os.path.exists('problems'):    !git clone https://github.com/pedronahum/stochastic-optimization.git    os.chdir('stochastic-optimization')# Clear Python import cache to ensure latest code is loadedimport sysfor key in list(sys.modules.keys()):    if key.startswith('problems'):        del sys.modules[key]print("✓ Setup complete!")

## Imports

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
from typing import List, Dict

# Import problem components
from problems.ssp_dynamic import (
    SSPDynamicConfig,
    SSPDynamicModel,
    ExogenousInfo,
    LookaheadPolicy,
    GreedyLookaheadPolicy,
    RandomPolicy,
)

print("✓ Imports successful")
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")

## Problem Configuration

Let's set up the stochastic shortest path parameters.

In [None]:
# Create configuration
config = SSPDynamicConfig(
    n_nodes=10,          # 10 nodes in the graph
    horizon=15,          # 15-step lookahead for dynamic programming
    edge_prob=0.3,       # 30% probability of edge existing
    cost_min=1.0,        # Minimum edge cost
    cost_max=10.0,       # Maximum edge cost
    max_spread=0.3,      # Maximum 30% spread around mean cost
    seed=42
)

print("Configuration:")
print(f"  Nodes: {config.n_nodes}")
print(f"  Lookahead horizon: {config.horizon}")
print(f"  Edge probability: {config.edge_prob:.1%}")
print(f"  Cost range: [{config.cost_min}, {config.cost_max}]")
print(f"  Max spread: ±{config.max_spread:.1%}")
print(f"  State size: {2 + 2 * config.n_nodes * config.n_nodes} (node + time + costs + counts)")

## Initialize Model and Policies

In [None]:
# Create model
model = SSPDynamicModel(config)

# Create policies
lookahead_policy = LookaheadPolicy(theta=0.5)  # Risk-neutral
greedy_policy = GreedyLookaheadPolicy()        # Single-step lookahead
random_policy = RandomPolicy()                 # Random baseline

print(f"✓ Model initialized")
print(f"  Target node: {model.target_node}")
print(f"  Total edges: {int(jnp.sum(model.adjacency))}")
print(f"  Average degree: {jnp.sum(model.adjacency) / config.n_nodes:.1f}")
print(f"\n✓ Policies created:")
print(f"  - LookaheadPolicy (θ=0.5, risk-neutral, {config.horizon}-step backward induction)")
print(f"  - GreedyLookaheadPolicy (single-step, uses estimated costs)")
print(f"  - RandomPolicy (baseline)")

## Visualize Graph Structure

Let's visualize the directed graph with edge costs.

In [None]:
def visualize_graph(model, show_costs=True):
    """Visualize the graph structure with NetworkX."""
    # Create directed graph
    G = nx.DiGraph()
    
    # Add nodes
    for i in range(model.config.n_nodes):
        G.add_node(i)
    
    # Add edges with mean costs
    edge_labels = {}
    for i in range(model.config.n_nodes):
        for j in range(model.config.n_nodes):
            if model.adjacency[i, j]:
                cost = float(model.mean_costs[i, j])
                spread = float(model.spreads[i, j])
                G.add_edge(i, j, weight=cost)
                if show_costs:
                    edge_labels[(i, j)] = f"{cost:.1f}±{spread*100:.0f}%"
    
    # Layout
    pos = nx.spring_layout(G, seed=42, k=2, iterations=50)
    
    # Draw
    fig, ax = plt.subplots(figsize=(14, 10))
    
    # Color nodes: start (green), target (red), others (lightblue)
    node_colors = ['lightblue'] * model.config.n_nodes
    node_colors[0] = 'lightgreen'
    node_colors[model.target_node] = 'lightcoral'
    
    # Draw nodes
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, 
                           node_size=800, ax=ax, alpha=0.9)
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, edge_color='gray', 
                           arrows=True, arrowsize=20, 
                           arrowstyle='->', ax=ax, alpha=0.6,
                           connectionstyle='arc3,rad=0.1')
    
    # Draw labels
    nx.draw_networkx_labels(G, pos, font_size=14, font_weight='bold', ax=ax)
    
    # Draw edge labels (costs)
    if show_costs:
        nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=8, ax=ax)
    
    ax.set_title(f'Graph Structure (Node 0 → Node {model.target_node})', 
                 fontsize=16, fontweight='bold')
    ax.axis('off')
    
    # Legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='lightgreen', label='Start (Node 0)'),
        Patch(facecolor='lightcoral', label=f'Target (Node {model.target_node})'),
        Patch(facecolor='lightblue', label='Intermediate Nodes')
    ]
    ax.legend(handles=legend_elements, loc='upper left', fontsize=12)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nGraph Statistics:")
    print(f"  Total edges: {G.number_of_edges()}")
    print(f"  Average out-degree: {G.number_of_edges() / G.number_of_nodes():.2f}")
    print(f"  Is connected: {nx.is_weakly_connected(G)}")

visualize_graph(model, show_costs=True)

## Run Simulation

Let's simulate path finding until the agent reaches the target node.

In [None]:
def run_episode(model, policy, max_steps=50, key=None, verbose=True):
    """Run a single episode until target is reached or max steps."""
    if key is None:
        key = jax.random.PRNGKey(42)
    
    # Initialize
    key, subkey = jax.random.split(key)
    state = model.init_state(subkey)
    
    # Track history
    history = {
        'path': [0],  # Start at node 0
        'costs': [],
        'cumulative_cost': [],
        'states': [state],
    }
    
    total_cost = 0.0
    
    # Run simulation
    for t in range(max_steps):
        # Check if reached target
        if model.is_terminal(state):
            if verbose:
                print(f"✓ Reached target node {model.target_node} in {t} steps!")
            break
        
        # Get decision
        key, subkey = jax.random.split(key)
        decision = policy(None, state, subkey, model)
        
        # Sample exogenous events
        key, subkey = jax.random.split(key)
        exog = model.sample_exogenous(subkey, state, t)
        
        # Compute reward (negative cost)
        reward = model.reward(state, decision, exog)
        cost = -float(reward)
        total_cost += cost
        
        # Record
        history['path'].append(int(decision))
        history['costs'].append(cost)
        history['cumulative_cost'].append(total_cost)
        
        # Transition
        state = model.transition(state, decision, exog)
        history['states'].append(state)
    else:
        if verbose:
            print(f"⚠ Did not reach target in {max_steps} steps")
    
    if verbose:
        print(f"  Total cost: {total_cost:.2f}")
        print(f"  Path: {' → '.join(map(str, history['path']))}")
    
    return history

# Run simulation with LookaheadPolicy
print("Running simulation with LookaheadPolicy (θ=0.5)...\n")
key = jax.random.PRNGKey(42)
history = run_episode(model, lookahead_policy, max_steps=50, key=key)

## Visualize Path Taken

Let's visualize the path the agent took through the graph.

In [None]:
def visualize_path(model, history):
    """Visualize the path taken through the graph."""
    path = history['path']
    costs = history['costs']
    
    # Create directed graph
    G = nx.DiGraph()
    for i in range(model.config.n_nodes):
        G.add_node(i)
    
    # Add all edges
    for i in range(model.config.n_nodes):
        for j in range(model.config.n_nodes):
            if model.adjacency[i, j]:
                G.add_edge(i, j)
    
    # Layout
    pos = nx.spring_layout(G, seed=42, k=2, iterations=50)
    
    # Draw
    fig, ax = plt.subplots(figsize=(14, 10))
    
    # Color nodes: visited (yellow), start (green), target (red), others (lightgray)
    visited_nodes = set(path)
    node_colors = []
    for i in range(model.config.n_nodes):
        if i == 0:
            node_colors.append('lightgreen')
        elif i == model.target_node:
            node_colors.append('lightcoral')
        elif i in visited_nodes:
            node_colors.append('gold')
        else:
            node_colors.append('lightgray')
    
    # Draw all nodes
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, 
                           node_size=800, ax=ax, alpha=0.6)
    
    # Draw all edges (faint)
    nx.draw_networkx_edges(G, pos, edge_color='lightgray', 
                           arrows=True, arrowsize=15, 
                           arrowstyle='->', ax=ax, alpha=0.2,
                           connectionstyle='arc3,rad=0.1')
    
    # Draw path edges (bold)
    path_edges = [(path[i], path[i+1]) for i in range(len(path)-1)]
    nx.draw_networkx_edges(G, pos, edgelist=path_edges, 
                           edge_color='red', width=3,
                           arrows=True, arrowsize=25, 
                           arrowstyle='->', ax=ax, alpha=0.8,
                           connectionstyle='arc3,rad=0.1')
    
    # Draw labels
    nx.draw_networkx_labels(G, pos, font_size=14, font_weight='bold', ax=ax)
    
    # Draw edge labels for path costs
    edge_labels = {(path[i], path[i+1]): f"{costs[i]:.1f}" 
                   for i in range(len(costs))}
    nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=10, 
                                 font_color='red', ax=ax)
    
    total_cost = sum(costs)
    ax.set_title(f'Path Taken (Total Cost: {total_cost:.2f})', 
                 fontsize=16, fontweight='bold')
    ax.axis('off')
    
    # Legend
    from matplotlib.patches import Patch
    from matplotlib.lines import Line2D
    legend_elements = [
        Patch(facecolor='lightgreen', label='Start'),
        Patch(facecolor='lightcoral', label='Target'),
        Patch(facecolor='gold', label='Visited'),
        Line2D([0], [0], color='red', linewidth=3, label='Path Taken')
    ]
    ax.legend(handles=legend_elements, loc='upper left', fontsize=12)
    
    plt.tight_layout()
    plt.show()

visualize_path(model, history)

## Visualize Cost Evolution

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

steps = range(len(history['costs']))

# Step costs
ax1.bar(steps, history['costs'], color='steelblue', alpha=0.7)
ax1.set_title('Edge Costs Per Step', fontsize=12, fontweight='bold')
ax1.set_xlabel('Step')
ax1.set_ylabel('Cost')
ax1.grid(alpha=0.3, axis='y')

# Cumulative cost
ax2.plot(steps, history['cumulative_cost'], 'o-', linewidth=2, 
         color='darkgreen', markersize=6)
ax2.set_title('Cumulative Cost', fontsize=12, fontweight='bold')
ax2.set_xlabel('Step')
ax2.set_ylabel('Total Cost')
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Summary:")
print(f"  Steps taken: {len(history['costs'])}")
print(f"  Average cost per step: {np.mean(history['costs']):.2f}")
print(f"  Total cost: {sum(history['costs']):.2f}")

## Policy Comparison

Let's compare the three policies over multiple episodes.

In [None]:
policies = {
    'Lookahead (θ=0.5)': lookahead_policy,
    'Greedy': greedy_policy,
    'Random': random_policy,
}

# Run multiple episodes for each policy
n_episodes = 20
results = {}

for name, policy in policies.items():
    print(f"\nRunning {name} policy ({n_episodes} episodes)...")
    episode_results = []
    
    for i in range(n_episodes):
        key = jax.random.PRNGKey(i)  # Different seed per episode
        history = run_episode(model, policy, max_steps=50, key=key, verbose=False)
        
        # Check if reached target
        reached_target = (history['path'][-1] == model.target_node)
        total_cost = sum(history['costs']) if history['costs'] else float('inf')
        steps = len(history['costs'])
        
        episode_results.append({
            'reached': reached_target,
            'cost': total_cost,
            'steps': steps,
        })
    
    results[name] = episode_results
    
    # Summary
    success_rate = sum(r['reached'] for r in episode_results) / n_episodes
    successful_costs = [r['cost'] for r in episode_results if r['reached']]
    avg_cost = np.mean(successful_costs) if successful_costs else float('inf')
    avg_steps = np.mean([r['steps'] for r in episode_results if r['reached']])
    
    print(f"  Success rate: {success_rate:.1%}")
    print(f"  Avg cost (successful): {avg_cost:.2f}")
    print(f"  Avg steps (successful): {avg_steps:.1f}")

# Compare statistics
print("\n" + "="*60)
print("Policy Comparison Summary")
print("="*60)
print(f"{'Policy':<20} {'Success Rate':>15} {'Avg Cost':>12} {'Avg Steps':>12}")
print("-"*60)
for name, episode_results in results.items():
    success_rate = sum(r['reached'] for r in episode_results) / n_episodes
    successful_costs = [r['cost'] for r in episode_results if r['reached']]
    avg_cost = np.mean(successful_costs) if successful_costs else float('inf')
    successful_steps = [r['steps'] for r in episode_results if r['reached']]
    avg_steps = np.mean(successful_steps) if successful_steps else float('inf')
    
    print(f"{name:<20} {success_rate:>14.1%} {avg_cost:>12.2f} {avg_steps:>12.1f}")

## Visualize Policy Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

policy_names = list(results.keys())
colors = ['steelblue', 'coral', 'lightgreen']

# Success rates
success_rates = [sum(r['reached'] for r in results[name]) / n_episodes 
                 for name in policy_names]
axes[0].bar(policy_names, success_rates, color=colors, alpha=0.7)
axes[0].set_title('Success Rate', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Proportion')
axes[0].set_ylim([0, 1.1])
axes[0].grid(alpha=0.3, axis='y')
for i, v in enumerate(success_rates):
    axes[0].text(i, v + 0.02, f'{v:.1%}', ha='center', fontweight='bold')

# Average costs (successful episodes only)
avg_costs = []
for name in policy_names:
    successful_costs = [r['cost'] for r in results[name] if r['reached']]
    avg_costs.append(np.mean(successful_costs) if successful_costs else 0)

axes[1].bar(policy_names, avg_costs, color=colors, alpha=0.7)
axes[1].set_title('Average Total Cost\n(Successful Episodes)', 
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Cost')
axes[1].grid(alpha=0.3, axis='y')
for i, v in enumerate(avg_costs):
    if v > 0:
        axes[1].text(i, v + max(avg_costs)*0.02, f'{v:.1f}', 
                    ha='center', fontweight='bold')

# Average steps (successful episodes only)
avg_steps = []
for name in policy_names:
    successful_steps = [r['steps'] for r in results[name] if r['reached']]
    avg_steps.append(np.mean(successful_steps) if successful_steps else 0)

axes[2].bar(policy_names, avg_steps, color=colors, alpha=0.7)
axes[2].set_title('Average Steps to Target\n(Successful Episodes)', 
                  fontsize=12, fontweight='bold')
axes[2].set_ylabel('Steps')
axes[2].grid(alpha=0.3, axis='y')
for i, v in enumerate(avg_steps):
    if v > 0:
        axes[2].text(i, v + max(avg_steps)*0.02, f'{v:.1f}', 
                    ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## Distribution of Costs Across Episodes

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

# Box plot of costs for successful episodes
data_to_plot = []
labels = []
for name in policy_names:
    successful_costs = [r['cost'] for r in results[name] if r['reached']]
    if successful_costs:
        data_to_plot.append(successful_costs)
        labels.append(name)

bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True,
                showmeans=True, meanline=True)

# Color boxes
for patch, color in zip(bp['boxes'], colors[:len(data_to_plot)]):
    patch.set_facecolor(color)
    patch.set_alpha(0.6)

ax.set_title('Distribution of Total Costs (Successful Episodes)', 
             fontsize=14, fontweight='bold')
ax.set_ylabel('Total Cost', fontsize=12)
ax.grid(alpha=0.3, axis='y')

# Add legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color='orange', linewidth=2, linestyle='--', label='Median'),
    Line2D([0], [0], color='green', linewidth=2, linestyle='-', label='Mean')
]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()

## Key Insights

The simulation demonstrates:

1. **LookaheadPolicy with backward induction** typically performs best by:
   - Planning multiple steps ahead using dynamic programming
   - Using learned cost estimates to make informed decisions
   - Computing optimal value functions over the lookahead horizon

2. **Learning dynamics**:
   - The agent improves cost estimates through running averages
   - Early decisions may be suboptimal due to uncertain estimates
   - Later decisions benefit from accumulated knowledge

3. **Risk sensitivity** (via θ parameter):
   - θ = 0.0: Pessimistic, plans for worst-case costs
   - θ = 0.5: Risk-neutral, uses expected costs (demonstrated here)
   - θ = 1.0: Optimistic, assumes best-case costs

4. **Planning horizon tradeoffs**:
   - Longer horizons provide better decisions but higher computation
   - Backward induction ensures time-consistent optimal paths
   - Greedy (1-step) policy is fast but may miss better long-term paths

5. **Graph structure impact**:
   - Dense graphs (high edge_prob) offer more route options
   - Sparse graphs may force suboptimal paths
   - Cost variance (spread) affects risk-sensitive decisions

---

## Extensions

Try modifying:
- `horizon`: See how lookahead depth affects performance
- `theta`: Test risk-averse (θ=0.2) vs. risk-seeking (θ=0.8) behavior
- `edge_prob`: Create denser or sparser graphs
- `max_spread`: Increase uncertainty to see adaptation
- `n_nodes`: Scale to larger graphs
- Implement your own policy using learned costs!

## References

- Repository: https://github.com/pedronahum/stochastic-optimization
- JAX Documentation: https://jax.readthedocs.io/
- Stochastic Shortest Path: Bertsekas, D. P. (2012). Dynamic Programming and Optimal Control, Vol. II
