In [158]:
import numpy as np

# MDPの構築
from typing import NamedTuple ,Optional
np.random.seed(10)

S = 5# 状態数
A = 3 # 行動数
S_set = np.arange(S)
A_set = np.arange(A)
gamma = 0.9 # 割引率

rew = np.random.uniform(0,1,size=(S,A)) # 報酬
rew = np.array(rew)

P = np.random.rand(S,A,S) # 遷移確率
P = P.reshape(S*A,S)
P = P/np.sum(P,axis=1,keepdims=True) # 正規化
P = P.reshape(S,A,S)
np.testing.assert_allclose(P.sum(axis=-1), 1, atol=1e-6)

class MDP(NamedTuple):
    S_set: np.ndarray
    A_set: np.ndarray
    rew: np.ndarray
    P: np.ndarray
    gamma: float
    H: int
    K:int


    optimal_V: Optional[np.ndarray] = None

    @property
    def S(self):
        return len(self.S_set)

    @property
    def A(self):
        return len(self.A_set)

H = int (1/(1-gamma) + 10)
mdp = MDP(S_set,A_set,rew,P,gamma,H,K=10)

In [159]:
def sampler(mdp: MDP, policy,s,h):
    """状態sで行動aをとり、次の状態と報酬を返す"""
    a = policy[h,s]
    a = int(a)
    next_s = np.random.choice(mdp.S_set, p=mdp.P[s, a])
    rew = mdp.rew[s, a]
    return next_s, rew

In [160]:
def feature_func(n:int,delta:float,mdp:MDP):
    if n == 0:
        return 1
    return np.min([1,np.sqrt(0.52/n * (1.4*np.log(np.log(np.max([np.e,n]))) + np.log(26*mdp.S*mdp.A*(mdp.H+1+mdp.S)/delta)))])

In [161]:
def V_h_max(h:int,mdp:MDP):
    return (mdp.H-h+1)

In [162]:
def V_std(P,V):
    # return np.sqrt(np.sum(P*V**2,axis=-1) - np.sum(P*V,axis=-1)**2)
    # return np.sqrt(np.sum(P*(V-P*V)**2,axis=-1))
    std = np.sqrt(np.sum(P*V**2,axis=-1) - np.sum(2*(P**2) * V**2,axis=-1) + np.sum(P**3 * V**2,axis=-1))
    return std

In [163]:
def ORLC(mdp:MDP,delta:float):
    n_k = np.zeros((mdp.S,mdp.A))
    n_k_p = np.zeros((mdp.S,mdp.A,mdp.S))
    r_k = np.zeros((mdp.S,mdp.A))
    P_k = np.zeros((mdp.S,mdp.A,mdp.S))
    V_lower = np.zeros((mdp.K,mdp.H,mdp.S))
    Q_lower = np.zeros((mdp.K,mdp.H-1,mdp.S,mdp.A))
    V_upper = np.zeros((mdp.K,mdp.H,mdp.S))
    Q_upper = np.zeros((mdp.K,mdp.H-1,mdp.S,mdp.A))
    gzi = np.zeros((mdp.K,mdp.H,mdp.S,mdp.A))
    policy = np.zeros((mdp.K,mdp.H,mdp.S))
    epsilon = np.zeros(mdp.K)
    experience = []
    for k in range(mdp.K):
        if k != 0:
            for h in range(mdp.H):
                s, a, r, s_dash = experience[h]
                n_k[s, a] += 1
                n_k_p[s, a, s_dash] += 1
                r_k[s, a] += r
                P_k[s, a, s_dash] = n_k_p[s, a, s_dash] / n_k[s, a]
        for h in reversed(range(mdp.H)):
            for s in range(mdp.S):
                for a in range(mdp.A):
                    # print(V_std(P_k[s,a],V_upper[k,h]))
                    gzi[k,h-1,s,a] = (1 + np.sqrt(12)*V_std(P_k[s,a],V_upper[k,h,s]) * feature_func(n_k[s,a],delta,mdp)) + 45*mdp.S*mdp.H**2 * feature_func(n_k[s,a],delta,mdp) **2 + 1/mdp.H *np.sum(P_k[s,a]*(V_upper[k,h] - V_lower[k,h]))
                    Q_upper[k,h-1,s,a] = np.max([0,r_k[s,a] + np.sum(P_k[s,a,:]*V_upper[k,h,:]) + gzi[k,h-1,s,a]])
                    Q_upper[k,h-1,s,a] = np.min([V_h_max(h,mdp),Q_upper[k,h-1,s,a]])
                    Q_lower[k,h-1,s,a] = np.max([0,r_k[s,a] + np.sum(P_k[s,a]*V_lower[k,h]) - gzi[k,h-1,s,a]])
                    Q_lower[k,h-1,s,a] = np.min([V_h_max(h,mdp),Q_lower[k,h-1,s,a]])
                    
        for h in range(mdp.H):
            for s in range(mdp.S):
                policy[k,h-1,s] = np.argmax(Q_upper[k,h-1,s])
                a = policy[k,h-1,s]
                a = int(a)
                V_upper[k,h-1,s] = Q_upper[k,h-1,s,a]
                V_lower[k,h-1,s] = Q_lower[k,h-1,s,a]
                
        epsilon[k] = np.abs(V_upper[k,0,0] - V_lower[k,0,0])
        s = 0
        experience = []
        for h in range(mdp.H):
            s_dash,rew = sampler(mdp,policy[k],s,h)
            experience.append([s,int(policy[k,h,s]),rew,s_dash])
            s=s_dash
    return policy,epsilon,Q_upper,Q_lower,V_upper,V_lower

