# 2. 어떻게 가치를 평가할 것인가?
## Policy Iteration
### Example : Grid World
#### Source
1. 예제 및 Pseudo Code from [Sutton et al. Reinforcement Learning: An Introduction]
2. Original Code from [https://github.com/dennybritz/reinforcement-learning]

<img src="img/gird_ex.png" width="50%" height="50%" title="grid_world" alt="grid_world"></img>

In [1]:
import numpy as np
import sys
import random
from utils import GridworldEnv

random.seed(100)
np.random.seed(100)

In [2]:
def one_step_lookahead(state, discount_factor, V):
    """주어진 상태(state) s에서 이동하는 다음 상태 s'에 대한 value V(s')에 대해서 알고 싶음"""
    A = np.zeros(env.nA)
    for a in range(env.nA):
        for prob, next_state, reward, done in env.P[state][a]:
            A[a] += prob * (reward + discount_factor * V[next_state])
    return A

<img src="img/pi.png" width="50%" height="50%" title="policy iteration" alt="policy iteration"></img>

In [3]:
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001, cnt_eval=None):

    V = np.zeros(env.nS)  # value table을 0으로 채움    
    cnt_iter = 0

    while True:
        delta = 0  # delta가 theta보다 작아질때까지 for loop 진행
    
        for s in range(env.nS):  # 모든 state에 대해서 table을 채워넣을 것임
            v = 0

            for a, action_prob in enumerate(policy[s]):  # 해당 state에 대해서 선택할 수 있는 action

                for  prob, next_state, reward, done in env.P[s][a]:  # MDP dynamics를 알고 있으므로...
                    v += action_prob * prob * (reward + discount_factor * V[next_state])

            delta = max(delta, np.abs(v - V[s]))
            V[s] = v
        
        print(f"{cnt_eval}th Policy Iteration {cnt_iter}: ") if cnt_eval is not None else None
        print(V.reshape(env.shape))
        print("\n")
        cnt_iter +=1
        
        if delta < theta:
            print('='*50)
            break

    return np.array(V)

In [4]:
def policy_iteration(env, policy_eval_fn=policy_eval, discount_factor=1.0):    
    policy = np.ones([env.nS, env.nA]) / env.nA  # random policy
    cnt_eval = 0
    while True:
        cnt_eval +=1
        V = policy_eval_fn(policy, env, discount_factor, cnt_eval=cnt_eval)
        policy_stable = True
        
        for s in range(env.nS):
            chosen_a = np.argmax(policy[s])  # 이전 iter의 Value State에서 얻은 최적 action
            action_values = one_step_lookahead(s, discount_factor, V)
            best_a = np.argmax(action_values)  # Updated Value State에서 얻은 최적 action
            
            # Greedily update the policy
            if chosen_a != best_a:
                policy_stable = False
            policy[s] = np.eye(env.nA)[best_a]
        
        # If the policy is stable we've found an optimal policy. Return it
        if policy_stable:
            return policy, V

In [5]:
env = GridworldEnv()
policy, v = policy_iteration(env)

1th Policy Iteration 0: 
[[ 0.        -1.        -1.25      -1.3125   ]
 [-1.        -1.5       -1.6875    -1.75     ]
 [-1.25      -1.6875    -1.84375   -1.8984375]
 [-1.3125    -1.75      -1.8984375  0.       ]]


1th Policy Iteration 1: 
[[ 0.         -1.9375     -2.546875   -2.73046875]
 [-1.9375     -2.8125     -3.23828125 -3.40429688]
 [-2.546875   -3.23828125 -3.56835938 -3.21777344]
 [-2.73046875 -3.40429688 -3.21777344  0.        ]]


1th Policy Iteration 2: 
[[ 0.         -2.82421875 -3.83496094 -4.17504883]
 [-2.82421875 -4.03125    -4.7097168  -4.87670898]
 [-3.83496094 -4.7097168  -4.96374512 -4.26455688]
 [-4.17504883 -4.87670898 -4.26455688  0.        ]]


1th Policy Iteration 3: 
[[ 0.         -3.67260742 -5.0980835  -5.58122253]
 [-3.67260742 -5.19116211 -6.03242493 -6.18872833]
 [-5.0980835  -6.03242493 -6.14849091 -5.15044403]
 [-5.58122253 -6.18872833 -5.15044403  0.        ]]


1th Policy Iteration 4: 
[[ 0.         -4.49046326 -6.30054855 -6.91293049]
 [-4.4904632

In [6]:
print("Policy Probability Distribution:")
print(policy)
print("")

Policy Probability Distribution:
[[1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]



In [7]:
print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]



In [8]:
print("Value Function:")
print(v)
print("")


Value Function:
[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.]



In [9]:
print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")

Reshaped Grid Value Function:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]

