In [201]:
from enum import Enum
from tabulate import tabulate
from random import choice


In [221]:
class Action(Enum):
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3

class GridWorld:
    def __init__(self, size=4, cur_state=[0,1]):
        self._size = size
        self._cur_state = cur_state
                        
    def __str__(self):
        
        table = [[' ' for _ in range(self._size)] for _ in range(self._size)]
        table[self._cur_state[0]][self._cur_state[1]] = 'x'
        
        return tabulate(table, tablefmt='orgtbl')
        
    def step(self, action):
        
        if self._is_final_state():
            return 0

        if action == Action.UP:
            self._cur_state[0] = max(0, self._cur_state[0] - 1)
        elif action == Action.DOWN:
            self._cur_state[0] = min(self._size - 1, self._cur_state[0] + 1)
        elif action == Action.LEFT:
            self._cur_state[1] = max(0, self._cur_state[1] - 1)
        elif action == Action.RIGHT:
            self._cur_state[1] = min(self._size - 1, self._cur_state[1] + 1)
        
        return -1
            
    def get_state(self):
        return self._cur_state[0] * self._size + self._cur_state[1]
    
    def _is_final_state(self):
        return self._cur_state == [0, 0] or self._cur_state == [self._size-1, self._size-1]

In [231]:
GRID_SIZE = 4
action_space = [Action(a) for a in list(range(GRID_SIZE))]


V = [0.0 for k in range(GRID_SIZE ** 2)]
for _ in range(10):
    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            
            if (i == 0 and j == 0) or (i == GRID_SIZE-1 and j == GRID_SIZE-1):
                continue
            
            v = 0
            for action in action_space:
                gw = GridWorld(size=GRID_SIZE, cur_state=[i,j])
                s = gw.get_state()
                gw.step(action)
                s_next = gw.get_state()
                v +=  0.25 * V[s_next]

            V[i * GRID_SIZE + j] = -1 + v

In [232]:
print(V)

[0.0, -7.82506756017392, -11.121818717405404, -12.227655033372685, -7.82506756017392, -10.420371658008662, -11.769337388334861, -11.8633986227901, -11.121818717405404, -11.769337388334861, -11.053818103924186, -8.81387888655695, -12.227655033372685, -11.8633986227901, -8.81387888655695, 0.0]


In [170]:
gw = GridWorld(size=GRID_SIZE, cur_state=[0,1])

In [171]:
print(gw.get_state())
gw.step(Action(2))
print(gw.get_state())

1
0


In [168]:
Action(2)

<Action.LEFT: 2>

0


In [156]:
print(gw)

| x |  |  |  |
|   |  |  |  |
|   |  |  |  |
|   |  |  |  |
