In [1]:
"""
Implementation of the stochastic Epigraph Robust Constrained Policy Gradient Search
(S‑EpiRC‑PGS).

This module provides a simple reference implementation of the algorithm
presented in Kitamura et al. (2025).  In contrast to the deterministic
version of EpiRC‑PGS contained in ``epirc-pgs.py``, the stochastic
counterpart replaces exact gradient and value computations with Monte
Carlo sampling.  The inner loop therefore uses a REINFORCE style
policy gradient estimator, while the outer loop performs a noisy
bisection search based on the PBA (probabilistic bisection algorithm).

Key features of this implementation:
  * An uncertainty set of transition kernels ``U`` is generated
    randomly.  The user may swap the ``make_U`` function for one
    generating a KL‑ball of kernels if desired.  Each element of ``U``
    is a three‑dimensional array of shape ``(S, A, S)`` encoding
    conditional next‑state probabilities.
  * Cost functions for the objective and constraints are drawn at
    random.  The thresholds for the constraints (entries 1–N‑1 of
    ``B``) are sampled uniformly from the effective horizon range.
  * The inner loop (Algorithm 2 in the paper) applies a projection
    step after each gradient update to ensure the policy remains a
    valid probability distribution for each state.  The gradient is
    estimated using the REINFORCE trick: for the worst case cost
    component the algorithm samples a trajectory from the worst
    environment and uses discounted returns to construct an unbiased
    estimator of the gradient of the expected cost.
  * The outer loop (Algorithm 1) implements a discrete version of
    PBA.  A finite grid over the search interval ``[0,H]`` defines a
    belief distribution over the optimal cost.  At each iteration the
    median of this belief determines the test threshold ``b0``.  The
    sign of the maximum constraint violation returned by the inner
    loop updates the belief, favouring the half interval suggested by
    the outcome.  A reliability parameter ``p > 0.5`` encodes the
    probability that the inner loop’s answer is correct; following
    Waeber et al. (2013) the belief update multiplies the weights on
    one side of ``b0`` by ``p`` and the other side by ``1‑p``.

This code is intended for educational purposes and has been kept
compact rather than optimised for speed.  For small state and action
spaces it should run in a reasonable amount of time.  To port the
algorithm to more complex environments (e.g. with function
approximation) one would need to extend the sampling and gradient
estimation routines appropriately.
"""

import random
import numpy as np

# ------------------------- Global hyper‑parameters -------------------------

# Number of states and actions in the tabular MDP
S = 3  # 状態数
A = 2  # 行動数

# Number of cost functions (one objective + constraints)
N = 4

# Number of uncertainty kernels in the uncertainty set U
U_num = 3

# Discount factor and effective horizon
gamma = 0.95
H = int(np.round(1 / (1 - gamma)))  # 有効地平線

# Learning rate for the inner loop
alpha = 0.01

# Number of iterations in the inner loop (per call of Algorithm 2)
T_inner = 50

# Number of iterations of the outer loop (Algorithm 1)
K_outer = 100

# Reliability parameter for PBA (must be > 0.5).  A higher value
# implies more trust in the sign returned by the inner loop.
p_reliability = 0.7

# Resolution of the belief grid over [0, H] used in PBA.  A finer
# grid yields a more accurate estimate of the optimal value but
# increases computational cost.
belief_grid_size = 101

# Initial policy π is a random stochastic matrix of shape (S,A)
pi_table = np.random.rand(S, A)
pi_table = pi_table / np.sum(pi_table, axis=1, keepdims=True)
pi_flat = pi_table.flatten()

# Random initial state distribution μ (length S)
mu = np.random.rand(S)
mu = mu / np.sum(mu)


# ------------------------- Uncertainty set and costs -------------------------

def make_U():
    """
    Generate a list of transition kernels forming the uncertainty set.

    By default this function samples each kernel independently from a
    uniform Dirichlet distribution over successor states.  To
    construct a KL‑ball around a nominal kernel one could modify this
    routine accordingly.
    """
    U_list = []
    for _ in range(U_num):
        P_kernel = np.random.rand(S, A, S)
        P_kernel = P_kernel / np.sum(P_kernel, axis=2, keepdims=True)
        U_list.append(P_kernel)
    return U_list


