# Model-Based RL : Policy and Value Iteration using Dynamic Programming

In [0]:
# Download scripts
import urllib.request
urllib.request.urlretrieve ("https://gitlab.com/jongseokkim/rlplayground/raw/master/env/gridworld.py", "gridworld.py")

('gridworld.py', <http.client.HTTPMessage at 0x7faaba2293c8>)

In [0]:
# Import dependencies
import numpy as np
import sys
from gridworld import GridworldEnv
env = GridworldEnv()
print("# state:", env.nS, "# action:", env.nA)
state = 1 
action = 3 # 0: Up, 1: Right, 2: Down, 3: Left
print(env.P[state][action])

# state: 16 # action: 4
[(1.0, 0, -1.0, True)]


### Policy Evaluation

In [0]:
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    # Initialize value function
    V = np.zeros(env.nS)
    # Start iteration
    while True:
        delta = 0
        for s in range(env.nS):
            # Calculate new value
            n_V = 0
            for a, a_prob in enumerate(policy[s]):
                for prob, n_s, reward, is_done in env.P[s][a]:
                    n_V += a_prob * prob * (reward + discount_factor * V[n_s])
            delta = max(delta, abs(V[s] - n_V))
            # Update value
            V[s] = n_V
        if delta < theta:
            break
    return np.array(V)

In [0]:
# Main
random_policy = np.ones([env.nS, env.nA]) / env.nA
v = policy_eval(random_policy, env)
print("Value Function:")
for i in range(4):
    for j in range(4):
        print('%5.1f '%v[i*4 + j], end ="")
    print()

Value Function:
  0.0 -14.0 -20.0 -22.0 
-14.0 -18.0 -20.0 -20.0 
-20.0 -20.0 -18.0 -14.0 
-22.0 -20.0 -14.0   0.0 


### Policy Iteration

In [0]:
def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):
    # Initialize policy
    policy = np.ones([env.nS, env.nA]) / env.nA
    # Start iteration
    while True:
        policy_stable = True
        # Evaluate policy
        V = policy_eval_fn(policy, env, discount_factor)
        # Improve policy
        for s in range(env.nS):
            # Calculate Q function
            q = np.zeros(env.nA)
            for a in range(env.nA):
                for prob, n_s, reward, is_done in env.P[s][a]:
                    q[a] += prob * (reward + discount_factor * V[n_s])
            best_q = np.argmax(q)
            if best_q != np.argmax(policy[s]):
                policy_stable = False
            # Update policy
            policy[s] = np.eye(env.nA)[best_q]
        if policy_stable:
            break
            
    return policy, V

In [0]:
# Main
policy, v = policy_improvement(env)
# Print
print("Policy (0=up, 1=right, 2=down, 3=left):")
for i in range(4):
    for j in range(4):
        print('%5d '%np.argmax(policy[i*4 + j]), end ="")
    print()
    
print("Value Function:")
for i in range(4):
    for j in range(4):
        print('%5.1f '%v[i*4 + j], end ="")
    print()

Policy (0=up, 1=right, 2=down, 3=left):
    0     3     3     2 
    0     0     0     2 
    0     0     1     2 
    0     1     1     0 
Value Function:
  0.0  -1.0  -2.0  -3.0 
 -1.0  -2.0  -3.0  -2.0 
 -2.0  -3.0  -2.0  -1.0 
 -3.0  -2.0  -1.0   0.0 


### Value Iteration

In [0]:
def value_iteration(env, theta=0.0001, discount_factor=1.0):
    
    def get_q_function(s, V):
        q = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, n_s, reward, is_done in env.P[s][a]:
                q[a] += prob * (reward + discount_factor * V[n_s])
        return q
    # Initialize value function
    V = np.zeros(env.nS)
    # Start iteration
    while True:
        delta = 0
        for s in range(env.nS):
            n_V = 0
            # Calculate Q function
            q = get_q_function(s, V)
            n_V = np.max(q)
            delta = max(delta, abs(V[s] - n_V))
            # Update
            V[s] = n_V
        if delta < theta:
            break
            
    # Make policy using optimal value function
    policy = np.zeros([env.nS, env.nA])
    for s in range(env.nS):
        q = get_q_function(s, V)
        argmax_q = np.argmax(q)
        policy[s] = np.eye(env.nA)[argmax_q]

    return policy, V

In [0]:
# Main
policy, v = value_iteration(env)
# Print
print("Policy (0=up, 1=right, 2=down, 3=left):")
for i in range(4):
    for j in range(4):
        print('%5d '%np.argmax(policy[i*4 + j]), end ="")
    print()
    
print("Value Function:")
for i in range(4):
    for j in range(4):
        print('%5.1f '%v[i*4 + j], end ="")
    print()

Policy (0=up, 1=right, 2=down, 3=left):
    0     3     3     2 
    0     0     0     2 
    0     0     1     2 
    0     1     1     0 
Value Function:
  0.0  -1.0  -2.0  -3.0 
 -1.0  -2.0  -3.0  -2.0 
 -2.0  -3.0  -2.0  -1.0 
 -3.0  -2.0  -1.0   0.0 