In [164]:
def ORLC(mdp:MDP,delta:float):
    n_k = np.zeros((mdp.S,mdp.A))
    n_k_p = np.zeros((mdp.S,mdp.A,mdp.S))
    r_hat_k = np.zeros((mdp.S,mdp.A))
    r_k = np.zeros((mdp.S,mdp.A))
    P_k = np.zeros((mdp.S,mdp.A,mdp.S))
    V_lower = np.ones((mdp.H,mdp.S))
    Q_lower = np.ones((mdp.H-1,mdp.S,mdp.A))
    V_upper = np.ones((mdp.H,mdp.S))
    Q_upper = np.ones((mdp.H-1,mdp.S,mdp.A))
    gzi = np.zeros((mdp.H,mdp.S,mdp.A))
    policy = np.zeros((mdp.H,mdp.S))
    epsilon = np.zeros(mdp.K)
    experience = []
    for k in range(mdp.K):
        if k != 0:
            for h in range(mdp.H):
                s, a, r, s_dash = experience[h]
                n_k[s, a] += 1
                n_k_p[s, a, s_dash] += 1
                r_hat_k[s, a] += r
                r_k[s,a] = r_hat_k[s,a] / n_k[s,a]
                P_k[s, a, s_dash] = n_k_p[s, a, s_dash] / n_k[s, a]
        for h in reversed(range(mdp.H)):
            for s in range(mdp.S):
                for a in range(mdp.A):
                    # print(V_std(P_k[s,a],V_upper[k,h]))
                    gzi[h-1,s,a] = (1 + np.sqrt(12)*V_std(P_k[s,a],V_upper[h,s]) * feature_func(n_k[s,a],delta,mdp)) + 45*mdp.S*mdp.H**2 * feature_func(n_k[s,a],delta,mdp) **2 + 1/mdp.H *np.sum(P_k[s,a]*(V_upper[h] - V_lower[h]))
                    # print(feature_func(n_k[s,a],delta,mdp))
                    # gzi[h-1,s,a] = (1 + np.sqrt(12)*V_std(P_k[s,a],V_upper[h,s]) * feature_func(n_k[s,a],delta,mdp))+1/mdp.H *np.sum(P_k[s,a]*(V_upper[h] - V_lower[h]))
                    Q_upper[h-1,s,a] = np.max([0,r_k[s,a] + np.sum(P_k[s,a,:]*V_upper[h,:]) + gzi[h-1,s,a]])
                    print(Q_upper[h-1,s,a])
                    Q_upper[h-1,s,a] = np.min([V_h_max(h,mdp),Q_upper[h-1,s,a]])
                    print(Q_upper[h-1,s,a])
                    print('------')
                    Q_lower[h-1,s,a] = np.max([0,r_k[s,a] + np.sum(P_k[s,a]*V_lower[h]) - gzi[h-1,s,a]])
                    Q_lower[h-1,s,a] = np.min([V_h_max(h,mdp),Q_lower[h-1,s,a]])
                    
        for h in range(mdp.H):
            for s in range(mdp.S):
                policy[h-1,s] = np.argmax(Q_upper[h-1,s],axis=-1)
                a = policy[h-1,s]
                a = int(a)
                V_upper[h-1,s] = Q_upper[h-1,s,a]
                V_lower[h-1,s] = Q_lower[h-1,s,a]
                
        epsilon[k] = np.abs(V_upper[0,0] - V_lower[0,0])
        s = 0
        experience = []
        for h in range(mdp.H):
            s_dash,rew = sampler(mdp,policy,s,h)
            experience.append([s,int(policy[h,s]),rew,s_dash])
            s=s_dash
    return policy,epsilon,Q_upper,Q_lower,V_upper,V_lower,r_k,P_k

