In [19]:
import numpy as np
from random import randint
import random

class EnvGrid(object):
    
    def __init__(self):
        super(EnvGrid, self).__init__()

        self.grid = [
            [0, 0, 1],
            [0, -1, 0],
            [0, 0, 0]
        ]
        
        # Starting position
        self.y = 2
        self.x = 0

        self.actions = [
            [-1, 0], # Up
            [1, 0], #Down
            [0, -1], # Left
            [0, 1] # Right
        ]

    def reset(self):
        
        self.y = 2
        self.x = 0
        return (self.y*3+self.x+1)

    def step(self, action):
        """
            Action: 0: UP, 1: DOWN, 2:LEFT, 3:RIGHT
        """
        self.y = max(0, min(self.y + self.actions[action][0],2))
        self.x = max(0, min(self.x + self.actions[action][1],2))

        return (self.y*3+self.x+1) , self.grid[self.y][self.x]

    def is_finished(self):
        return self.grid[self.y][self.x] == 1

def take_action(st, Q, eps):
    
    if random.uniform(0, 1) < eps:
        action = randint(0, 3)
    else: 
        action = np.argmax(Q[st])
    return action


if __name__ == '__main__':
    env = EnvGrid()
    st = env.reset()

    Q = [
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]
    ]

    for _ in range(100):
        
        st = env.reset()
        while not env.is_finished():
           
            at = take_action(st, Q, 0.4)

            stp1, r = env.step(at)
            

            # Update Q function
            atp1 = take_action(stp1, Q, 0.0)
            Q[st][at] = Q[st][at] + 0.1*(r + 0.9*Q[stp1][atp1] - Q[st][at])

            st = stp1

    for s in range(1, 10):
        print(s, Q[s])



1 [0.8099992664170232, 0.7289997491844794, 0.8099997441178153, 0.899999999999999]
2 [0.8999981682683241, -0.19000189215175234, 0.8099996487165364, 0.9999999999999996]
3 [0, 0, 0, 0]
4 [0.8099999999999987, 0.6560998789040847, 0.7289998912219666, -0.1900000855877098]
5 [0.8999999999131182, 0.5752583465067547, 0.6937665105851085, 0.6927115506010989]
6 [0.94185026299696, 0.06665282821606144, -0.052135043148348076, 0.3042485050887772]
7 [0.7289999999999983, 0.656099955163494, 0.6560999705547628, 0.5904897814072181]
8 [-0.1950466310684683, 0.5051448677757148, 0.6560999911311041, 0.3945162195474043]
9 [0.168809082003429, 0.0827616557752354, 0.5460124827190439, 0.121877643207849]
