# サンプルベースのCMDPにおけるNatural Policy Gradient (モデルフリー)



---


In [6]:
import numpy as np
from docplex.mp.model import Model

# MDPの構築
from typing import NamedTuple ,Optional
np.random.seed(10)

S = 20 # 状態数
A = 10 # 行動数
S_set = np.arange(S)
A_set = np.arange(A)
gamma = 0.9 # 割引率

rew = np.random.uniform(0,1,size=(S,A)) # 報酬
rew = np.array(rew)

utility = np.random.uniform(0,1,size=(S,A))
utility = np.array(utility)


P = np.random.rand(S,A,S) # 遷移確率
P = P.reshape(S*A,S)
P = P/np.sum(P,axis=1,keepdims=True) # 正規化
P = P.reshape(S,A,S)
np.testing.assert_allclose(P.sum(axis=-1), 1, atol=1e-6)
rho = np.ones(S) /S
b = 3

class CMDP(NamedTuple):
    S_set: np.ndarray
    A_set: np.ndarray
    rew: np.ndarray
    utility: np.ndarray
    P: np.ndarray
    gamma: float
    H: int
    rho : np.ndarray
    b : int


    optimal_V: Optional[np.ndarray] = None

    @property
    def S(self):
        return len(self.S_set)

    @property
    def A(self):
        return len(self.A_set)

H = int (1/(1-gamma) + 100)
cmdp = CMDP(S_set,A_set,rew,utility,P,gamma,H,rho,b)

In [34]:
def theta_to_policy(theta: np.ndarray, cmdp: CMDP) -> np.ndarray:
    """θから方策を計算する"""
    '''[pi_(s1,a1),pi_(s1,a2),pi_(s1,a3),...]'''
    # ここを実装
    s = cmdp.S
    a = cmdp.A
    policy = []
    theta = theta - np.amax(theta)
    for i in range(s):
        norm = np.sum(np.exp(theta[a*i:a*(i+1)]))
        for j in range(a*i,a*(i+1)):
            policy.append(np.exp(theta[j])/norm)


    return np.array(policy)

def get_Pi(prob : np.ndarray,cmdp:CMDP) -> np.ndarray:

    Pi = np.zeros((cmdp.S, cmdp.S * cmdp.A))
    for i in range(cmdp.S):
        Pi[i, i * cmdp.A:(i + 1) * cmdp.A] = prob[i * cmdp.A:(i + 1) * cmdp.A]

    return Pi

def V_from_Q(qvals:np.ndarray,prob:np.ndarray,rho:np.ndarray,cmdp:CMDP) -> np.ndarray:

    V = np.zeros(cmdp.S)
    for i in range(cmdp.S):
        for j in range(cmdp.A):
            V[i] += qvals[i*cmdp.A+j] * prob[i*cmdp.A+j]

    v_rho = np.dot(V,rho)
    return v_rho

#価値関数の推定
def Q_value_estimate(cmdp:CMDP,policy:np.ndarray,uti:np.ndarray,Ksample:int):
    s,a = cmdp.S,cmdp.A
    gamma = cmdp.gamma

    q_estimate = np.zeros((s,a))
    policy = policy.reshape(s,a)

    for _ in range(Ksample):
        qest = np.zeros((s,a))
        for i in range(s):
            for j in range(a):
                qest[i,j] = uti[i,j] #初期位置をs,aに設定してるため
                length = np.random.geometric(p=1-gamma,size=1)
                init_s = np.random.choice(cmdp.S_set,1,p=cmdp.P[i,j,:])
                state = init_s[0]

                for _ in range(length[0] - 1):
                    action = np.random.choice(cmdp.A_set,1,p=policy[state])[0]
                    qest[state,action] += uti[state,action]
                    state = np.random.choice(cmdp.S_set,1,p=cmdp.P[state,action,:])[0]
        q_estimate += qest

    return q_estimate/Ksample


In [35]:
def proj(scalar):
    offset = 100
    if scalar < 0:
        scalar = 0
    
    if scalar > offset:
        scalar = offset
    
    return scalar

最適な価値関数を線形計画法で求める。


In [36]:
model = Model('CMDP')
idx = [(i,j) for i in range(cmdp.S) for j in range(cmdp.A)]
policy = model.continuous_var_dict(idx)

for s in range(cmdp.S):
    for a in range(cmdp.A):
        model.add_constraint(policy[(s,a)] >= 0)
        model.add_constraint(policy[(s,a)] <= 1)

model.add_constraint(model.sum(policy[(s,a)] * cmdp.utility[s,a] / (1-gamma) for s in range(cmdp.S) for a in range(cmdp.A)) >= b)


for s_next in range(cmdp.S):
    model.add_constraint(
        gamma * model.sum(policy[(s,a)] * cmdp.P[s,a,s_next] for s in range(cmdp.S) for a in range(cmdp.A)) 
        + (1 - cmdp.gamma) * cmdp.rho[s_next] == model.sum(policy[(s_next,a_next)] for a_next in range(cmdp.A))
    )


model.maximize(model.sum(policy[(s,a)] * cmdp.rew[s,a]/(1-gamma) for s in range(cmdp.S) for a in range(cmdp.A)))


In [37]:
solution = model.solve()

In [38]:
solution.get_objective_value()

8.552624829749226