In [165]:
def ORLC(mdp:MDP,delta:float):
    n_k = np.zeros((mdp.S,mdp.A))
    n_k_p = np.zeros((mdp.S,mdp.A,mdp.S))
    r_hat_k = np.zeros((mdp.S,mdp.A))
    r_k = np.zeros((mdp.S,mdp.A))
    P_k = np.zeros((mdp.S,mdp.A,mdp.S))
    V_lower = np.ones((mdp.H,mdp.S))
    Q_lower = np.ones((mdp.H-1,mdp.S,mdp.A))
    V_upper = np.ones((mdp.H,mdp.S))
    Q_upper = np.ones((mdp.H-1,mdp.S,mdp.A))
    gzi = np.zeros((mdp.H,mdp.S,mdp.A))
    policy = np.zeros((mdp.H,mdp.S))
    epsilon = np.zeros(mdp.K)
    experience = []
    for k in range(mdp.K):
        if k != 0:
            for h in range(mdp.H):
                s, a, r, s_dash = experience[h]
                n_k[s, a] += 1
                n_k_p[s, a, s_dash] += 1
                r_hat_k[s, a] += r
                r_k[s,a] = r_hat_k[s,a] / n_k[s,a]
                P_k[s, a, s_dash] = n_k_p[s, a, s_dash] / n_k[s, a]
        for h in reversed(range(mdp.H)):
            for s in range(mdp.S):
                for a in range(mdp.A):
                    # print(V_std(P_k[s,a],V_upper[k,h]))
                    
                    # print(feature_func(n_k[s,a],delta,mdp))
                    # gzi[h-1,s,a] = (1 + np.sqrt(12)*V_std(P_k[s,a],V_upper[h,s]) * feature_func(n_k[s,a],delta,mdp))+1/mdp.H *np.sum(P_k[s,a]*(V_upper[h] - V_lower[h]))
                    Q_upper[h-1,s,a] = np.max([0,r_k[s,a] + np.sum(P_k[s,a,:]*V_upper[h,:]) + (V_h_max(h,mdp) + 1) * feature_func(n_k[s,a],delta,mdp)])
                    print(Q_upper[h-1,s,a])
                    Q_upper[h-1,s,a] = np.min([V_h_max(h,mdp),Q_upper[h-1,s,a]])
                    print(Q_upper[h-1,s,a])
                    print('------')
                    Q_lower[h-1,s,a] = np.max([0,r_k[s,a] + np.sum(P_k[s,a]*V_lower[h]) - (2*np.sqrt(mdp.S)*V_h_max(h,mdp) + 1) * feature_func(n_k[s,a],delta,mdp)])
                    Q_lower[h-1,s,a] = np.min([V_h_max(h,mdp),Q_lower[h-1,s,a]])
                    
        for h in range(mdp.H):
            for s in range(mdp.S):
                policy[h-1,s] = np.argmax(Q_upper[h-1,s],axis=-1)
                a = policy[h-1,s]
                a = int(a)
                V_upper[h-1,s] = Q_upper[h-1,s,a]
                V_lower[h-1,s] = Q_lower[h-1,s,a]
                
        epsilon[k] = np.abs(V_upper[0,0] - V_lower[0,0])
        s = 0
        experience = []
        for h in range(mdp.H):
            s_dash,rew = sampler(mdp,policy,s,h)
            experience.append([s,int(policy[h,s]),rew,s_dash])
            s=s_dash
    return policy,epsilon,Q_upper,Q_lower,V_upper,V_lower,r_k,P_k

In [166]:
policy,epsilon,Qu,Ql,Vu,Vl,r,P = ORLC(mdp,0.1)

3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
3.0
2.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
4.0
3.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
5.0
4.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
6.0
5.0
------
7.0
6.0
------
7.0
6.0
------
7.0
6.0
------
7.0
6.0
------
7.0
6.0
------
7.0
6.0
------
7.0
6.0
--

In [167]:
epsilon

array([20., 20., 20., 20., 20., 20., 20., 20., 20., 20.])

In [154]:
policy

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0.