In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

In [96]:
class GridWorld:
    
    def __init__(self, w, h, ts):
        self.world_height = h
        self.world_width = w
        self.values = np.zeros((h,w))
        self.terminal_states = ts
            
    def transition_to_next_state(self, state, action, values):
        
        x = state[1]
        y = state[0]
        v_next = values[x,y]

        if state in self.terminal_states:
            return v_next, 0.0, 1.0

        if action == 'up':
            if (y-1) >= 0:
                v_next = values[x,y-1]
        elif action == 'down':
            if (y+1) < self.world_height:
                v_next = values[x,y+1]
        elif action == 'right':
            if (x+1) < self.world_width:
                v_next = values[x+1,y]
        elif action == 'left':
            if (x-1) >= 0:
                v_next = values[x-1,y]

        return v_next, -1.0, 1.0
        
    def policy_evaluation(self, discount, accuracy_threshold, policy):
        
        delta = 2 * accuracy_threshold
        
        while delta > accuracy_threshold:
            delta = 0
            current_values = np.copy(self.values)
            for state, value in np.ndenumerate(current_values):
                value = 0
                for action, probability_of_action in policy[state].items():
                    value_on_next_state, reward, transition_probability = self.transition_to_next_state(state, action, current_values)
                    value += probability_of_action * transition_probability * ( reward + discount * value_on_next_state )
                    self.values[state] = value
            delta = np.amax(np.absolute(current_values-self.values))
        
    def best_policy(self, state, discount):
        
        actions = ['up', 'down', 'left', 'right']
        action_value = [0.0, 0.0, 0.0, 0.0]
        
        for action in actions:
            value_on_next_state, reward, transition_probability = self.transition_to_next_state(state, action, self.values)
            idx = actions.index(action)
            action_value[idx] = transition_probability * ( reward + discount * value_on_next_state )
        idx_max_action_value = action_value.index(max(action_value))
        
        new_greedy_policy = {'up':0.0, 'down':0.0, 'right':0.0, 'left':0.0}
        new_greedy_policy[actions[idx_max_action_value]] = 1.0
        return new_greedy_policy
        
    
    def policy_improvment(self, discount, accuracy_threshold, policy):
        
        policy_stable = True
        old_policy = np.copy(policy)
        for state, value in np.ndenumerate(self.values):
            policy[state] = self.best_policy(state, discount)
            if old_policy[state] != policy[state]:
                policy_stable = False
        return policy, policy_stable
            
            
    def policy_iteration(self, discount, accuracy_threshold, policy):
        
        while True:
            self.policy_evaluation(discount, accuracy_threshold, policy)
            policy, policy_stable = self.policy_improvment(discount, accuracy_threshold, policy)
            if policy_stable:
                break


In [91]:
gridworld = GridWorld(4,4,[(0,0), (3,3)])
policy = np.full((gridworld.world_height, gridworld.world_width), {'up':0.25, 'down':0.25, 'right':0.25, 'left':0.25})
gridworld.policy_evaluation(0.9, 0.0001, policy)
gridworld.values

array([[ 0.        , -5.27748322, -7.1279106 , -7.64996138],
       [-5.27748322, -6.60585983, -7.18012478, -7.1279106 ],
       [-7.1279106 , -7.18012478, -6.60585983, -5.27748322],
       [-7.64996138, -7.1279106 , -5.27748322,  0.        ]])

In [92]:
gridworld = GridWorld(4,4,[(0,0), (3,3)])
policy = np.full((gridworld.world_height, gridworld.world_width), {'up':0.25, 'down':0.25, 'right':0.25, 'left':0.25})
gridworld.policy_iteration(0.9, 0.0001, policy)

In [93]:
gridworld = GridWorld(4,4,[(0,0), (3,3)])
policy = np.full((gridworld.world_height, gridworld.world_width), {'up':1.0, 'down':0.0, 'right':0.0, 'left':0.0})
gridworld.policy_iteration(0.9, 0.0001, policy)