In [1]:
import numpy as np
from functools import partial
from typing import Optional,NamedTuple
import jax.numpy as jnp
from jax.random import PRNGKey
import jax

class MDP(NamedTuple):
    S_array : jnp.ndarray
    A_array : jnp.ndarray
    P: jnp.ndarray
    rew: jnp.ndarray
    gamma : float
    H : int
    optimal_Q : Optional[jnp.ndarray] = None
    robust_optimal_Q : Optional[jnp.ndarray] = None

    @property
    def S(self):
        return len(self.S_array)
    
    @property
    def A(self):
        return  len(self.A_array)
    

S = 5
A = 3

#報酬と遷移確率の行列を作る．
key = PRNGKey(0)
key,_ = jax.random.split(key)
rew = jax.random.uniform(key,shape=(S,A))


P = np.random.rand(S*A, S) #遷移確率
P = P / np.sum(P,axis=-1,keepdims=True)
P = P.reshape(S,A,S)
np.testing.assert_allclose(P.sum(axis=-1),1)
# P = jnp.array(P)



S_array = jnp.arange(S)
A_array = jnp.arange(A)
gamma = 0.99
horizon = 1 / (1-gamma)
horizon = int(horizon)

mdp = MDP(S_array,A_array,P,rew,gamma,horizon)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [7]:
from functools import partial

#割引状態訪問率
@partial(jax.jit,static_argnames=('h'))
def compute_discounted_visitation(mdp:MDP,policy:np.ndarray,h:int):
    S,A = policy.shape
    P = mdp.P.reshape(S*A,S)

    visits_dist = jnp.zeros((S,h))
    visits_dist = visits_dist.at[0,0].set(1.0) #初期状態分布の挿入

    def backup(i,visits_dist):
        visit = visits_dist[i]
        visit_policy = (visit.reshape(S,1) * policy).reshape(S*A)
        next_visit = visits_dist[:,0] + mdp.gamma * visit_policy @ P
        visits_dist = visits_dist.at[i+1].set(next_visit)
        return visits_dist
    visits_dist = jax.lax.fori_loop(0,h-1,backup,visits_dist)

    return visits_dist

@jax.jit
def compute_greedy_policy(Q: np.ndarray):
    greedy_policy = jnp.zeros_like(Q)
    S, A = Q.shape
    greedy_policy = greedy_policy.at[jnp.arange(S), Q.argmax(axis=1)].set(1)
    
    return greedy_policy



@partial(jax.jit, static_argnames=("S", "A"))
def _compute_optimal_Q(mdp: MDP, S: int, A: int):

    def backup(optimal_Q):
        greedy_policy = compute_greedy_policy(optimal_Q)
        max_Q = (greedy_policy * optimal_Q).sum(axis=1)
        next_v = mdp.P @ max_Q
        assert next_v.shape == (S, A)
        return mdp.rew + mdp.gamma * next_v
    
    optimal_Q = jnp.zeros((S, A))
    body_fn = lambda i, Q: backup(Q)
    return jax.lax.fori_loop(0, mdp.horizon, body_fn, optimal_Q)

compute_optimal_Q = lambda mdp: _compute_optimal_Q(mdp, mdp.S, mdp.A)
@jax.jit
def compute_policy_Q(mdp:MDP,policy:np.ndarray):
    S,A = policy.shape
    
    policy_Q = jnp.zeros((S,A))

    def backup(policy_Q):
        V = (policy * policy_Q).sum(axis=-1)
        V = mdp.P @ V

        return mdp.rew + mdp.gamma * V
    
    body_fn = lambda i,Q:backup(Q)
    policy_Q = jax.lax.fori_loop(0,mdp.H,body_fn,policy_Q)


方策を直接パラメータ化する方法と、ソフトマックス関数を使いパラメータ化する方法があるのですが、まずはソフトマックス関数でパラメータ化していきたいと思います。

まず、価値関数の勾配は次のように表せます。

$$
\nabla_\theta V^{\pi_\theta}\left(s_0\right)=\frac{1}{1-\gamma} \mathbb{E}_{s \sim d_{s_0}^{\pi_\theta}} \mathbb{E}_{a \sim \pi_\theta(\cdot \mid s)}\left[\nabla_\theta \log \pi_\theta(a \mid s) A^{\pi_\theta}(s, a)\right] .
$$

In [29]:
key = jax.random.PRNGKey(0)
theta = jax.random.uniform(key,shape=(S,A))

def log_pi(theta:np.ndarray,s:int,a:int):
    return jnp.log(jax.nn.softmax(theta)[s,a])

# def compute_grad_log_pi(theta:np.ndarray,s,a):
#     # s,a = mdp.S_array,mdp.A_array
#     grad_log_pi = jax.vmap(jax.grad(log_pi),in_axes=(None,0,0))(theta,s,a)

#     return grad_log_pi
def compute_grad_log_pi(theta, s,a):
    
    grad_log_pi = jax.vmap(jax.grad(log_pi), in_axes=(None, 0, 0))(theta, s, a)
    return grad_log_pi

S_array = mdp.S_array
S_array = jnp.tile(S_array, mdp.A).reshape(-1, 1)

A_array = mdp.A_array
A_array = jnp.repeat(A_array, mdp.S).reshape(-1, 1)

SA = jnp.hstack([S_array, A_array])

grad = compute_grad_log_pi(theta, S_array,A_array)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (1,).

In [17]:
S_array = mdp.S_array
S_array = jnp.tile(S_array, mdp.A).reshape(-1, 1)

In [18]:
S_array

Array([[0],
       [1],
       [2],
       [3],
       [4],
       [0],
       [1],
       [2],
       [3],
       [4],
       [0],
       [1],
       [2],
       [3],
       [4]], dtype=int32)

In [19]:
A_array = mdp.A_array
A_array = jnp.repeat(A_array, mdp.S).reshape(-1, 1)

In [20]:
A_array

Array([[0],
       [0],
       [0],
       [0],
       [0],
       [1],
       [1],
       [1],
       [1],
       [1],
       [2],
       [2],
       [2],
       [2],
       [2]], dtype=int32)