def make_C_and_B():
    """
    Create random cost matrices C_n(s,a) and thresholds B_n.

    For each cost (objective + constraints) we draw a random
    non‑negative cost matrix.  For the objective (index 0) the
    threshold is placeholder and will be set by the outer loop.  For
    the constraints we draw integer thresholds uniformly between 0
    and H inclusive.
    """
    C_list = []
    B_list = []
    for n in range(N):
        c_sa = np.random.rand(S, A)
        C_list.append(c_sa)
        if n == 0:
            B_list.append(0.0)  # placeholder for objective
        else:
            B_list.append(float(random.randint(0, H)))
    return C_list, B_list


# Instantiate the uncertainty set and cost functions
U = make_U()
C, B = make_C_and_B()


# ------------------------- Projection utilities -------------------------

def projection_to_simplex(y_vec: np.ndarray) -> np.ndarray:
    """
    Project a vector onto the probability simplex (non‑negative entries
    summing to one).

    This helper implements the algorithm of Duchi et al. (2008).  It
    is used to ensure each row of the policy table remains a valid
    probability distribution after the gradient update.

    Args:
        y_vec (np.ndarray): vector of shape (A,) to project
    Returns:
        np.ndarray: projected vector of shape (A,)
    """
    m = y_vec.shape[0]
    # Sort entries in descending order
    u = np.sort(y_vec)[::-1]
    cssv = np.cumsum(u)
    rho = -1
    for i in range(m):
        if u[i] + (1 - cssv[i]) / (i + 1) > 0:
            rho = i
    if rho == -1:
        return np.ones(m) / m
    theta = (cssv[rho] - 1) / (rho + 1)
    return np.maximum(y_vec - theta, 0)


def proj_policy_matrix(policy_matrix: np.ndarray) -> np.ndarray:
    """
    Project each row of a policy matrix onto the simplex.

    Args:
        policy_matrix (np.ndarray): array of shape (S,A)
    Returns:
        np.ndarray: projected matrix of shape (S,A)
    """
    projected = np.zeros_like(policy_matrix)
    for s_idx in range(policy_matrix.shape[0]):
        projected[s_idx, :] = projection_to_simplex(policy_matrix[s_idx, :])
    return projected


def update_and_project_policy(current_policy_table: np.ndarray,
                              gradient_table: np.ndarray,
                              lr: float) -> np.ndarray:
    """
    Perform a gradient descent step followed by projection onto the
    probability simplex for each state.

    Args:
        current_policy_table (np.ndarray): current policy matrix of
            shape (S,A)
        gradient_table (np.ndarray): gradient estimate of shape (S,A)
        lr (float): learning rate
    Returns:
        np.ndarray: updated and projected policy matrix of shape (S,A)
    """
    if current_policy_table.shape != gradient_table.shape:
        raise ValueError(
            "Shape mismatch between current_policy_table and gradient_table.")
    y = current_policy_table - lr * gradient_table
    return proj_policy_matrix(y)


# ------------------------- Value computation (exact) -------------------------

def compute_Q_pi_c_P_matrix(P_kernel: np.ndarray,
                            Q_initial: np.ndarray,
                            cost_sa: np.ndarray,
                            pi_flat: np.ndarray,
                            max_iterations: int = 1000,
                            tolerance: float = 1e-3) -> np.ndarray:
    """
    Solve for the Q‑function Q^π(s,a) under fixed P and cost c via
    successive approximation of the Bellman equation.  This function
    mirrors the deterministic helper used in ``epirc-pgs.py``.

    Args:
        P_kernel (np.ndarray): transition kernel of shape (S,A,S)
        Q_initial (np.ndarray): initial guess for Q, shape (S,A)
        cost_sa (np.ndarray): cost matrix, shape (S,A)
        pi_flat (np.ndarray): flattened policy of length S*A
        max_iterations (int): maximum number of iterations
        tolerance (float): convergence tolerance
    Returns:
        np.ndarray: converged Q table of shape (S,A)
    """
    Q_k = np.copy(Q_initial)
    for iteration in range(max_iterations):
        Q_k_plus = np.zeros_like(Q_k)
        # Precompute V^π(s) for all s under current Q_k
        # Flatten policy into table for convenience
        pi_matrix = pi_flat.reshape(S, A)
        V = np.zeros(S)
        for s_idx in range(S):
            V[s_idx] = np.dot(pi_matrix[s_idx, :], Q_k[s_idx, :])
        # Perform Bellman update
        for s_idx in range(S):
            for a_idx in range(A):
                expected_future = np.dot(P_kernel[s_idx, a_idx, :], V)
                Q_k_plus[s_idx, a_idx] = cost_sa[s_idx, a_idx] + gamma * expected_future
        # Check convergence
        if np.max(np.abs(Q_k_plus - Q_k)) < tolerance:
            return Q_k_plus
        Q_k = Q_k_plus
    return Q_k


