# ロバストMDPのモデルフリーのオンライン学習について

---

In [1]:
import numpy as np
import jax.numpy as jnp
from jax.random import PRNGKey
import jax
from typing import NamedTuple,Optional

key = PRNGKey(0)

S = 10
A = 3
S_array = jnp.arange(S)
A_array = jnp.arange(A)
gamma = 0.95

key,_ = jax.random.split(key)
rew = jax.random.uniform(key,shape=(S,A))

key,_ = jax.random.split(key)
P = jax.random.uniform(key,shape=(S*A,S))
P = P / jnp.sum(P,axis=1,keepdims=True)
P = P.reshape(S,A,S)

#初期状態分布
key,_ = jax.random.split(key)
mu = jax.random.uniform(key,shape=(S,))
mu = mu/jnp.sum(mu)


class MDP(NamedTuple):
    S_array : jnp.ndarray
    A_array : jnp.ndarray
    gamma : float
    H : int 
    rew : jnp.ndarray
    P : jnp.ndarray
    mu : jnp.ndarray
    optimal_Q : Optional[jnp.ndarray] = None #最適なQ値

    @property
    def S(self):
        return len(self.S_array)
    
    @property
    def A(self):
        return len(self.A_array)
    


H = int(1/(1-gamma))
mdp = MDP(S_array,A_array,gamma,H,rew,P,mu)


In [3]:
import jax 
import jax.numpy as jnp
import numpy as np
from functools import partial
import chex

@jax.jit
def compute_greedy_policy(Q: jnp.ndarray) -> jnp.ndarray:
    greedy_policy = jnp.zeros_like(Q)
    S,A = Q.shape
    greedy_policy = greedy_policy.at[jnp.arange(S),jnp.argmax(Q,axis=1)].set(1)
    return greedy_policy

@partial(jax.jit,static_argnames=('S','A'))
def _compute_optimal_Q(mdp:MDP,S:int,A:int) -> jnp.ndarray:
    optimal_Q = jnp.zeros((S,A))

    def backup(Q):
        next_v = mdp.P @ jnp.max(Q,axis=1)
        return mdp.rew + mdp.gamma * next_v
    
    body_fn = lambda i,Q: backup(Q)
    return jax.lax.fori_loop(0,mdp.H+100,body_fn,optimal_Q)

compute_optimal_Q = lambda mdp: _compute_optimal_Q(mdp,mdp.S,mdp.A)

@jax.jit
def compute_policy_Q(mdp:MDP,policy:jnp.ndarray):
    S,A = policy.shape
    chex.assert_shape(policy,(S,A))

    def backup(Q):
        v = (policy * Q).sum(axis=1)
        next_v = mdp.P @ v
        return mdp.rew + mdp.gamma * next_v
    
    policy_Q = jnp.zeros_like(policy)
    body_fn = lambda i, policy_Q: backup(policy_Q)
    return jax.lax.fori_loop(0,mdp.H+100,body_fn,policy_Q)

@jax.jit
def compute_policy_matrix(policy:jnp.ndarray): #訪問頻度の確率求めたいから，遷移確率行列と内積取れるような形にする
    S,A = policy.shape
    PI = policy.reshape(1,S,A)
    PI = jnp.tile(PI,(S,1,1))
    eyes = jnp.eyes(S).reshape(S,S,1)
    PI = PI * eyes
    PI = PI.reshape(S,S*A)

    return PI


@jax.jit
def compute_policy_visit_sa(mdp:MDP,policy:jnp.ndarray,init_dist:jnp.ndarray):
    S,A = policy.shape
    chex.assert_shape(policy,(S,A))
    PI = compute_policy_matrix(policy)
    PPI = mdp.P.reshape(S*A,S) @ PI

    def backup(visit):
        next_visit = mdp.gamma * visit @ PPI
        return init_dist @ PI + next_visit
    
    body_fn = lambda i,visit: backup(visit)
    visit = jnp.zeros(S*A)
    visit = jax.lax.fori_loop(0,mdp.H+100,body_fn,visit)
    visit = visit.reshape(S,A)

    return visit

@jax.jit
def compute_policy_visit_s(mdp: MDP, policy: jnp.ndarray, init_dist: jnp.ndarray):
    S, A = policy.shape
    chex.assert_shape(policy, (mdp.S, mdp.A))
    Pi = compute_policy_matrix(policy)
    PiP = Pi @ mdp.P.reshape(S*A, S) 

    def backup(visit):
        next_visit = mdp.gamma * visit @ PiP
        return init_dist + next_visit
    
    body_fn = lambda i, visit: backup(visit)
    visit = jnp.zeros(S)
    visit = jax.lax.fori_loop(0, mdp.H + 100, body_fn, visit)
    return visit


optimal_Q_DP = compute_optimal_Q(mdp)
mdp = mdp._replace(optimal_Q=optimal_Q_DP)


---

普通のQ学習とrobustなQ学習との比較のために二つのアルゴリズムを実装

**Q学習**

