# Robust Constrained MDP

---

今回はこの[論文](https://openreview.net/pdf?id=e-ZdxsIwweR)を参考にConsttrained MDPとRobust MDPをうまく混合させることについてコードも含めて説明していきます．

---

In [35]:
import numpy as np
from itertools import product
from functools import partial
from typing import Optional,NamedTuple
import jax.numpy as jnp
import jax

class MDP(NamedTuple):
    S_array:np.ndarray #状態空間
    A_array:np.ndarray #行動空間
    gamma:float #割引率
    horizon:int #ホライゾン
    rew:np.ndarray #報酬関数
    con:np.ndarray #制約関数
    P:np.ndarray #遷移行列
    optimal_Q:Optional[np.ndarray] = None #最適なQ関数

    @property
    def S(self):
        return len(self.S_array)

    @property
    def A(self):
        return len(self.A_array)


np.random.seed(0)
S = 5 #状態数
A = 3 #行動数
rew = np.random.rand(S,A) #報酬関数
con = np.zeros((S,A)) #制約関数

con[0,2] = 1.0
con[1,1] = 1.0
con[1,2] = 1.0
con[4,2] = 1.0

P = np.random.rand(S*A,S) #遷移行列
P = P / P.sum(axis=-1,keepdims=True)
P = P.reshape(S, A, S)
np.testing.assert_almost_equal(P.sum(axis=-1), 1)
gamma = 0.9
horizon = int(1/(1-gamma)) * 2
S_array = np.arange(S)
A_array = np.arange(A)

mdp = MDP(S_array,A_array,gamma,horizon,rew,con,P)



In [37]:
from tqdm import tqdm

@jax.jit
def compute_greedy_policy(Q:jnp.ndarray):
    S,A = Q.shape
    greedy_policy = jnp.zeros((S,A))
    greedy_policy = greedy_policy.at[jnp.arange(S),Q.argmax(axis=-1)].set(1.0)
    return greedy_policy

@partial(jax.jit,static_argnames=('S','A'))
def _compute_optimal_Q(mdp:MDP,S:int,A:int):

    def backup(optimal_Q):
        optimal_Q = optimal_Q.max(axis=-1)
        return mdp.rew + mdp.gamma * mdp.P @ optimal_Q
    
    optimal_Q = jnp.zeros((S,A))
    body_fn = lambda i,Q: backup(Q)
    optimal_Q = jax.lax.fori_loop(0,mdp.horizon,body_fn,optimal_Q)
    return optimal_Q

compute_optimal_Q = lambda mdp : _compute_optimal_Q(mdp,mdp.S,mdp.A)


@partial(jax.jit,static_argnames=('S','A'))
def _compute_optimal_robust_Q(mdp:MDP,S:int,A:int):

    def backup(optimal_Q):
        policy = compute_greedy_policy(optimal_Q)
        pi_Q = (policy * optimal_Q)
        P = mdp.P.reshape(S*A,S)
        P = P.min(axis=-1)
        P = P.reshape(S,A)
        return mdp.rew + mdp.gamma * P * pi_Q
    
    optimal_Q = jnp.zeros((S,A))
    body_fn = lambda i,Q: backup(Q)
    optimal_Q = jax.lax.fori_loop(0,mdp.horizon,body_fn,optimal_Q)
    return optimal_Q



compute_optimal_robust_Q = lambda mdp : _compute_optimal_robust_Q(mdp,mdp.S,mdp.A)


@jax.jit
def compute_greedy_constrained_policy(Q:jnp.ndarray):
    S,A = Q.shape
    greedy_policy = jnp.zeros((S,A))
    Q_con = Q
    # Q_con = Q_con.at[[0,2],[1,1],[4,2]].set(0)
    Q_con = Q_con.at[0,2].set(0)
    Q_con = Q_con.at[1,1].set(0)
    Q_con = Q_con.at[1,2].set(0)
    Q_con = Q_con.at[4,2].set(0)
    
    greedy_policy = greedy_policy.at[jnp.arange(S),Q_con.argmax(axis=-1)].set(1.0)
    return greedy_policy

@partial(jax.jit,static_argnames=('S','A'))
def _compute_optimal_robust_constrained_Q(mdp:MDP,S:int,A:int):

    def backup(optimal_Q):
        policy = compute_greedy_constrained_policy(optimal_Q)
        pi_Q = (policy * optimal_Q)
        P = mdp.P.reshape(S*A,S)
        P = P.min(axis=-1)
        P = P.reshape(S,A)
        return mdp.rew + mdp.gamma * P * pi_Q
    
    optimal_Q = jnp.zeros((S,A))
    body_fn = lambda i,Q: backup(Q)
    optimal_Q = jax.lax.fori_loop(0,mdp.horizon,body_fn,optimal_Q)
    return optimal_Q

compute_optimal_robust_constrained_Q = lambda mdp : _compute_optimal_robust_constrained_Q(mdp,mdp.S,mdp.A)

In [38]:
optimal_Q = compute_optimal_Q(mdp)
optimal_robust_Q = compute_optimal_robust_Q(mdp)
optimal_robust_constrained_Q = compute_optimal_robust_constrained_Q(mdp)

print(f'最適なQ値とロバストQ値の差分:{jnp.abs(optimal_Q - optimal_robust_Q).max()}')
print(f'最適なrobust Q値とロバスト制約付きQ値の差分:{jnp.abs(optimal_robust_Q - optimal_robust_constrained_Q).max()}')


最適なQ値とロバストQ値の差分:6.196681499481201
最適なrobust Q値とロバスト制約付きQ値の差分:0.01615595817565918


これ上手くいってんの？

In [39]:
print(optimal_robust_constrained_Q)
print(optimal_robust_Q)

[[0.5488135  0.7403013  0.60276335]
 [0.54934484 0.4236548  0.6458941 ]
 [0.4375872  0.891773   1.0493212 ]
 [0.3834415  0.8238071  0.5288949 ]
 [0.56804454 0.95234555 0.07103606]]
[[0.5488135  0.7403013  0.60276335]
 [0.5448832  0.4236548  0.66205007]
 [0.4375872  0.891773   1.0493212 ]
 [0.3834415  0.8238071  0.5288949 ]
 [0.56804454 0.95234555 0.07103606]]
