In [1]:
import numpy as np


class GridWorld:

    # Initialization of parameters, variable names are self-explanatory
    def __init__(self, no_of_rows, no_of_cols, discount, prob):

        self.no_of_rows = no_of_rows
        self.no_of_cols = no_of_cols
        self.discount_rate = discount
        self.action_prob = prob
        self.v_pi = np.zeros((no_of_rows, no_of_cols))
        self.no_of_states = no_of_rows*no_of_cols
        self.actions = list()
        self.actions.append([0, -1])
        self.actions.append([-1, 0])
        self.actions.append([0, 1])
        self.actions.append([1, 0])

    # Finds state-value function
    def find_v_pi(self):

        flag = 0
        while flag == 0: # Until convergence
            v_pi_new = np.zeros(self.v_pi.shape)
            for i in range(self.no_of_rows): # Iterate over states
                for j in range(self.no_of_cols):
                    for action in self.actions: # Iterate over actions
                        next_state, reward = self.take_step([i, j], action) # Take step and get reward
                        v_pi_new[i, j] += self.action_prob * (reward + self.discount_rate * self.v_pi[next_state[0], next_state[1]]) # Bellman's update rule
            if np.sum(np.abs(self.v_pi-v_pi_new)) < 0.0004: # Check convergence
                print(self.v_pi)
                break
            self.v_pi = v_pi_new

    # Takes current state and action taken as parameters and returns next state and reward
    def take_step(self, state, action):

        if state[0] == 0 and state[1] == 1:
            return [4,1], 10
        elif state[0] == 0 and state[1] == 3:
            return [2,3], 5
        next_state = list()
        next_state.append(state[0]+action[0])
        next_state.append(state[1]+action[1])
        if next_state[0] < 0 or next_state[0] >= self.no_of_rows or next_state[1] < 0 or next_state[1] >= self.no_of_cols:
            reward = -1
            next_state = state
        else:
            reward = 0
        return next_state, reward


In [2]:
gridworld_obj = GridWorld(5,5,0.9,0.25)
gridworld_obj.find_v_pi()

[[ 3.30914382  8.78943917  4.42776616  5.32251425  1.49232523]
 [ 1.52173554  2.99246515  2.25028693  1.90771837  0.54754919]
 [ 0.05096995  0.73831787  0.67326024  0.3583329  -0.40299464]
 [-0.97344484 -0.43534815 -0.35473528 -0.58545839 -1.18292857]
 [-1.85755309 -1.34508398 -1.22912027 -1.42277145 -1.97503253]]
