In [61]:
from enum import Enum
from tabulate import tabulate
from random import choice
from copy import deepcopy
import math

In [62]:
class Action(Enum):
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3

class GridWorld:
    def __init__(self, size=4, cur_state=[0,1]):
        self._size = size
        self._cur_state = cur_state
                        
    def __str__(self):
        
        table = [[' ' for _ in range(self._size)] for _ in range(self._size)]
        table[self._cur_state[0]][self._cur_state[1]] = 'x'
        
        return tabulate(table, tablefmt='orgtbl')
        
    def step(self, action):
        
        if self._is_final_state():
            return 0

        if action == Action.UP:
            self._cur_state[0] = max(0, self._cur_state[0] - 1)
        elif action == Action.DOWN:
            self._cur_state[0] = min(self._size - 1, self._cur_state[0] + 1)
        elif action == Action.LEFT:
            self._cur_state[1] = max(0, self._cur_state[1] - 1)
        elif action == Action.RIGHT:
            self._cur_state[1] = min(self._size - 1, self._cur_state[1] + 1)
        
        return -1
            
    def get_state(self):
        return self._cur_state[0] * self._size + self._cur_state[1]
    
    def _is_final_state(self):
        return self._cur_state == [0, 0] or self._cur_state == [self._size-1, self._size-1]

In [64]:
GRID_SIZE = 4
action_space = [Action(a) for a in list(range(GRID_SIZE))]


V = [0.0 for k in range(GRID_SIZE ** 2)]
V_save = deepcopy(V)
for k in range(1000):
    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            
            if (i == 0 and j == 0) or (i == GRID_SIZE-1 and j == GRID_SIZE-1):
                continue
            
            v = 0
            for action in action_space:
                gw = GridWorld(size=GRID_SIZE, cur_state=[i,j])
                s = gw.get_state()
                gw.step(action)
                s_next = gw.get_state()
                v +=  0.25 * V_save[s_next]

            V[i * GRID_SIZE + j] = -1 + v
    V_save = deepcopy(V)
    
    if k in [0, 1, 2, 3, 10, 999]:
        print('k = ', k)
        table = [V[i * GRID_SIZE:(i+1)*GRID_SIZE] for i in range(GRID_SIZE)]
        table = [[round(value, 1) for value in line] for line in table]
        print(tabulate(table, tablefmt='orgtbl'))

k =  0
|  0 | -1 | -1 | -1 |
| -1 | -1 | -1 | -1 |
| -1 | -1 | -1 | -1 |
| -1 | -1 | -1 |  0 |
k =  1
|  0   | -1.8 | -2   | -2   |
| -1.8 | -2   | -2   | -2   |
| -2   | -2   | -2   | -1.8 |
| -2   | -2   | -1.8 |  0   |
k =  2
|  0   | -2.4 | -2.9 | -3   |
| -2.4 | -2.9 | -3   | -2.9 |
| -2.9 | -3   | -2.9 | -2.4 |
| -3   | -2.9 | -2.4 |  0   |
k =  3
|  0   | -3.1 | -3.8 | -4   |
| -3.1 | -3.7 | -3.9 | -3.8 |
| -3.8 | -3.9 | -3.7 | -3.1 |
| -4   | -3.8 | -3.1 |  0   |
k =  10
|  0   | -6.6 | -9   | -9.7 |
| -6.6 | -8.3 | -9   | -9   |
| -9   | -9   | -8.3 | -6.6 |
| -9.7 | -9   | -6.6 |  0   |
k =  999
|   0 | -14 | -20 | -22 |
| -14 | -18 | -20 | -20 |
| -20 | -20 | -18 | -14 |
| -22 | -20 | -14 |   0 |


In [None]:
# There are a few differences in the table from the book due to decimal precision round ups