def compute_J_c_P(P_kernel: np.ndarray,
                  Q_initial: np.ndarray,
                  cost_sa: np.ndarray,
                  pi_flat: np.ndarray) -> float:
    """
    Compute the expected discounted cost J_{c,P}(π) = μ^T V^π for a
    given transition kernel P and cost matrix c.

    Args:
        P_kernel (np.ndarray): transition kernel of shape (S,A,S)
        Q_initial (np.ndarray): initial guess for Q, shape (S,A)
        cost_sa (np.ndarray): cost matrix, shape (S,A)
        pi_flat (np.ndarray): flattened policy of length S*A
    Returns:
        float: expected discounted cost under μ
    """
    Q_sa = compute_Q_pi_c_P_matrix(P_kernel, Q_initial, cost_sa, pi_flat)
    pi_matrix = pi_flat.reshape(S, A)
    V = np.zeros(S)
    for s_idx in range(S):
        V[s_idx] = np.dot(pi_matrix[s_idx, :], Q_sa[s_idx, :])
    return float(np.dot(mu, V))


def compute_J_c_U_b_and_its_max(U_list, C_list, Q_initial_list, B_list, pi_flat):
    """
    For each cost index n, compute the worst case cost over all
    environment kernels and subtract the corresponding threshold.

    Args:
        U_list (list): list of transition kernels
        C_list (list): list of cost matrices c_n
        Q_initial_list (list): list of initial Q matrices
        B_list (list): list of thresholds b_n
        pi_flat (np.ndarray): current policy
    Returns:
        tuple: (J_results, J_max_index, J_max_index_U) where
            J_results[n] = max_i J_{c_n,U_i}(π) - B[n]
            J_max_index is the index n of the largest violation
            J_max_index_U[n] is the environment index achieving the
            maximum for cost n
    """
    J_results = []
    J_max_index_U = []
    for n in range(N):
        values = []
        for i in range(U_num):
            val = compute_J_c_P(U_list[i], Q_initial_list[n], C_list[n], pi_flat)
            values.append(val)
        worst_env = int(np.argmax(values))
        J_max_index_U.append(worst_env)
        J_results.append(values[worst_env] - B_list[n])
    J_max_index = int(np.argmax(J_results))
    return J_results, J_max_index, J_max_index_U


def compute_delta(pi_flat: np.ndarray, B_list):
    """
    Compute Δ(π,B) = max_n [ J_{c_n,U}(π) − B[n] ] exactly using
    dynamic programming.  This helper is used in the outer loop to
    determine the sign for PBA.

    Args:
        pi_flat (np.ndarray): current policy (flattened)
        B_list (list or np.ndarray): threshold list
    Returns:
        float: maximum violation across costs
    """
    # Construct Q_initial_list of zeros for all costs
    Q_initial_list = [np.zeros((S, A)) for _ in range(N)]
    J_results, _, _ = compute_J_c_U_b_and_its_max(U, C, Q_initial_list, B_list, pi_flat)
    return float(max(J_results))


# ------------------------- Sampling utilities -------------------------

