# EpiRC-PGSの実装

[Near-Optimal Policy Identification in Robust Constrained Markov Decision Processes via Epigraph Form](https://openreview.net/forum?id=G5sPv4KSjR)において提案された，**Epigraph Robust Constrained Policy Gradient Search**(EpiRC-PGS)の再現実装を行いました．

* ほとんど自力で作ったので，ひどいコードです．
* 外側ループの二分探索は正しいと思います．
* 内側ループの勾配収束が現状の問題です．
  * 内側ループ内の反復で方策がすぐに決定論的になってしまって，勾配が収束しない問題が発生しているっぽいです．直したいです．
* numpyで実装したので反復計算が結構遅いです．jaxで描いたほうがいいかも
* コメントは適宜改良します．




In [2]:
import random
from jax import grad
import numpy as np

S = 3 # 状態数
A = 2 # 行動数
N = 4  # 制約+目的関数の数
T = 100  # エピソード数
U_num = 3 # 不確実性集合の要素数
alpha = 0.01  # 学習率
gamma = 0.95 # 割引率

# 方策の初期化
# pi は S*A の1次元ベクトルとして初期化(論文上でベクトルだけど，コードなら行列にしたほうが良かった気がする)
pi = np.zeros((S, A)) 
pi_temp_table = np.random.rand(S, A)
pi_temp_table = pi_temp_table / np.sum(pi_temp_table, axis=1, keepdims=True)
pi = pi_temp_table.flatten() # グローバル変数 pi (1次元 S*A ベクトル)

# 初期状態分布 mu の初期化
# mu は S 次元のベクトルとして初期化
mu = np.random.rand(S)
mu = mu / np.sum(mu) 

# 有効ホライゾンの定義
H = int(np.round(1 / (1 - gamma)))

# 不確実性集合の作成
def make_U():
    U_list = []
    for _ in range(U_num):
        P_kernel = np.random.rand(S, A, S)
        P_kernel = P_kernel / np.sum(P_kernel, axis=2, keepdims=True)
        U_list.append(P_kernel)
    return U_list

U = make_U() 

# N個(制約数個)のコスト関数、初期化されたQ行列、閾値を作成
def make_C_and_Q_pi_b():
    C_list = []
    Q_list = [] 
    B_list = []
    for _ in range(N): 
        c_sa = np.random.rand(S, A)
        C_list.append(c_sa)
        Q_pi_initial = np.zeros((S, A)) 
        Q_list.append(Q_pi_initial)
        b_val = random.randint(0, H)
        B_list.append(b_val)
    return C_list, Q_list, B_list

C, Q_initial_list, B = make_C_and_Q_pi_b() 
                                       


def compute_dot_pi_Q(s_idx, current_Q_sa,pi): 
    """
    方策 pi と Q 関数の内積を計算
    pi は S*A の1次元ベクトルとして与えられ、Qは S*A の2次元行列として与えられる。
    input:
    s_idx: 状態のインデックス (0 <= s_idx < S)
    current_Q_sa: Q関数の行列 (S, A)
    pi: 方策の1次元ベクトル (S*A)
    output:
    V_s: 状態 s の価値関数 V(s) の値
    """

    pi_s_actions = pi[s_idx * A : (s_idx + 1) * A] #s_idx行における行動列ベクトルの獲得．次元は A
    return np.dot(pi_s_actions, current_Q_sa[s_idx, :])# Q行列のs_idx行と行動列ベクトルの内積を計算．スカラー値

def compute_sum_dot_pi_Q(current_Q_sa,pi): 
    """
    全てのsについて方策とQ関数の内積を計算した配列
    
    input:
    current_Q_sa: Q関数の行列 (S, A)
    pi: 方策の1次元ベクトル (S*A)
    output:
    V_s_results: 各状態 s の価値関数 V(s) の値を格納した配列 (S,)
    """

    V_s_results = np.zeros(S) # 出力用S次元のベクトルを初期化
    for s_idx in range(S): 
        V_s_results[s_idx] = compute_dot_pi_Q(s_idx, current_Q_sa,pi) # Q行列のs_idx行と方策ののs_idx行動列ベクトルの内積を計算し，s_idx行に格納
    return V_s_results # 各状態 s の価値関数V(s)の値を格納した配列．次元はS

def compute_sum_P_dot_dot_pi_Q(s_idx, a_idx, P_kernel_sas, current_Q_sa,pi):
    """
    sum_{s'} P(s'|s,a) V(s') を計算
    input:
        s_idx: 状態のインデックス (0 <= s_idx < S)
        a_idx: 行動のインデックス (0 <= a_idx < A)
        P_kernel_sas: 状態遷移確率行列 (S, A, S')
        current_Q_sa: Q関数の行列 (S, A)
        pi: 方策の1次元ベクトル (S*A)
    output:
        sum_{s'} P(s'|s,a) V(s')，スカラー
    """
    V_s_prime_current = compute_sum_dot_pi_Q(current_Q_sa,pi) # Vを計算
    return np.dot(P_kernel_sas[s_idx, a_idx, :], V_s_prime_current)#VとPの内積をとる．

def compute_Q_pi_c_P(s_idx, a_idx, P_kernel_sas, current_Q_sa, cost_sa_current,pi):
    """
    Q(s,a) の1ステップ更新値を計算します(ベルマン方程式の適用)

    Q^\pi_c_P(s,a) = c(s,a) + gamma * sum_{s'} P(s'|s,a) V(s')

    input:
        s_idx: 状態のインデックス (0 <= s_idx < S)
        a_idx: 行動のインデックス (0 <= a_idx < A)
        P_kernel_sas: 状態遷移確率行列 (S, A, S')
        current_Q_sa: Q関数の行列 (S, A)
        cost_sa_current: コスト関数 (S, A)
        pi: 方策の1次元ベクトル (S*A)
    output:
        Q(s,a) の1ステップ更新値，スカラー
    """
    expected_future_val = compute_sum_P_dot_dot_pi_Q(s_idx, a_idx, P_kernel_sas, current_Q_sa,pi)
    return cost_sa_current[s_idx, a_idx] + gamma * expected_future_val

def compute_Q_pi_c_P_matrix(P_kernel_sas, Q_sa_initial_guess, cost_sa_current,pi, max_iterations=1000,tolerance=0.001):
    """
    Q関数の行列を反復計算します．ベルマン方程式を更新誤差が収束するまで反復的に計算します．
    input:
        P_kernel_sas: 状態遷移確率行列 (S, A, S')
        Q_sa_initial_guess: Q関数の初期値 (S, A)
        cost_sa_current: コスト関数 (S, A)
        pi: 方策の1次元ベクトル (S*A)
        max_iterations: 最大反復回数
        tolerance: 収束判定の閾値
    output:
        Q_k_plus_1_sa: 収束したQ関数の行列 (S, A)
    """
    Q_k_sa = np.copy(Q_sa_initial_guess)

    
    # ベルマン方程式で更新されたQを反復的にベルマン方程式を適用して計算
    for iteration in range(max_iterations):
        Q_k_plus_1_sa = np.zeros((S, A)) 
        # Qの行列全てに対してベルマン方程式を適用
        for s_idx in range(S):
            for a_idx in range(A):
                Q_k_plus_1_sa[s_idx, a_idx] = compute_Q_pi_c_P( # 元の compute_Q_pi_c_P を呼び出し
                    s_idx, a_idx, P_kernel_sas, Q_k_sa, cost_sa_current,pi
                )
        
        # 収束判定
        if np.max(np.abs(Q_k_plus_1_sa - Q_k_sa)) < tolerance:
            print(f"Q converged at iteration {iteration + 1} for P: {P_kernel_sas[0,0,0]:.2f} c: {cost_sa_current[0,0]:.2f}")
            return Q_k_plus_1_sa # 収束したQ行列を返す
        Q_k_sa = Q_k_plus_1_sa 
    return Q_k_sa # 収束しなかった場合は最後のQ行列を返す

def compute_J_c_P(P_kernel_sas, Q_sa_initial_guess, cost_sa_current,pi):
    """
    目的関数 J_c_P を計算します。
    
    Parameters:
    P_kernel_sas : np.ndarray
        状態遷移確率行列
    Q_sa_initial_guess : np.ndarray
        Q行列の初期値
    cost_sa_current : np.ndarray
        コスト関数

    Returns:
    float
        目的関数の値
    """
    Q_sa = compute_Q_pi_c_P_matrix(P_kernel_sas, Q_sa_initial_guess, cost_sa_current,pi)  # Q行列を計算
    V_s = np.zeros(S)  
    pi_matrix = pi.reshape(S, A)  
    for s_idx in range(S):
        V_s[s_idx] = np.dot(pi_matrix[s_idx, :], Q_sa[s_idx, :]) 
    return np.dot(mu, V_s) 

def compute_Q_pi_for_all_U(C, Q_initial_list,U,pi):  
    Q_cn_U_array = np.zeros((N, U_num),dtype=object) 
    for j in range(N):
        Q_results_respect_U = []  
        for i in range(U_num):
            Q_result = compute_Q_pi_c_P_matrix(U[i], Q_initial_list[j], C[j],pi)
            Q_results_respect_U.append(Q_result)
            Q_cn_U_array[j, i] = Q_result 

    return Q_cn_U_array  

def compute_J_c_U_b_and_its_max(U, C, Q_initial_list, B,pi):
    """
    各不確実性集合に対して目的関数を計算し、最大値を返します。
    
    Parameters:
    U : list of np.ndarray
        不確実性集合のリスト
    C : list of np.ndarray
        コスト関数のリスト
    Q_initial_list : list of np.ndarray
        Q行列の初期値のリスト
    B : list of float
        制約条件のリスト

    Returns:
    tuple
        目的関数の値とその最大値
    """
    J_results = []
    J_max_index_U = [] 
    for n in range(N):
        J_results_respect_U = []
        for i in range(U_num):
            J_result = compute_J_c_P(U[i], Q_initial_list[n], C[n], pi)   
            J_results_respect_U.append(J_result)
        J_max_index_U.append(np.argmax(J_results_respect_U))  
        J_results.append(np.max(J_results_respect_U)- B[n])  
    # argmaxをとる
    J_max_index = np.argmax(J_results)
    return J_results, J_max_index,J_max_index_U  


def compute_occupancy_measure_d_P_pi(
    P_kernel_sas,             
    current_gamma,
    initial_dist_mu_s,     
    num_states,
    num_actions,
    pi
    ):
    """ d_P^pi(s) を計算 """
    
    pi_table_sa = pi.reshape(num_states, num_actions)  # S行A列の形に変形
    
    # 1. 実効的な状態遷移確率 P_pi(s'|s) = sum_a pi(a|s)P(s'|s,a) を計算
    # P_kernel_sas: (S, A, S') -> ijk
    # pi_table_sa: (S, A) -> ij
    # 結果 P_pi_ss_prime: (S, S') -> ik
    P_pi_ss_prime = np.einsum('ijk,ij->ik', P_kernel_sas, pi_table_sa)
    # 2. 線形方程式 (I - gamma * (P_pi)^T) d^T = (1-gamma) mu^T を解く
    #    d は行ベクトルとして扱いたいので、転置して d^T を求める
    #    (A_matrix) x = b  =>  x = d^T
    A_matrix_for_d = np.eye(num_states) - current_gamma * P_pi_ss_prime.T
    b_vector_for_d = (1 - current_gamma) * initial_dist_mu_s
    
    try:
        d_P_pi_s_transposed = np.linalg.solve(A_matrix_for_d, b_vector_for_d)
        d_P_pi_s = d_P_pi_s_transposed # solve は列ベクトルとして解を返すので、これがd_P_pi(s)
    except np.linalg.LinAlgError:
        print("Warning: Linear system solve failed for occupancy measure. Using pseudo-inverse.")
        try:
            d_P_pi_s_transposed = np.linalg.pinv(A_matrix_for_d) @ b_vector_for_d
            d_P_pi_s = d_P_pi_s_transposed
        except np.linalg.LinAlgError:
            print("Error: Pseudo-inverse also failed for occupancy measure. Returning uniform distribution.")
            d_P_pi_s = np.ones(num_states) / num_states # フォールバック
    return d_P_pi_s # S次元のベクトル

def compute_gradient_J_c_U(
    P,            
    initial_dist_mu_s,      
    num_states,
    num_actions,
    pi,
    current_gamma,
    Q,
    H,
    index_U,
    jindex_c
):
    
    d_P_worst_pi_s = compute_occupancy_measure_d_P_pi(
        P[index_U], current_gamma, initial_dist_mu_s, num_states, num_actions, pi)
    
    gradient_J = np.zeros((S,A))  
    # 勾配の計算をします．最悪ケースの制約nおよびPに対して計算を行います．
    # d_P_worst_pi_sとHおよび，worstなPおよびcに対するQ行列を使って勾配を計算します．
    # ∇J = H * d_P_worst_pi_s * (Q[index_c][index_U])  # 勾配の計算
    for s_idx in range(num_states):
        for a_idx in range(num_actions):
            gradient_J[s_idx, a_idx] = H * d_P_worst_pi_s[s_idx] * Q[jindex_c][index_U][s_idx, a_idx]
    return gradient_J  


#　射影による方策更新(ここだけは北村さんのコード+GPTでnumpy用に生成しました．でもjaxにしたほうがいいと思う)
def projection_to_simplex_numpy(y_vector):
    """
    ベクトル y_vector を確率単体（要素が非負で合計が1）に射影します。
    JAXのチュートリアルなどで見られるアルゴリズムのNumPy版。

    Args:
        y_vector (np.ndarray): 射影したい1次元ベクトル。

    Returns:
        np.ndarray: 射影された1次元ベクトル。
    """
    n_features = y_vector.shape[0]
    u = np.sort(y_vector)[::-1]
    cssv = np.cumsum(u)
    indices = np.arange(n_features) + 1
    condition_for_rho = u + (1 - cssv) / indices > 0
    
    if not np.any(condition_for_rho):
        if np.all(y_vector >= 0) and np.isclose(np.sum(y_vector), 1.0):
            return y_vector
        rho = 0
    else:
        k_vals = np.arange(1, n_features + 1)
        sum_u_k = np.cumsum(u)
        condition_duchi_rho = u > (sum_u_k - 1) / k_vals
        if np.any(condition_duchi_rho):
            rho_duchi_0_indexed = np.where(condition_duchi_rho)[0][-1]
            # lambda_val (Duchiのtheta)
            lambda_val = (sum_u_k[rho_duchi_0_indexed] - 1) / (rho_duchi_0_indexed + 1)
            projected_y = np.maximum(y_vector - lambda_val, 0.)
            return projected_y
        else:
            y_non_negative = np.maximum(y_vector, 0)
            s = np.sum(y_non_negative)
            if s > 0:
                return y_non_negative / s
            else: # 全て0以下なら均等分布
                return np.ones(n_features) / n_features


def proj_to_Pi_numpy(policy_matrix_SA):
    """
    方策行列の各行（各状態の方策ベクトル）を確率単体に射影します。
    Args:
        policy_matrix_SA (np.ndarray): (S, A) の方策行列。

    Returns:
        np.ndarray: (S, A) の射影された方策行列。
    """
    if policy_matrix_SA.ndim != 2:
        raise ValueError("Input policy_matrix_SA must be a 2D array (S, A).")
    
    projected_policy = np.zeros_like(policy_matrix_SA)
    num_states = policy_matrix_SA.shape[0]
    
    for s_idx in range(num_states):
        projected_policy[s_idx, :] = projection_to_simplex_numpy(policy_matrix_SA[s_idx, :])
    return projected_policy

def update_and_project_policy(
    current_policy_table_SA,
    gradient_table_SA,       
    learning_rate_alpha
    ):
    """
    方策を勾配法で更新し、結果を確率単体に射影します。

    Args:
        current_policy_table_SA (np.ndarray): 現在の方策遷移行列 (S, A)。
        gradient_table_SA (np.ndarray): 計算された勾配行列 (S, A)。
        learning_rate_alpha (float): 学習率。

    Returns:
        np.ndarray: 更新され射影された新しい方策遷移行列 (S, A)。
    """
    if current_policy_table_SA.shape != gradient_table_SA.shape:
        raise ValueError("Shape mismatch between current_policy_table_SA and gradient_table_SA.")

    y_table_SA = current_policy_table_SA - learning_rate_alpha * gradient_table_SA
    
    # 2. 各状態の方策を確率単体に射影
    updated_projected_policy_table_SA = proj_to_Pi_numpy(y_table_SA)
    
    return updated_projected_policy_table_SA



def compute_inner_loop(b_0,pi):
    B[0]= b_0  # Bの最初の要素を更新
    for t in range(T):
        J_results, J_max_index,J_max_index_U = compute_J_c_U_b_and_its_max(U, C, Q_initial_list, B, pi)  # 各不確実性集合に対して目的関数を計算
        Q_array = compute_Q_pi_for_all_U(C, Q_initial_list,U,pi)  # 各不確実性集合に対するQ行列を計算
        gradient_J = compute_gradient_J_c_U(
            U, mu, S, A, pi, gamma, Q_array, H, J_max_index_U[J_max_index], J_max_index
        )  # 勾配を計算
        pi = update_and_project_policy(pi.reshape((S,A)), gradient_J, alpha)  # 方策を更新し射影
        pi = pi.flatten()  # 1次元ベクトルに戻す
        # Qの最大値の推移を記録
        if t % 10 == 0:  # 10エピソードごとに出力
            print(f"Episode {t}: J_max = {J_results[J_max_index]:.4f}, pi = {pi}")
            print(f"Q_max for U[{J_max_index_U[J_max_index]}]: {np.max(Q_array[J_max_index, J_max_index_U[J_max_index]]):.4f}")
            print(f"Gradient at episode {t}: {gradient_J}")
    return pi


def compute_delta(pi,B):
    """
    deltaを計算します。
    """
    J_results, J_max_index, J_max_index_U = compute_J_c_U_b_and_its_max(U, C, Q_initial_list, B, pi)
    return J_results[J_max_index] - B[J_max_index] 

def compute_i_j(i_curren,j_curren,delta,B):
    """
    iとjを計算します。
    iは、deltaが正のときはb0，0以下のときはi(そのまま)に設定されます。
    jは、deltaが正のときはj(そのまま)、0以下のときはb0に設定されます。
    """
    if delta > 0:
        i_new = B[0]  # b0
        j_new = j_curren  # jはそのまま
    else:
        i_new = i_curren  # iはそのまま
        j_new = B[0]  # b0
    return i_new, j_new

i_current = 0  
j_current = H

B[0]= (i_current + j_current) / 2  

def compute_outer_loop():
    global pi_new, i_current, j_current, B,pi
    for t in range(T):
        pi = compute_inner_loop(B[0],pi)
        delta = compute_delta(pi, B)
        i_current, j_current = compute_i_j(i_current, j_current, delta, B)
        B[0] = (i_current + j_current) / 2
        print(f"Outer loop iteration {t}: i = {i_current}, j = {j_current}, delta = {delta:.4f}, B[0] = {B[0]:.4f}")


    return pi_new

compute_outer_loop()  # 外側ループを実行

Q converged at iteration 117 for P: 0.21 c: 0.57
Q converged at iteration 118 for P: 0.33 c: 0.57
Q converged at iteration 120 for P: 0.48 c: 0.57
Q converged at iteration 118 for P: 0.21 c: 0.03
Q converged at iteration 117 for P: 0.33 c: 0.03
Q converged at iteration 115 for P: 0.48 c: 0.03
Q converged at iteration 111 for P: 0.21 c: 0.28
Q converged at iteration 111 for P: 0.33 c: 0.28
Q converged at iteration 109 for P: 0.48 c: 0.28
Q converged at iteration 119 for P: 0.21 c: 0.11
Q converged at iteration 118 for P: 0.33 c: 0.11
Q converged at iteration 117 for P: 0.48 c: 0.11
Q converged at iteration 117 for P: 0.21 c: 0.57
Q converged at iteration 118 for P: 0.33 c: 0.57
Q converged at iteration 120 for P: 0.48 c: 0.57
Q converged at iteration 118 for P: 0.21 c: 0.03
Q converged at iteration 117 for P: 0.33 c: 0.03
Q converged at iteration 115 for P: 0.48 c: 0.03
Q converged at iteration 111 for P: 0.21 c: 0.28
Q converged at iteration 111 for P: 0.33 c: 0.28
Q converged at itera

KeyboardInterrupt: 

## 得られた知見
内側ルーチンでのベルマン方程式の計算は，収束するまで反復的に計算をしている．
* 価値反復法は1反復で1ステップの更新だけ
* 方策反復法は1反復で収束するまで反復更新をする

ホライゾンが1に近いと，収束に時間がかかる＋勾配がめっちゃでかくなる

### 内側ループ
* T回反復する
  * 現在の方策を用いて方策評価を行う
      * 制約ごとに，J_n_U - bnを求める
        * 遷移確率ごとに，J_c_Pを計算する
          * ベルマン方程式を，Qの更新差分が収束するまで繰り返し適用する
        * 最もJ_c_Pが大きなPとJn_Uを求める．
        * J_n_U - bnを求める．
        * 最もJ_n_U - bnが大きくなる制約nを求める
  * 射影勾配法を用いて方策を更新する
    * 方策評価で求めた制約nにおけるJ_n_Uを現在の方策での劣勾配を求める．
      * ∇J_n_U(pi) = H * d_pi * Q (勾配の計算が怪しい)
      * (劣)勾配は方策と同じ次元
    * 方策から勾配を定数倍したものを引く
    * 引いたものを入力とする最適化(最小化)問題を解く
      * 最適化問題を最小化するような方策を求める
  * 方策を更新し，現在の方策とする．
  * 方策は，J_n_U - bnが最も小さくなったtにおける方策とする

### 外側ループ
i =0,j = Hとして初期化
* K回反復する
  * b_0 = (i+j)/2
  * b0を入力として，内側ループで得られた方策を獲得する
  * 得られた方策を用いて，J_n_U -bnの最大値を求める
  * 最大値を用いてiとjを更新する．
    * 最大値が0超過なら iはb0，それ以外ならそのまま
    * 最大値が0超過なら jはそのまま，それ以外ならb0


In [5]:
import jax
import jax.numpy as jnp
import functools

# --- グローバル定数 (JAX関数内では引数として渡すか、partialで束縛) ---
# S, A, N, T, U_num, alpha, gamma, H は適宜設定
# 例:
S = 3
A = 2
N_constraints = 4 # 元のNは制約+目的だったので、制約の数
T_inner = 100
T_outer = 10 # デモ用に小さく
U_num = 3
alpha_lr = 0.01
gamma_discount = 0.95
H_horizon = int(jnp.round(1 / (1 - gamma_discount)))
MAX_ITER_Q = 1000 # Q学習の最大反復回数
TOL_Q = 0.001    # Q学習の収束閾値


# --- 初期化関数 (JAXのキーを使用) ---
def init_pi(key, s_dim, a_dim):
    key_pi, key_norm = jax.random.split(key)
    pi_temp_table = jax.random.uniform(key_pi, (s_dim, a_dim))
    pi_policy = pi_temp_table / jnp.sum(pi_temp_table, axis=1, keepdims=True)
    return pi_policy # (S, A) 行列

def init_mu(key, s_dim):
    mu_dist = jax.random.uniform(key, (s_dim,))
    return mu_dist / jnp.sum(mu_dist)

def make_U(key, u_num, s_dim, a_dim):
    keys_u = jax.random.split(key, u_num)
    def _make_P(k_single):
        P_kernel = jax.random.uniform(k_single, (s_dim, a_dim, s_dim))
        return P_kernel / jnp.sum(P_kernel, axis=2, keepdims=True)
    return jax.vmap(_make_P)(keys_u) # (U_num, S, A, S)

def make_C_and_Q_initial_and_B(key, n_funcs, s_dim, a_dim, h_horizon):
    key_c, key_q, key_b = jax.random.split(key, 3)
    C_list = jax.random.uniform(key_c, (n_funcs, s_dim, a_dim))
    Q_initial_list = jnp.zeros((n_funcs, s_dim, a_dim))
    B_thresholds = jax.random.randint(key_b, (n_funcs,), 0, h_horizon + 1).astype(jnp.float32)
    return C_list, Q_initial_list, B_thresholds


# --- Q関数とJ関数の計算 (JAX版) ---
@functools.partial(jax.jit, static_argnames=["s_dim", "a_dim", "max_iterations", "tolerance"])
def compute_Q_pi_c_P_matrix_jax(P_kernel_sas, Q_sa_initial_guess, cost_sa_current, pi_sa,
                                gamma_val, s_dim, a_dim, max_iterations, tolerance):
    pi_sa = jnp.reshape(pi_sa, (s_dim, a_dim)) # 念のためreshape

    def _cond_fun(state_tuple):
        Q_k_old, Q_k_current, iter_count = state_tuple
        error = jnp.max(jnp.abs(Q_k_current - Q_k_old))
        # 初回実行を保証し、その後エラーと最大反復回数で判定
        return jax.lax.cond(iter_count < 1,
                            lambda _: True,
                            lambda _: (error > tolerance) & (iter_count < max_iterations),
                            None)

    def _body_fun(state_tuple):
        _, Q_k_current_as_old, iter_count = state_tuple
        
        # V_s_k = jnp.sum(pi_sa * Q_k_current_as_old, axis=1) # (S,)
        V_s_k = jnp.einsum('sa,sa->s', pi_sa, Q_k_current_as_old)

        # sum_P_V_sa = jnp.einsum('sasprime,sprime->sa', P_kernel_sas, V_s_k)
        sum_P_V_sa = jnp.einsum('ijk,k->ij', P_kernel_sas, V_s_k)


        Q_k_plus_1_new = cost_sa_current + gamma_val * sum_P_V_sa
        return Q_k_current_as_old, Q_k_plus_1_new, iter_count + 1

    # Q_k_oldの初期値はQ_sa_initial_guessと異なる値にして初回エラー判定を正しくする
    # または、cond_funでiter_count=0の場合を特別扱いする
    q_old_init = jnp.full_like(Q_sa_initial_guess, jnp.inf)
    init_val = (q_old_init, Q_sa_initial_guess, 0)
    
    _, Q_final, iters_taken = jax.lax.while_loop(_cond_fun, _body_fun, init_val)
    # jax.debug.print("Q converged at iteration {iters} for P: {p_val:.2f} c: {c_val:.2f}",
    #                 iters=iters_taken, p_val=P_kernel_sas[0,0,0], c_val=cost_sa_current[0,0]) # JIT内でのprint
    return Q_final

@functools.partial(jax.jit, static_argnames=["s_dim", "a_dim", "max_iter_q", "tol_q"])
def compute_J_c_P_jax(P_kernel_sas, Q_sa_initial_guess, cost_sa_current, pi_sa, mu_s,
                      gamma_val, s_dim, a_dim, max_iter_q, tol_q):
    pi_sa = jnp.reshape(pi_sa, (s_dim, a_dim))
    Q_sa = compute_Q_pi_c_P_matrix_jax(P_kernel_sas, Q_sa_initial_guess, cost_sa_current, pi_sa,
                                       gamma_val, s_dim, a_dim, max_iter_q, tol_q)
    # V_s = jnp.sum(pi_sa * Q_sa, axis=1)
    V_s = jnp.einsum('sa,sa->s', pi_sa, Q_sa)
    return jnp.dot(mu_s, V_s)

# --- 占有尺度と勾配計算 (JAX版) ---
@functools.partial(jax.jit, static_argnames=["num_states", "num_actions"])
def compute_occupancy_measure_d_P_pi_jax(P_kernel_sas, gamma_val, initial_dist_mu_s,
                                         pi_sa, num_states, num_actions):
    pi_sa = jnp.reshape(pi_sa, (num_states, num_actions))
    # P_pi(s'|s) = sum_a pi(a|s)P(s'|s,a)
    P_pi_ss_prime = jnp.einsum('ijk,ij->ik', P_kernel_sas, pi_sa) # (S, S')
    
    A_matrix_for_d = jnp.eye(num_states) - gamma_val * P_pi_ss_prime.T
    b_vector_for_d = (1 - gamma_val) * initial_dist_mu_s
    
    # solveはエラー処理が難しいので、ここでは成功を仮定
    d_P_pi_s = jnp.linalg.solve(A_matrix_for_d, b_vector_for_d)
    return d_P_pi_s

@functools.partial(jax.jit, static_argnames=["s_dim", "a_dim", "h_horizon_val"])
def compute_gradient_term_jax(d_P_worst_pi_s, Q_worst_sa, h_horizon_val, s_dim, a_dim):
    # ∇J(s,a) = H * d(s) * Q(s,a)
    # d_P_worst_pi_s needs to be broadcasted for element-wise multiplication with Q_worst_sa
    # d_P_worst_pi_s: (S,) -> d_expanded: (S,1)
    d_expanded = jnp.expand_dims(d_P_worst_pi_s, axis=1)
    gradient_J_sa = h_horizon_val * d_expanded * Q_worst_sa # (S,A)
    return gradient_J_sa

# --- 射影関数 (JAX版) ---
# (Duchi et al., Efficient Projections onto the l1-Ball for Learning in High Dimensions, ICML 2008)
@jax.jit
def projection_to_simplex_jax(y_vector):
    n_features = y_vector.shape[0]
    u = jnp.sort(y_vector)[::-1]
    cssv = jnp.cumsum(u)
    indices = jnp.arange(n_features) + 1
    condition = u + (1.0 - cssv) / indices > 0
    
    # rho_idx = jnp.count_nonzero(condition) -1 # Gets the last index where condition is true
    # This is tricky if all are false. Safer:
    rho_idx = jnp.sum(condition.astype(jnp.int32)) - 1
    rho_idx = jnp.maximum(0, rho_idx) # Ensure rho_idx is not -1 if all false

    lambda_val = (cssv[rho_idx] - 1.0) / (rho_idx + 1.0)
    projected_y = jnp.maximum(y_vector - lambda_val, 0.0)
    return projected_y

@jax.jit
def proj_to_Pi_jax(policy_matrix_SA):
    return jax.vmap(projection_to_simplex_jax)(policy_matrix_SA)

@jax.jit
def update_and_project_policy_jax(current_policy_table_SA, gradient_table_SA, learning_rate_alpha):
    y_table_SA = current_policy_table_SA - learning_rate_alpha * gradient_table_SA
    return proj_to_Pi_jax(y_table_SA)


# --- メインループ関数 (状態を引数と返り値で管理) ---

# Inner loop structure
# Inner loop structure
# s_dim, a_dim, n_constr_val, u_n_val are now passed as static args via closure
def inner_loop_body_fn_factory(s_dim, a_dim, n_constr_val, u_n_val, gamma_val, h_h, max_iter_q_l, tol_q_l, alpha_l):
    
    def n_loop_body_factory(pi_sa_current_arg, mu_dist_arg, U_kernels_arg, Q_initials_arg, C_costs_arg): # Pass needed args
        def body_fn(n_idx, carry_n_loop):
            q_list_acc, j_list_acc = carry_n_loop

            q_n_all_u_single_n = jax.vmap(compute_Q_pi_c_P_matrix_jax,
                                          in_axes=(0, None, None, None, None, None, None, None, None))(
                U_kernels_arg, Q_initials_arg[n_idx], C_costs_arg[n_idx], pi_sa_current_arg,
                gamma_val, s_dim, a_dim, max_iter_q_l, tol_q_l)

            j_n_all_u_single_n = jax.vmap(compute_J_c_P_jax,
                                          in_axes=(0, None, None, None, None, None, None, None, None, None))(
                U_kernels_arg, Q_initials_arg[n_idx], C_costs_arg[n_idx], pi_sa_current_arg, mu_dist_arg,
                gamma_val, s_dim, a_dim, max_iter_q_l, tol_q_l)

            q_list_acc = q_list_acc.at[n_idx].set(q_n_all_u_single_n)
            j_list_acc = j_list_acc.at[n_idx].set(j_n_all_u_single_n)
            return q_list_acc, j_list_acc
        return body_fn

    def _inner_loop_body_fn(t, state_inner_dynamic):
        # Dynamic parts of the state
        pi_sa_current, key_inner, B_thresholds_current, \
        U_kernels_static, C_costs_static, Q_initials_static, mu_dist_static = state_inner_dynamic
        # ^^^ Static parts are now separated or assumed to be captured by closure

        q_accumulator_init = jnp.zeros((n_constr_val, u_n_val, s_dim, a_dim)) # Now uses static u_n_val, s_dim, a_dim
        j_accumulator_init = jnp.zeros((n_constr_val, u_n_val))

        # Pass dynamic and relevant static parts to the factory
        current_n_loop_body = n_loop_body_factory(pi_sa_current, mu_dist_static, U_kernels_static, Q_initials_static, C_costs_static)
        Q_all_n_u, J_all_n_u = jax.lax.fori_loop(0, n_constr_val, current_n_loop_body, (q_accumulator_init, j_accumulator_init))

        J_worst_case_over_U_for_n = jnp.max(J_all_n_u, axis=1)
        violations_n = J_worst_case_over_U_for_n - B_thresholds_current

        J_max_overall_violation_idx_n = jnp.argmax(violations_n)
        idx_U_for_worst_n = jnp.argmax(J_all_n_u[J_max_overall_violation_idx_n, :])

        P_worst = U_kernels_static[idx_U_for_worst_n]
        Q_worst = Q_all_n_u[J_max_overall_violation_idx_n, idx_U_for_worst_n]

        d_P_worst = compute_occupancy_measure_d_P_pi_jax(P_worst, gamma_val, mu_dist_static,
                                                         pi_sa_current, s_dim, a_dim)
        grad_sa = compute_gradient_term_jax(d_P_worst, Q_worst, h_h, s_dim, a_dim)
        pi_sa_new = update_and_project_policy_jax(pi_sa_current, grad_sa, alpha_l)

        return pi_sa_new, key_inner, B_thresholds_current, \
               U_kernels_static, C_costs_static, Q_initials_static, mu_dist_static # Return dynamic parts + static refs

    return _inner_loop_body_fn


def outer_loop_body_fn(k_outer, state_outer):
    # Dynamic parts from state_outer
    pi_sa_outer, i_bisection, j_bisection, key_outer, B_orig = state_outer[:5]
    # Static parts (assumed constant for the duration of outer_loop or JITted function)
    U_k_static, C_c_static, Q_i_static, mu_d_static, \
    gamma_d_static, s_d_static, a_d_static, h_h_val_static, \
    max_i_q_static, t_q_static, n_c_val_static, u_val_static, \
    alpha_val_il_static, t_inner_loop_static = state_outer[5:]


    key_outer, key_inner_loop = jax.random.split(key_outer)
    b0_current = (i_bisection + j_bisection) / 2.0
    B_current_loop = B_orig.at[0].set(b0_current)

    # Factory creates the inner loop body with static parameters closed over
    _inner_loop_fn = inner_loop_body_fn_factory(
        s_d_static, a_d_static, n_c_val_static, u_val_static,
        gamma_d_static, h_h_val_static, max_i_q_static, t_q_static, alpha_val_il_static
    )
    
    # Loop carry for inner loop only contains dynamic parts + references to static data
    initial_inner_state_dynamic = (
        pi_sa_outer, key_inner_loop, B_current_loop,
        U_k_static, C_c_static, Q_i_static, mu_d_static # Pass static structures as well
    )
    
    # final_inner_state_dynamic will have the same structure as initial_inner_state_dynamic
    final_inner_state_dynamic = jax.lax.fori_loop(0, t_inner_loop_static, _inner_loop_fn, initial_inner_state_dynamic)
    pi_sa_after_inner_loop, _, _, _, _, _, _ = final_inner_state_dynamic # Unpack updated dynamic parts


    # Delta calculation - needs similar refactoring if it uses these dimensions for jnp.zeros etc.
    # For simplicity, let's assume n_c_val_static, u_val_static are correctly used here.
    # Factory for delta's n_loop
    def delta_n_loop_body_factory(pi_sa_arg, mu_arg, U_k_arg, Q_i_arg, C_c_arg):
        def body_fn(n_idx_delta, carry_delta_loop):
            j_n_all_u_delta_array = carry_delta_loop
            j_n_all_u_single_delta = jax.vmap(compute_J_c_P_jax,
                                         in_axes=(0, None, None, None, None, None, None, None, None, None))(
                U_k_arg, Q_i_arg[n_idx_delta], C_c_arg[n_idx_delta], pi_sa_arg, mu_arg,
                gamma_d_static, s_d_static, a_d_static, max_i_q_static, t_q_static)
            j_n_all_u_delta_array = j_n_all_u_delta_array.at[n_idx_delta].set(j_n_all_u_single_delta)
            return j_n_all_u_delta_array
        return body_fn

    j_delta_accumulator_init = jnp.zeros((n_c_val_static, u_val_static))
    current_delta_n_loop_body = delta_n_loop_body_factory(pi_sa_after_inner_loop, mu_d_static, U_k_static, Q_i_static, C_c_static)
    J_all_n_u_delta = jax.lax.fori_loop(0, n_c_val_static, current_delta_n_loop_body, j_delta_accumulator_init)

    J_worst_case_over_U_for_n_delta = jnp.max(J_all_n_u_delta, axis=1)
    violations_n_delta = J_worst_case_over_U_for_n_delta - B_current_loop
    delta_for_bisection = jnp.max(violations_n_delta)

    i_new = jax.lax.cond(delta_for_bisection > 0, lambda _: b0_current, lambda _: i_bisection, None)
    j_new = jax.lax.cond(delta_for_bisection > 0, lambda _: j_bisection, lambda _: b0_current, None)

    # Return the full state for the next outer loop iteration
    return (pi_sa_after_inner_loop, i_new, j_new, key_outer, B_orig) + state_outer[5:]


def run_rcmdp_epigraph_solver(key_init, S_val, A_val, N_constraints_val, U_n_val, H_val, gamma_val,
                               alpha_val_inner, t_in, t_out, max_q_iter, tol_q_val):
    key_pi, key_mu, key_U, key_CQB, key_loop = jax.random.split(key_init, 5)

    pi_initial_sa = init_pi(key_pi, S_val, A_val)
    mu_initial_s = init_mu(key_mu, S_val)
    U_initial_kernels = make_U(key_U, U_n_val, S_val, A_val)
    C_initial_costs, Q_initial_matrices, B_initial_thresholds_orig = \
        make_C_and_Q_initial_and_B(key_CQB, N_constraints_val, S_val, A_val, H_val)

    i_bisection_init = 0.0
    j_bisection_init = float(H_val)

    # Separate dynamic and static parts for the outer loop state
    # Dynamic: pi, i_bisect, j_bisect, key, B_orig (B_orig is modified for b0 but its core structure is static)
    # Static: U, C, Q_initials, mu, and all scalar parameters (gamma, S, A, H, etc.)
    initial_outer_state_dynamic_part = (
        pi_initial_sa, i_bisection_init, j_bisection_init, key_loop, B_initial_thresholds_orig
    )
    initial_outer_state_static_part = (
        U_initial_kernels, C_initial_costs, Q_initial_matrices, mu_initial_s,
        gamma_val, S_val, A_val, H_val, # Scalar parameters
        max_q_iter, tol_q_val, N_constraints_val, U_n_val,
        alpha_val_inner, t_in # More scalar parameters
    )
    initial_outer_state = initial_outer_state_dynamic_part + initial_outer_state_static_part
    
    # JIT compile the outer_loop_body_fn if all static args are handled
    # To do this properly, outer_loop_body_fn needs to be created by a factory
    # that closes over the static parts, or use static_argnums extensively.
    # For now, we keep the Python loop for `run_rcmdp_epigraph_solver` for clarity.

    print("Starting JAX RCMDP solver...")
    current_outer_state = initial_outer_state
    for k_outer_python_loop in range(t_out):
        # When calling outer_loop_body_fn, ensure t_inner_loop (and other loop bounds) are static
        # t_in is already part of initial_outer_state_static_part and correctly named t_inner_loop_static inside
        current_outer_state = outer_loop_body_fn(k_outer_python_loop, current_outer_state)
        pi_current, i_current, j_current, _, B_state_log, *_ = current_outer_state # Unpack carefully
        b0_log = (i_current + j_current) / 2.0
        print(f"Python Outer iter {k_outer_python_loop}, current b0_approx: {b0_log:.4f}, i: {i_current:.4f}, j: {j_current:.4f}")

    final_pi_policy, final_i, final_j, *_ = current_outer_state # Unpack carefully
    print("Solver finished.")
    return final_pi_policy, final_i, final_j


if __name__ == '__main__':
    # グローバル定数の設定
    S_dim = S
    A_dim = A
    N_funcs = N_constraints # 目的関数 + 制約の総数 (b0も含む)
    U_count = U_num
    H_eff = H_horizon
    gamma_val = gamma_discount
    alpha_il = alpha_lr # Inner loop learning rate
    T_inner_loop = T_inner
    T_outer_loop = T_outer
    MAX_ITER_Q_loop = MAX_ITER_Q
    TOL_Q_loop = TOL_Q

    # 初期キー
    master_key = jax.random.PRNGKey(42)

    final_policy, final_i_b, final_j_b = run_rcmdp_epigraph_solver(
        master_key, S_dim, A_dim, N_funcs, U_count, H_eff, gamma_val,
        alpha_il, T_inner_loop, T_outer_loop, MAX_ITER_Q_loop, TOL_Q_loop
    )
    
    print("\n--- Results ---")
    print("Final Policy (sample):")
    print(final_policy[:min(3, S_dim), :]) # Display part of the policy
    print(f"Final b0 approximation range: [{final_i_b:.4f}, {final_j_b:.4f}]")
    print(f"Estimated b0: {(final_i_b + final_j_b) / 2.0:.4f}")

Starting JAX RCMDP solver...
Python Outer iter 0, current b0_approx: 15.0000, i: 10.0000, j: 20.0000
Python Outer iter 1, current b0_approx: 17.5000, i: 15.0000, j: 20.0000
Python Outer iter 2, current b0_approx: 18.7500, i: 17.5000, j: 20.0000
Python Outer iter 3, current b0_approx: 19.3750, i: 18.7500, j: 20.0000
Python Outer iter 4, current b0_approx: 19.6875, i: 19.3750, j: 20.0000
Python Outer iter 5, current b0_approx: 19.8438, i: 19.6875, j: 20.0000
Python Outer iter 6, current b0_approx: 19.9219, i: 19.8438, j: 20.0000
Python Outer iter 7, current b0_approx: 19.9609, i: 19.9219, j: 20.0000
Python Outer iter 8, current b0_approx: 19.9805, i: 19.9609, j: 20.0000
Python Outer iter 9, current b0_approx: 19.9902, i: 19.9805, j: 20.0000
Solver finished.

--- Results ---
Final Policy (sample):
[[0.         1.        ]
 [0.09874785 0.90125203]
 [1.         0.        ]]
Final b0 approximation range: [19.9805, 20.0000]
Estimated b0: 19.9902