In [49]:
from tqdm import tqdm
N = 300
theta = np.random.uniform(0,1,size=(cmdp.S*cmdp.A))
dual = 0
gap = []
violation = []
avg_gap = 0
avg_violation = 0
div_number = 0
K_samples = 30
step = 0.2
dual_step = 0.2
S = cmdp.S
A = cmdp.A
vio = 0
for t in tqdm(range(N)):
    policy = theta_to_policy(theta,cmdp)
    Pi = get_Pi(policy,cmdp)
    #逆行列から価値関数を求める。
    P = cmdp.P.reshape(S*A,S)
    mat = np.identity(S*A) - gamma * np.matmul(P,Pi)

    #価値関数の計算
    qr_val = np.matmul(np.linalg.inv(mat),cmdp.rew.reshape(S*A))
    qr_val = qr_val.reshape(S,A)

    qg_val = np.matmul(np.linalg.inv(mat),cmdp.utility.reshape(S*A))
    qg_val = qg_val.reshape(S,A)

    q_est = Q_value_estimate(cmdp,policy,cmdp.rew+dual*cmdp.utility,K_samples)
    v_est = np.zeros(S)
    
    policy = policy.reshape(S,A)
    v_est = (policy * q_est).sum(axis=1)
    v_est = v_est.reshape(20,1)
    v_est = np.tile(v_est,10) 
    #advantage 
    adv = q_est - v_est

    qg_est = Q_value_estimate(cmdp,policy,cmdp.utility,K_samples)
    vg_est = np.zeros(S)
    vg_est = (policy * qg_est).sum(axis=1)
    
    vg_val = np.dot(vg_est.T,cmdp.rho)
    # print(vg_val)
    if vg_val < b:
        vio += 1
    # print(adv.shape)
    theta = theta.reshape(S,A)

    theta += step * adv/(1-cmdp.gamma)
    theta = theta.reshape(S*A)
    dual = proj(dual - dual_step * (vg_val-b))

    if t % 5 == 0:
        print(f'vioaltionしてる数は{vio}')
        q = np.sum((policy * qr_val).sum())
        q_dash = sum(policy[(s,a)] * qr_val[s,a]/(1-gamma) for s in range(cmdp.S) for a in range(cmdp.A))
        gap.append(solution.get_objective_value() - q/(1-gamma))
        print(f'最適な価値との差は{gap[-1]}')
        print(f'最適な価値{solution.get_objective_value() - q_dash}')


  0%|          | 1/300 [00:04<23:20,  4.68s/it]

vioaltionしてる数は0
最適な価値との差は-943.0450359018492
最適な価値-943.0450359018489


  2%|▏         | 6/300 [00:27<22:03,  4.50s/it]

vioaltionしてる数は0
最適な価値との差は-1611.323198949993
最適な価値-1611.3231989499927


  4%|▎         | 11/300 [00:50<22:08,  4.60s/it]

vioaltionしてる数は0
最適な価値との差は-1611.323199177711
最適な価値-1611.323199177711


  5%|▌         | 16/300 [01:12<21:29,  4.54s/it]

vioaltionしてる数は0
最適な価値との差は-1611.323199177711
最適な価値-1611.323199177711


  6%|▌         | 17/300 [01:20<22:17,  4.72s/it]


KeyboardInterrupt: 

In [30]:
v = v.reshape(20,1)
v = np.tile(v,10)

In [31]:
v

array([[0.7778335 , 0.7778335 , 0.7778335 , 0.7778335 , 0.7778335 ,
        0.7778335 , 0.7778335 , 0.7778335 , 0.7778335 , 0.7778335 ],
       [0.73388228, 0.73388228, 0.73388228, 0.73388228, 0.73388228,
        0.73388228, 0.73388228, 0.73388228, 0.73388228, 0.73388228],
       [0.89310588, 0.89310588, 0.89310588, 0.89310588, 0.89310588,
        0.89310588, 0.89310588, 0.89310588, 0.89310588, 0.89310588],
       [0.72414992, 0.72414992, 0.72414992, 0.72414992, 0.72414992,
        0.72414992, 0.72414992, 0.72414992, 0.72414992, 0.72414992],
       [0.00669741, 0.00669741, 0.00669741, 0.00669741, 0.00669741,
        0.00669741, 0.00669741, 0.00669741, 0.00669741, 0.00669741],
       [0.18333458, 0.18333458, 0.18333458, 0.18333458, 0.18333458,
        0.18333458, 0.18333458, 0.18333458, 0.18333458, 0.18333458],
       [0.77063825, 0.77063825, 0.77063825, 0.77063825, 0.77063825,
        0.77063825, 0.77063825, 0.77063825, 0.77063825, 0.77063825],
       [0.89934335, 0.89934335, 0.8993433

In [23]:
a = np.tile(v,10)
a.reshape(20,10)

array([[0.71717967, 0.11696971, 0.35789644, 0.16072117, 0.60168587,
        0.06927592, 0.10500928, 0.78794715, 0.70932358, 0.90797541],
       [0.81023054, 0.09194707, 0.38310912, 0.4388835 , 0.99063364,
        0.03733071, 0.00248129, 0.0309497 , 0.3711929 , 0.8826824 ],
       [0.71717967, 0.11696971, 0.35789644, 0.16072117, 0.60168587,
        0.06927592, 0.10500928, 0.78794715, 0.70932358, 0.90797541],
       [0.81023054, 0.09194707, 0.38310912, 0.4388835 , 0.99063364,
        0.03733071, 0.00248129, 0.0309497 , 0.3711929 , 0.8826824 ],
       [0.71717967, 0.11696971, 0.35789644, 0.16072117, 0.60168587,
        0.06927592, 0.10500928, 0.78794715, 0.70932358, 0.90797541],
       [0.81023054, 0.09194707, 0.38310912, 0.4388835 , 0.99063364,
        0.03733071, 0.00248129, 0.0309497 , 0.3711929 , 0.8826824 ],
       [0.71717967, 0.11696971, 0.35789644, 0.16072117, 0.60168587,
        0.06927592, 0.10500928, 0.78794715, 0.70932358, 0.90797541],
       [0.81023054, 0.09194707, 0.3831091