In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [134]:
class GridWorld:
    
    def __init__(self, w, h, ts):
        self.world_height = h
        self.world_width = w
        self.values = np.zeros((h,w))
        self.terminal_states = ts
            
    def transition_to_next_state(self, state, action, values):
        
        x = state[0]
        y = state[1]
        v_next = values[x,y]

        if state in self.terminal_states:
            return v_next, 0.0, 1.0

        if action == 'up':
            if (y-1) >= 0:
                v_next = values[x,y-1]
        elif action == 'down':
            if (y+1) < self.world_height:
                v_next = values[x,y+1]
        elif action == 'right':
            if (x+1) < self.world_width:
                v_next = values[x+1,y]
        elif action == 'left':
            if (x-1) >= 0:
                v_next = values[x-1,y]

        return v_next, -1.0, 1.0
        
    def policy_evaluation(self, discount, accuracy_threshold, policy):
        
        steps = 0
        delta = 1.0
        
        while delta > accuracy_threshold: #delta > accuracy_threshold :
            
            current_values = np.copy(self.values)
            for state, value in np.ndenumerate(current_values):
                value = 0
                for action, probability_of_action in policy[state].items():
                    value_on_next_state, reward, transition_probability = self.transition_to_next_state(state, action, current_values)
                    value += probability_of_action * transition_probability * ( reward + discount * value_on_next_state )
                    self.values[state] = value
            steps += 1
            delta = np.amax(np.absolute(current_values-self.values))
        
        print(np.round(self.values,1))
        return steps

In [137]:
gridworld = GridWorld(4,4,[(0,0), (3,3)])
policy = np.full((gridworld.world_height, gridworld.world_width), {'up':0.25, 'down':0.25, 'right':0.25, 'left':0.25})
gridworld.policy_evaluation(1.0, 0.001, policy)

[[  0. -14. -20. -22.]
 [-14. -18. -20. -20.]
 [-20. -20. -18. -14.]
 [-22. -20. -14.   0.]]


131