In [35]:
import time
@partial(jax.jit,static_argnames=('T'))
def Q_learning(mdp:MDP,T:int,key:PRNGKey,lr:float=0.1,epsilon:float=0.0):
    S,A = mdp.S,mdp.A

    def body_fn(n,args):
        key,s,Q = args

        #epsilon-greedy
        a = Q[s].argmax()
        key,key1,key2 = jax.random.split(key,3)
        random_A = jax.random.choice(key1,A)
        a = jnp.where(jax.random.uniform(key2) < epsilon,random_A,a)

        #遷移
        key,key1 = jax.random.split(key,2)
        next_s = jax.random.choice(key1,mdp.S_array,p=mdp.P[s,a])

        next_V = Q[next_s].max(axis = -1)
        Q_targ = mdp.rew[s,a] + mdp.gamma * next_V
        Q_targ = (1-lr) * Q[s,a] + lr * Q_targ
        Q = Q.at[s,a].set(Q_targ)

        return key,next_s,Q
    
    Q = jnp.zeros((S,A))
    key,new_key = jax.random.split(key)
    init_s = jax.random.choice(new_key,mdp.S_array,p=mdp.mu)
    args = (key,init_s,Q)
    key,_,Q = jax.lax.fori_loop(0,T,body_fn,args)
    return key,Q

key = jax.random.PRNGKey(0)
start_time = time.time()
key, Q = Q_learning(mdp, 100000, key, lr=0.1, epsilon=0.3)
end_time = time.time()

greedy_policy = compute_greedy_policy(Q)
error = mdp.optimal_Q - compute_policy_Q(mdp, greedy_policy)
print(f'loop回すまでにかかった時間は{end_time - start_time}秒です')

error.max()




loop回すまでにかかった時間は1.1108109951019287秒です


Array(1.04904175e-05, dtype=float32)

**ロバストQ習**

In [36]:
@partial(jax.jit,static_argnames=('T'))
def Robust_Q_learning(mdp:MDP,T:int,key:PRNGKey,R:float,lr:float=0.1,epsilon:float=0.0):
    S,A = mdp.S,mdp.A

    def body_fn(n,args):
        key,s,Q = args

        a = Q[s].argmax()
        key,key1,key2 = jax.random.split(key,3)
        random_a = jax.random.choice(key1,A)
        a = jnp.where(jax.random.uniform(key2) < epsilon,random_a,a)

        #遷移
        key,key1 = jax.random.split(key,2)
        next_s = jax.random.choice(key1,mdp.S_array,p=mdp.P[s,a])

        next_V = Q[next_s].max(axis = -1)
        worst_V = Q.max(axis=1).min()
        Q_targ = mdp.rew[s,a] + mdp.gamma *  (1-R) * next_V + mdp.gamma * R * worst_V
        Q_targ = (1-lr) * Q[s,a] + lr * Q_targ
        Q = Q.at[s,a].set(Q_targ)

        return key,next_s,Q
    Q = jnp.zeros((S,A))
    key,new_key = jax.random.split(key)
    init_s = jax.random.choice(new_key,mdp.S_array,p=mdp.mu)
    args = (key,init_s,Q)
    key,_,Q = jax.lax.fori_loop(0,T,body_fn,args)
    return key,Q

key = jax.random.PRNGKey(0)
start_time = time.time()
key, robust_Q = Robust_Q_learning(mdp, 100000, key, R=0.4, lr=0.1, epsilon=0.3)
end_time = time.time()
robust_greedy_policy = compute_greedy_policy(Q)
error = mdp.optimal_Q - compute_policy_Q(mdp, greedy_policy)
print(f'loop回すまでにかかった時間は{end_time - start_time}秒です')
error.max()

loop回すまでにかかった時間は1.1249217987060547秒です


Array(1.04904175e-05, dtype=float32)

---

実際のロバストな方策が手に入ったのかを確認するために，元のMDPの遷移確率に摂動を入れてどのくらいロバストな方策なのか，確認する．

In [37]:
@jax.jit
def perturb_mdp(key, mdp: MDP, R: float):
    key, _ = jax.random.split(key)
    P = jax.random.uniform(key=key, shape=(S*A, S))
    P = P / jnp.sum(P, axis=-1, keepdims=True)  # 正規化して確率にします
    P = P.reshape(S, A, S)
    perturbed_mdp = mdp._replace(P=(1 - R) * mdp.P + R * P)

    optimal_Q_DP = compute_optimal_Q(perturbed_mdp)
    perturbed_mdp = perturbed_mdp._replace(optimal_Q=optimal_Q_DP)
    return key, perturbed_mdp

key, perturbed_mdp = perturb_mdp(key, mdp, 0.7)
np.testing.assert_allclose(perturbed_mdp.P.sum(axis=-1), 1, atol=1e-6) 

In [38]:

greedy_policy = compute_greedy_policy(Q)
error = perturbed_mdp.optimal_Q - compute_policy_Q(perturbed_mdp, greedy_policy)
print(f'摂動を入れたMDPでのQ学習の誤差は{error.max()}です')

摂動を入れたMDPでのQ学習の誤差は0.10248088836669922です


In [39]:
greedy_policy = compute_greedy_policy(robust_Q)
error = perturbed_mdp.optimal_Q - compute_policy_Q(perturbed_mdp, greedy_policy)
print(f'摂動を入れたMDPでのロバストQ学習の誤差は{error.max()}です')

摂動を入れたMDPでのロバストQ学習の誤差は0.0です


ロバストQ学習で学習された方策の方が違うMDPで試してもうまくいってそうで

す