def sample_trajectory(pi_flat: np.ndarray,
                      P_kernel: np.ndarray,
                      horizon: int = H) -> tuple[list[int], list[int]]:
    """
    Sample a single trajectory of states and actions under a given
    policy and transition kernel.

    Args:
        pi_flat (np.ndarray): flattened policy of length S*A
        P_kernel (np.ndarray): transition kernel of shape (S,A,S)
        horizon (int): number of steps to simulate
    Returns:
        tuple: (states, actions) lists of length ``horizon``
    """
    states = []
    actions = []
    pi_matrix = pi_flat.reshape(S, A)
    # Sample initial state according to μ
    current_state = np.random.choice(S, p=mu)
    for _ in range(horizon):
        # Sample action according to current policy for this state
        action = np.random.choice(A, p=pi_matrix[current_state])
        states.append(current_state)
        actions.append(action)
        # Sample next state according to P_kernel
        next_state = np.random.choice(S, p=P_kernel[current_state, action, :])
        current_state = next_state
    return states, actions


def sample_return_for_cost(pi_flat: np.ndarray,
                           P_kernel: np.ndarray,
                           cost_sa: np.ndarray,
                           horizon: int = H) -> float:
    """
    Estimate J_{c,P}(π) by sampling a single trajectory and summing
    discounted costs.

    Args:
        pi_flat (np.ndarray): flattened policy
        P_kernel (np.ndarray): transition kernel
        cost_sa (np.ndarray): cost matrix of shape (S,A)
        horizon (int): number of steps to simulate
    Returns:
        float: sampled discounted return
    """
    states, actions = sample_trajectory(pi_flat, P_kernel, horizon)
    ret = 0.0
    discount = 1.0
    for t, (s, a) in enumerate(zip(states, actions)):
        ret += discount * cost_sa[s, a]
        discount *= gamma
    return ret


def approximate_J_c_P(pi_flat: np.ndarray,
                      P_kernel: np.ndarray,
                      cost_sa: np.ndarray,
                      num_samples: int = 1,
                      horizon: int = H) -> float:
    """
    Approximate the expected discounted cost J_{c,P}(π) via Monte Carlo
    sampling.

    Args:
        pi_flat (np.ndarray): flattened policy
        P_kernel (np.ndarray): transition kernel
        cost_sa (np.ndarray): cost matrix
        num_samples (int): number of trajectories to sample
        horizon (int): horizon length
    Returns:
        float: Monte Carlo estimate of the expected discounted return
    """
    total = 0.0
    for _ in range(num_samples):
        total += sample_return_for_cost(pi_flat, P_kernel, cost_sa, horizon)
    return total / float(num_samples)


def approximate_J_c_U(pi_flat: np.ndarray,
                      C_list: list,
                      U_list: list,
                      num_samples: int = 1,
                      horizon: int = H) -> tuple[list[float], list[int]]:
    """
    Approximate J_{c_n,U}(π) for each cost n by sampling.  For each
    cost the worst case over the uncertainty set is returned.

    Args:
        pi_flat (np.ndarray): current policy
        C_list (list): list of cost matrices c_n
        U_list (list): list of transition kernels
        num_samples (int): number of trajectories per environment
        horizon (int): horizon length
    Returns:
        tuple: (J_estimates, worst_env_indices)
            J_estimates[n] ≈ max_i J_{c_n,U_i}(π)
            worst_env_indices[n] = argmax_i J_{c_n,U_i}(π)
    """
    J_estimates = []
    worst_env_indices = []
    for n in range(N):
        values = []
        for i in range(U_num):
            val = approximate_J_c_P(pi_flat, U_list[i], C_list[n], num_samples, horizon)
            values.append(val)
        idx = int(np.argmax(values))
        J_estimates.append(values[idx])
        worst_env_indices.append(idx)
    return J_estimates, worst_env_indices


def approximate_delta(pi_flat: np.ndarray,
                      B_list,
                      C_list: list,
                      U_list: list,
                      num_samples: int = 1,
                      horizon: int = H) -> float:
    """
    Compute an approximate maximum constraint violation Δ for a
    candidate policy using Monte Carlo sampling.

    Args:
        pi_flat (np.ndarray): current policy
        B_list (list): thresholds b_n
        C_list (list): list of cost matrices c_n
        U_list (list): list of transition kernels
        num_samples (int): number of samples per environment
        horizon (int): horizon length
    Returns:
        float: estimated Δ
    """
    J_estimates, _ = approximate_J_c_U(pi_flat, C_list, U_list, num_samples, horizon)
    max_violation = max([J_estimates[n] - B_list[n] for n in range(N)])
    return float(max_violation)


