In [None]:
from numpy import *
import numpy.matlib as matlib
import itertools
import sys

In [None]:
DEBUG = True

GRID_HEIGHT = 4
GRID_WIDTH = 4

<img src="/files/4x4%20Grid%20-%20State%20Transistion%20Diagram.png"/>

In [None]:
class Action:
    value_map = {'up':0, 'down':1, 'left':2, 'right':3}
    
    def __init__(self, value):
        self.value = value.lower()
        self.index = self.value_map[self.value]
        
    def __eq__(self, other):
        if self.index == other.index:
            return True
        return False
    
    def __str__(self):
        return self.value
    
    
class State:
    def __init__(self, i, j):
        self.i = i
        self.j = j
        self.index = self.i * GRID_HEIGHT + self.j
   
    def left_of(self, other):
        if self.i == other.i and self.j - other.j == -1:
            return True
        return False

    def right_of(self, other):
        if self.i == other.i and self.j - other.j == 1:
            return True
        return False

    def above(self, other):
        if self.j == other.j and self.i - other.i == -1:
            return True
        return False

    def below(self, other):
        if self.j == other.j and self.i - other.i == 1:
            return True
        return False
    
    def on_top_edge(self):
        if self.i == 0:
            return True
        
    def on_bottom_edge(self):
        if self.i == GRID_HEIGHT - 1:
            return True
        
    def on_left_edge(self):
        if self.j == 0:
            return True
        
    def on_right_edge(self):
        if self.j == GRID_WIDTH - 1:
            return True
        
    def __eq__(self, other):
        if self.index == other.index:
            return True
        return False
    
    def __str__(self):
        return 's{}{}'.format(self.i,self.j)

In [None]:
actions = [Action('up'), Action('down'), Action('left'), Action('right')]
states = [State(i,j) for i,j in itertools.product(range(GRID_HEIGHT), range(GRID_WIDTH))]

# terminal states
s_term = [State(0,0), State(3,3)]

In [None]:
# |S| x |A|
uni_random_policy = full(shape=(len(states), len(actions)), fill_value=0.25)

In [None]:
def trans(s, a, s_p):
    if s in s_term:
        return 0.0
    if a == Action('up'):
        if s_p.above(s) or (s == s_p and s.on_top_edge()):
            return 1.0
    elif a == Action('down'):
        if s_p.below(s) or (s == s_p and s.on_bottom_edge()):
            return 1.0
    elif a == Action('left'):
        if s_p.left_of(s) or (s == s_p and s.on_left_edge()):
            return 1.0
    elif a == Action('right'):
        if s_p.right_of(s) or (s == s_p and s.on_right_edge()):
            return 1.0

    return 0.0
    
# |S| x |A| x |S|
p_trans = zeros(shape=(len(states), len(actions), len(states)))

for s, a, s_p in itertools.product(states, actions, states):
    p_trans[s.index, a.index, s_p.index] = trans(s, a, s_p) 

In [None]:
r_term = 0.0  # Reward for terminal state
r_step = -1.0 # Reward for any non-terminal state

gamma = 1.0  # Discount factor

In [None]:
def reward(state, action, next_state):
    if state in s_term:
        return r_term
    else:
        return r_step

# |S| x |A| x |S|
r = zeros(shape=(len(states),len(actions),len(states)))

for s, state in enumerate(states):
    for a, action in enumerate(actions):
        for s_p, next_state in enumerate(states):
            r[s,a,s_p] = reward(state,action,next_state)

In [None]:
def policy_evaluation(policy, vk):
    vk_new = zeros(shape=(len(states)))
    for s, state in enumerate(states):
        for a, action in enumerate(actions):
            for s_p, next_state in enumerate(states):
                vk_new[s] += policy[s, a] * p_trans[s, a, s_p] * (r[s, a, s_p] + gamma * vk[s_p])
    return vk_new

Evaluating Uniform Random Policy

In [None]:
vk = zeros(shape=(len(states)))

NUM_ITERS = 100
for k in range(NUM_ITERS):
    vk = policy_evaluation(uni_random_policy, vk)

vk_uni = copy(vk)
for s in states:
    print '{} = {:.0f}'.format(s, vk_uni[s.index])

In [None]:
def policy_improvement(vk):
    
    new_policy = zeros(shape=(len(states), len(actions)))
    for s, state in enumerate(states):
        max_a = None
        max_vk = -sys.maxint - 1
        for a, action in enumerate(actions):
            vk_cand = 0.0
            for s_p, next_state in enumerate(states):
                vk_cand += p_trans[s, a, s_p] * (r[s, a, s_p] + gamma * vk[s_p])  
            
            if vk_cand > max_vk:
                max_vk = vk_cand
                max_a = a
            
        new_policy[s, max_a] = 1.0
        
    return new_policy

Determine the optimal policy from the state value function for uniform random policy.

In [None]:
optimal_policy = policy_improvement(vk_uni)

In [None]:
optimal_policy

Evaluating State Value Function For Optimal Policy

In [None]:
vk_star = zeros(shape=(len(states)))

NUM_ITERS = 20
for k in range(NUM_ITERS):
    vk_star = policy_evaluation(optimal_policy, vk_star)
        
for s in states:
    print '{} = {}'.format(s, vk_star[s.index])

In [None]:
def policy_iteration(policy, vk, k_iters=1, epsilon=1e-4):

    stable_policy = False
    while not stable_policy:
        for k in range(k_iters):
            vk = policy_evaluation(policy, vk)
       
        new_policy = policy_improvement(vk)

        stable_policy = True
        for s in states:
            for a in actions:
                if abs(policy[s.index, a.index] - new_policy[s.index, a.index]) > epsilon:
                    stable_policy = False
                    
        policy = copy(new_policy)
            
    return policy, vk

In [None]:
vk = zeros(shape=(len(states)))
policy_star, vk_star = policy_iteration(uni_random_policy, vk, k_iters=1)

In [None]:
vk_star

In [None]:
policy_star