In [18]:
from numpy import *
import numpy.matlib as matlib
import itertools

In [19]:
DEBUG = False

GRID_HEIGHT = 4
GRID_WIDTH = 4

class Action:
    value_map = {'up':0, 'down':1, 'left':2, 'right':3}
    
    def __init__(self, value):
        self.value = value.lower()
        self.index = self.value_map[self.value]
        
    def __eq__(self, other):
        if self.index == other.index:
            return True
        return False
    
    def __str__(self):
        return self.value
    
    
class State:
    def __init__(self, i, j):
        self.i = i
        self.j = j
        self.index = self.i * GRID_HEIGHT + self.j
   
    def left_of(self, other):
        if self.i == other.i and self.j - other.j == -1:
            return True
        return False

    def right_of(self, other):
        if self.i == other.i and self.j - other.j == 1:
            return True
        return False

    def above(self, other):
        if self.j == other.j and self.i - other.i == -1:
            return True
        return False

    def below(self, other):
        if self.j == other.j and self.i - other.i == 1:
            return True
        return False
    
    def on_top_edge(self):
        if self.i == 0:
            return True
        
    def on_bottom_edge(self):
        if self.i == GRID_HEIGHT - 1:
            return True
        
    def on_left_edge(self):
        if self.j == 0:
            return True
        
    def on_right_edge(self):
        if self.j == GRID_WIDTH - 1:
            return True
    
    def __eq__(self, other):
        if self.index == other.index:
            return True
        return False
    
    def __str__(self):
        return 's{}{}'.format(self.i,self.j)
    
actions = [Action('up'), Action('down'), Action('left'), Action('right')]
states = [State(i,j) for i,j in itertools.product(range(GRID_HEIGHT), range(GRID_WIDTH))]

# terminal states
s_term = [State(0,0), State(3,3)]

In [3]:
# |S| x |A|
p_policy = full(shape=(len(states), len(actions)), fill_value=0.25)

In [20]:
def trans(s, a, s_p):
    if s in s_term:
        return 0.0
    if a == Action('up'):
        if s_p.above(s) or (s == s_p and s.on_top_edge()):
            return 1.0
    elif a == Action('down'):
        if s_p.below(s) or (s == s_p and s.on_bottom_edge()):
            return 1.0
    elif a == Action('left'):
        if s_p.left_of(s) or (s == s_p and s.on_left_edge()):
            return 1.0
    elif a == Action('right'):
        if s_p.right_of(s) or (s == s_p and s.on_right_edge()):
            return 1.0

    return 0.0
    
# |S| x |A| x |S|
p_trans = zeros(shape=(len(states), len(actions), len(states)))

for s, a, s_p in itertools.product(states, actions, states):
    p_trans[s.index, a.index, s_p.index] = trans(s, a, s_p) 

In [22]:
s = State(3,2)
    
for a in actions:
    for s_p in states:
        p = p_trans[s.index, a.index, s_p.index]
        if p:
            print 'p_trans[{},{},{}] = {}'.format(s,a,s_p,p_trans[s.index, a.index, s_p.index])

p_trans[s32,up,s22] = 1.0
p_trans[s32,down,s32] = 1.0
p_trans[s32,left,s31] = 1.0
p_trans[s32,right,s33] = 1.0


In [23]:
r_term = 0.0  # Reward of terminal state
r_step = -1.0 # Reward for any non-terminal state

gamma = 1.0  # Discount factor

In [24]:
def reward(state, action, next_state):
    if state in s_term:
        return r_term
    else:
        return r_step

# |S| x |A| x |S|
r = zeros(shape=(len(states),len(actions),len(states)))

for s, state in enumerate(states):
    for a, action in enumerate(actions):
        for s_p, next_state in enumerate(states):
            r[s,a,s_p] = reward(state,action,next_state)

In [25]:
def value_iteration(vk, vk_new):
    for s, state in enumerate(states):
        vk_new[s] = 0
        for a, action in enumerate(actions):
            for s_p, next_state in enumerate(states):
                vk_new[s] += p_policy[s, a] * p_trans[s, a, s_p] * (r[s, a, s_p] + gamma * vk[s_p])  
                
                if DEBUG:
                    print 's, a, s_p = ({}, {}, {})'.format(state, action, next_state)
                    print 'p = {}'.format(p_policy[s, a] * p_trans[s, a, s_p])
                    print 'r[s, a, s_p] = {}'.format(r[s, a, s_p])
                    print 'vk[s_p] = {}'.format(vk[s])
                    print 'vk_new[s] = {}'.format(vk_new[s])
                    print '**********************************'
              

In [27]:
vk = zeros(shape=(len(states)))
vk_new = zeros(shape=(len(states)))

for k in range(200):
    value_iteration(vk, vk_new)
    vk = copy(vk_new)
        
for s in states:
    print '{} = {}'.format(s, vk[s.index])

s00 = 0.0
s01 = -13.9997574073
s02 = -19.9996405213
s03 = -21.9995977221
s10 = -13.9997574073
s11 = -17.9996833205
s12 = -19.999642926
s13 = -19.9996405213
s20 = -19.9996405213
s21 = -19.999642926
s22 = -17.9996833205
s23 = -13.9997574073
s30 = -21.9995977221
s31 = -19.9996405213
s32 = -13.9997574073
s33 = 0.0


In [None]:
def improve_policy(p_policy):
    pass