def estimate_gradient_REINFORCE(pi_flat: np.ndarray,
                                P_kernel: np.ndarray,
                                cost_sa: np.ndarray,
                                num_samples: int = 1,
                                horizon: int = H) -> np.ndarray:
    """
    Estimate the gradient of J_{c,P}(π) with respect to π using the
    REINFORCE estimator.

    This function returns a gradient table of shape (S,A).  For each
    sampled trajectory it computes the cumulative discounted cost from
    each time step and multiplies it by the derivative of the log
    policy with respect to the action probability.  The resulting
    gradient is averaged over ``num_samples`` trajectories.

    Args:
        pi_flat (np.ndarray): flattened policy
        P_kernel (np.ndarray): transition kernel
        cost_sa (np.ndarray): cost matrix
        num_samples (int): number of trajectories to sample
        horizon (int): horizon length
    Returns:
        np.ndarray: gradient estimate of shape (S,A)
    """
    gradient = np.zeros((S, A))
    pi_matrix = pi_flat.reshape(S, A)
    for _ in range(num_samples):
        states, actions = sample_trajectory(pi_flat, P_kernel, horizon)
        # Compute discounted return from each time step
        # returns_from_t[t] = sum_{k=t}^{T-1} gamma^{k-t} c(s_k,a_k)
        G_remaining = 0.0
        returns_from_t = np.zeros(len(states))
        for idx in reversed(range(len(states))):
            s = states[idx]
            a = actions[idx]
            G_remaining = cost_sa[s, a] + gamma * G_remaining
            returns_from_t[idx] = G_remaining
        # Accumulate gradient using score function estimator
        for idx, (s, a) in enumerate(zip(states, actions)):
            # derivative of log π(a|s) w.r.t. π(s,a) is 1/π(s,a)
            gradient[s, a] += returns_from_t[idx] / (pi_matrix[s, a] + 1e-8)
    gradient /= float(num_samples)
    return gradient


# ------------------------- Inner and outer loops -------------------------

# The following global containers are used to record learning
# statistics.  When running ``pba_outer_loop`` these lists are
# populated to enable downstream visualisation.  If you run multiple
# experiments in a single interpreter session you may wish to reset
# them beforehand.
outer_b0_history: list[float] = []
outer_delta_history: list[float] = []
outer_sign_history: list[int] = []
inner_deltas_history: list[list[float]] = []

def inner_loop(b0: float,
               initial_pi_flat: np.ndarray,
               record_list: list[float] | None = None) -> np.ndarray:
    """
    Execute the stochastic inner loop (Algorithm 2) for a fixed
    threshold ``b0``.

    Given a candidate objective threshold ``b0`` this routine seeks a
    policy π that minimises the maximum constraint violation Δ.  The
    objective threshold ``B[0]`` is set to ``b0`` temporarily.  A
    REINFORCE gradient step is performed at each iteration using the
    worst case cost component and environment determined from
    approximate value estimates.  The best policy encountered (in
    terms of Δ) during the iterations is returned.

    Args:
        b0 (float): objective threshold
        initial_pi_flat (np.ndarray): starting policy vector
    Returns:
        np.ndarray: flattened policy found by the inner loop
    """
    # Update the objective threshold for the duration of this call
    B_local = B.copy()
    B_local[0] = b0
    # Local copy of the policy to avoid clobbering the global state
    local_pi = initial_pi_flat.copy()
    best_pi = local_pi.copy()
    # Compute the current best Δ value
    best_delta_val = approximate_delta(local_pi, B_local, C, U)
    for t in range(T_inner):
        # Estimate worst case costs and environments
        J_estimates, worst_env_indices = approximate_J_c_U(local_pi, C, U)
        # Determine the cost index with maximum violation
        delta_values = [J_estimates[n] - B_local[n] for n in range(N)]
        worst_cost_idx = int(np.argmax(delta_values))
        worst_env_idx = worst_env_indices[worst_cost_idx]
        # Estimate gradient for this worst case
        grad = estimate_gradient_REINFORCE(local_pi,
                                           U[worst_env_idx],
                                           C[worst_cost_idx],
                                           num_samples=1,
                                           horizon=H)
        # Gradient descent update and projection
        updated_pi_matrix = update_and_project_policy(
            local_pi.reshape(S, A), grad, alpha)
        local_pi = updated_pi_matrix.flatten()
        # Track the best policy seen so far according to Δ
        current_delta = approximate_delta(local_pi, B_local, C, U)
        if current_delta < best_delta_val:
            best_delta_val = current_delta
            best_pi = local_pi.copy()
        # Record approximate Δ if logging is requested
        if record_list is not None:
            record_list.append(current_delta)
        # Optional diagnostics: uncomment for verbose output
        # if t % 10 == 0:
        #     print(f"  Inner iter {t}: Δ ≈ {current_delta:.4f}, b0 = {b0:.4f}")
    return best_pi


