In [1]:
import numpy as np

An MDP example with state space $S=\{s_0,s_1,s_2\}$, an action space $A=\{a_0,a_1\}$. The reward functions $R_a(s,s')$ and transition functions $P_a(s,s')$ could be listed as tables.

![image.png](attachment:image.png)

### the transition and reward functions

In [60]:
P0 = np.array([
    [0.5, 0, 0.5],
    [0.7, 0.1, 0.2],
    [0.4, 0, 0.6]])

R0 = np.array([
    [0, 0, 0],
    [5, 0, 0],
    [0, 0, 0]])

P1 = np.array([
    [0, 0, 1],
    [0, 0.95, 0.05],
    [0.3, 0.3, 0.4]])

R1 = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [-1, 0, 0]])

### 1. Value iteration

In [80]:
# initialize the state value function
Q = np.zeros((3,1))
gamma = 0.9

for i in range(10):
    Q0 = (R0 * P0).sum(1).reshape(-1,1) + \
            gamma * (P0 @ Q.reshape(-1, 1))

    Q1 = (R1 * P1).sum(1).reshape(-1,1) + \
            gamma * (P1 @ Q.reshape(-1, 1))

    Q = np.maximum(Q0, Q1)
    Q -= np.mean(Q)
    print (f"--- iter {i} ---")
    print (f"choice: {((Q0 < Q1).flatten())}")
    print (f"Q: {Q.flatten()}")
    print ()

--- iter 0 ---
choice: [False False False]
Q: [-1.16666667  2.33333333 -1.16666667]

--- iter 1 ---
choice: [False False  True]
Q: [-1.48666667  2.32833333 -0.84166667]

--- iter 2 ---
choice: [ True False  True]
Q: [-1.25356667  2.12538333 -0.87181667]

--- iter 3 ---
choice: [ True False  True]
Q: [-1.31180567  2.21743983 -0.90563417]

--- iter 4 ---
choice: [ True False  True]
Q: [-1.3195841   2.20560452 -0.88602042]

--- iter 5 ---
choice: [ True False  True]
Q: [-1.30759256  2.19750857 -0.88991602]

--- iter 6 ---
choice: [ True False  True]
Q: [-1.31185465  2.20287734 -0.89102268]

--- iter 7 ---
choice: [ True False  True]
Q: [-1.31168508  2.20164178 -0.8899567 ]

--- iter 8 ---
choice: [ True False  True]
Q: [-1.31113997  2.20141502 -0.89027504]

--- iter 9 ---
choice: [ True False  True]
Q: [-1.31140999  2.2016972  -0.89028721]



### 2. Policy iteration

In [79]:
# initialize the policy (all choose a0)
A = np.array([0, 0, 0])
# initialize state value function randomly
Q = np.random.random((3,1))
gamma = 0.9
    
for i in range(5):
    print (f"--- iter {i} ---")
    
    Q0 = (R0 * P0).sum(1).reshape(-1,1) + \
            gamma * (P0 @ Q.reshape(-1, 1))

    Q1 = (R1 * P1).sum(1).reshape(-1,1) + \
            gamma * (P1 @ Q.reshape(-1, 1))

    # policy evaluation
    Q = Q0 * (1 - A).reshape(-1, 1) + Q1 * A.reshape(-1, 1) 
    Q -= Q.mean()
    print (f"cur state valie: {Q.flatten()}")
    
    
    Q0 = (R0 * P0).sum(1).reshape(-1,1) + \
            gamma * (P0 @ Q.reshape(-1, 1))

    Q1 = (R1 * P1).sum(1).reshape(-1,1) + \
            gamma * (P1 @ Q.reshape(-1, 1))

    # policy improvement
    A = (Q0 < Q1).astype("int").flatten()
    print (f"policy: {A}")
    print ()

--- iter 0 ---
cur state valie: [-1.17204968  2.32835092 -1.15630125]
policy: [1 0 1]

--- iter 1 ---
cur state valie: [-1.4801004   2.32359679 -0.84349639]
policy: [1 0 1]

--- iter 2 ---
cur state valie: [-1.25573664  2.12824121 -0.87250457]
policy: [1 0 1]

--- iter 3 ---
cur state valie: [-1.31178654  2.21684438 -0.90505784]
policy: [1 0 1]

--- iter 4 ---
cur state valie: [-1.31927632  2.2054558  -0.88617947]
policy: [1 0 1]

