# Stochastic Shortest Path - Static

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/stochastic-optimization/blob/master/notebooks/ssp_static.ipynb)

## Problem
Find shortest path in graph with uncertain edge costs using percentile-based risk measures.

## Formulation
- State: current node
- Action: next node to visit
- Cost: stochastic edge traversal cost
- Objective: Minimize percentile of total path cost

In [None]:
!pip install -q jax jaxlib jaxtyping chex matplotlib networkx
import os
if 'COLAB_GPU' in os.environ or not os.path.exists('problems'):
    !git clone https://github.com/pedronahum/stochastic-optimization.git
    os.chdir('stochastic-optimization')

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx

# Import problem components
from problems.ssp_static import (
    SSPStaticConfig,
    SSPStaticModel,
    GreedyPolicy,
    EpsilonGreedyPolicy,
)

print("✓ Imports successful")
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")

In [None]:
# Create model configuration
config = SSPStaticConfig(
    n_nodes=8,
    edge_prob=0.3,
    cost_lower_bound=1.0,
    cost_upper_bound=10.0
)
model = SSPStaticModel(config)
key = jax.random.PRNGKey(42)
state = model.init_state(key)
print('✓ SSP static model ready')

In [None]:
# Run episodes
key = jax.random.PRNGKey(42)
costs = []
for _ in range(100):
    state = model.init_state(key)
    total_cost = 0.0
    for t in range(20):
        if model.is_terminal(state): break
        key, k1, k2 = jax.random.split(key, 3)
        decision = policy(None, state, k1, model)
        exog = model.sample_exogenous(k2, state, t)
        cost = model.reward(state, decision, exog)
        total_cost += float(cost)
        state = model.transition(state, decision, exog)
    costs.append(-total_cost)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.hist(costs, bins=20)
plt.axvline(jnp.percentile(jnp.array(costs), 10), color='red', label='10th percentile')
plt.title('Path Cost Distribution')
plt.xlabel('Total Cost')
plt.ylabel('Frequency')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(sorted(costs))
plt.title('Sorted Costs')
plt.xlabel('Episode (sorted)')
plt.ylabel('Total Cost')
plt.tight_layout()
plt.show()
print(f'10th percentile cost: {jnp.percentile(jnp.array(costs), 10):.2f}')