def pba_outer_loop(initial_pi_flat: np.ndarray, return_logs: bool = False) -> np.ndarray | tuple[np.ndarray, dict]:
    """
    Execute the probabilistic bisection outer loop (Algorithm 1).

    This routine implements a discrete version of the PBA described in
    the paper.  A belief distribution over the interval [0, H] is
    maintained on a finite grid.  At each iteration the median of the
    belief determines the test threshold b0.  The inner loop is
    executed with this threshold and the sign of the resulting Δ
    updates the belief via Bayes’ rule using the reliability
    parameter ``p_reliability``.  After ``K_outer`` iterations the
    median of the final belief yields the estimate j(K), and a final
    call to the inner loop returns the policy at that threshold.

    Args:
        initial_pi_flat (np.ndarray): initial policy vector
    Returns:
        np.ndarray: flattened policy returned by S‑EpiRC‑PGS
    """
    # Initialize belief over [0, H] as uniform discrete distribution
    grid = np.linspace(0.0, float(H), belief_grid_size)
    belief = np.ones(belief_grid_size)
    belief /= belief.sum()
    # Current policy to be refined by successive inner loop calls
    current_pi = initial_pi_flat.copy()
    for k in range(K_outer):
        # Compute the CDF and find the median index
        cdf = np.cumsum(belief)
        median_idx = int(np.searchsorted(cdf, 0.5))
        b0 = grid[median_idx]
        # Prepare a container to record Δ during the inner iterations
        inner_deltas: list[float] = []
        # Run inner loop to (approximately) solve the auxiliary problem
        pi_k = inner_loop(b0, current_pi, record_list=inner_deltas)
        # Estimate Δ exactly using dynamic programming for robustness
        delta_val = compute_delta(pi_k, [b0] + B[1:])
        # Determine sign: +1 if Δ > 0 (b0 too small) else -1
        Z_k = 1 if delta_val > 0 else -1
        # Record outer loop statistics
        outer_b0_history.append(float(b0))
        outer_delta_history.append(float(delta_val))
        outer_sign_history.append(int(Z_k))
        inner_deltas_history.append(inner_deltas)

        # Update belief according to PBA update rule
        # If Z_k = +1, true J⋆ > b0, so upweight the right half
        if Z_k == 1:
            mask = grid <= b0
            belief[mask] *= (1.0 - p_reliability)
            belief[~mask] *= p_reliability
        else:
            # Z_k == -1 implies J⋆ ≤ b0, upweight the left half
            mask = grid <= b0
            belief[mask] *= p_reliability
            belief[~mask] *= (1.0 - p_reliability)
        # Normalise the belief distribution
        belief_sum = belief.sum()
        if belief_sum > 0:
            belief /= belief_sum
        else:
            # Avoid degenerate case: reinitialise to uniform
            belief = np.ones_like(belief) / belief.size
        # Update current policy for next iteration
        current_pi = pi_k.copy()
        # Optional diagnostics: uncomment for verbose output
        # print(f"Outer iter {k}: b0 = {b0:.4f}, Δ = {delta_val:.4f}, sign = {Z_k:+d}")
    # Final threshold is the median of the final belief
    cdf_final = np.cumsum(belief)
    median_idx_final = int(np.searchsorted(cdf_final, 0.5))
    b_final = grid[median_idx_final]
    # Final inner loop call to get the returned policy
    final_inner_deltas: list[float] = []
    final_pi = inner_loop(b_final, current_pi, record_list=final_inner_deltas)
    # Record the final inner deltas but do not append b_final/delta/sign again
    inner_deltas_history.append(final_inner_deltas)
    if return_logs:
        # Package the logged data into a dictionary
        logs = {
            "outer_b0": outer_b0_history,
            "outer_delta": outer_delta_history,
            "outer_sign": outer_sign_history,
            "inner_deltas": inner_deltas_history,
            "final_b": float(b_final)
        }
        return final_pi, logs
    else:
        return final_pi


if __name__ == "__main__":
    # Run the algorithm and collect logs for visualisation
    final_policy, logs = pba_outer_loop(pi_flat, return_logs=True)
    print("Final policy (flattened) from S‑EpiRC‑PGS:")
    print(final_policy)
    # Generate visualisations of the learning process and final policy
    try:
        import matplotlib
        matplotlib.use('Agg')  # Use a non‑interactive backend
        import matplotlib.pyplot as plt
        # Outer loop metrics
        outer_iters = range(len(logs["outer_b0"]))
        fig, ax = plt.subplots(2, 1, figsize=(8, 8), constrained_layout=True)
        ax[0].plot(outer_iters, logs["outer_b0"], marker='o', label='Threshold b0')
        ax[0].plot(outer_iters, logs["outer_delta"], marker='x', label='Δ (violation)')
        # Use scatter to indicate sign of Δ (positive/negative)
        colors = ['tab:red' if s > 0 else 'tab:green' for s in logs["outer_sign"]]
        ax[0].scatter(outer_iters, logs["outer_delta"], c=colors, label='Sign of Δ')
        ax[0].set_xlabel('Outer iteration')
        ax[0].set_ylabel('Value')
        ax[0].set_title('Outer loop evolution')
        ax[0].legend()
        # Inner loop curves: plot Δ per inner iteration for each outer loop
        for idx, inner_deltas in enumerate(logs["inner_deltas"][:-1]):
            ax[1].plot(range(len(inner_deltas)), inner_deltas, label=f'outer {idx}')
        ax[1].set_xlabel('Inner iteration')
        ax[1].set_ylabel('Δ')
        ax[1].set_title('Inner loop Δ over iterations')
        if len(logs["inner_deltas"]) > 1:
            ax[1].legend(loc='upper right', fontsize='small', ncol=2)
        # Save the figure summarising the learning process
        fig.savefig('learning_results.png')
        plt.close(fig)
        # Visualise final policy as both a heatmap and bar chart
        policy_matrix = final_policy.reshape(S, A)
        # Heatmap
        fig2, ax2 = plt.subplots(figsize=(4, 3))
        im = ax2.imshow(policy_matrix, cmap='viridis', aspect='auto')
        ax2.set_xticks(range(A))
        ax2.set_yticks(range(S))
        ax2.set_xlabel('Action')
        ax2.set_ylabel('State')
        ax2.set_title('Final policy heatmap')
        for i in range(S):
            for j in range(A):
                val = policy_matrix[i, j]
                ax2.text(j, i, f"{val:.2f}", ha='center', va='center', color='white' if val > 0.5 else 'black')
        fig2.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
        fig2.savefig('final_policy_heatmap.png')
        plt.close(fig2)
        # Bar chart for each state
        fig3, ax3 = plt.subplots(figsize=(6, 4))
        x = np.arange(A)
        width = 0.8 / S
        for s_idx in range(S):
            ax3.bar(x + s_idx * width, policy_matrix[s_idx], width, label=f'state {s_idx}')
        ax3.set_xticks(x + width * (S - 1) / 2)
        ax3.set_xticklabels([f'action {a}' for a in range(A)])
        ax3.set_ylabel('Probability')
        ax3.set_title('Final policy bar chart')
        ax3.legend()
        fig3.savefig('final_policy_barchart.png')
        plt.close(fig3)
        print("Learning curves and final policy visualisations have been saved as PNG files.")
    except Exception as e:
        print("Matplotlib is not available or an error occurred while plotting:", e)

Final policy (flattened) from S‑EpiRC‑PGS:
[0.84303778 0.15696222 0.99069685 0.00930315 1.         0.        ]


  fig.savefig('learning_results.png')


Learning curves and final policy visualisations have been saved as PNG files.
