In [3]:
import numpy as np
from random import randint
import random

class EnvGrid(object):
    
    def __init__(self):
        super(EnvGrid, self).__init__()

        self.grid = [
            [0, 0, 1],
            [0, -1, 0],
            [0, 0, 0]
        ]
        
        # Starting position
        self.y = 2
        self.x = 0

        self.actions = [
            [-1, 0], # Up
            [1, 0], #Down
            [0, -1], # Left
            [0, 1] # Right
        ]

    def reset(self):
        
        self.y = 2
        self.x = 0
        return (self.y*3+self.x+1)

    def step(self, action):
        """
            Action: 0: UP, 1: DOWN, 2:LEFT, 3:RIGHT
        """
        self.y = max(0, min(self.y + self.actions[action][0],2))
        self.x = max(0, min(self.x + self.actions[action][1],2))

        return (self.y*3+self.x+1) , self.grid[self.y][self.x]

    def is_finished(self):
        return self.grid[self.y][self.x] == 1

def take_action(st, Q, eps):
    
    if random.uniform(0, 1) < eps:
        action = randint(0, 3)
    else: 
        action = np.argmax(Q[st])
    return action


if __name__ == '__main__':
    env = EnvGrid()
    st = env.reset()

    Q = [
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]
    ]

    for _ in range(100):
        
        alpha = 1; eps = 1
        st = env.reset()
        while not env.is_finished():
           
            at = take_action(st, Q, eps)

            stp1, r = env.step(at)
            

            # Update Q function
            atp1 = take_action(stp1, Q, 0.0)
            Q[st][at] = Q[st][at] + alpha*(r + 0.9*Q[stp1][atp1] - Q[st][at])
            
            st = stp1
            alpha = 0.99*alpha
            eps = 0.99*eps
            

    for s in range(1, 10):
        print(s, Q[s])



1 [0.81, 0.7290000000000001, 0.81, 0.9]
2 [0.9, -0.18999999999999995, 0.81, 1.0]
3 [0, 0, 0, 0]
4 [0.81, 0.6561000000000001, 0.7290000000000001, -0.18999999999999995]
5 [0.9, 0.7290000000000001, 0.7290000000000001, 0.9]
6 [1.0, 0.81, -0.18999999999999995, 0.9]
7 [0.7290000000000001, 0.6561000000000001, 0.6561000000000001, 0.7290000000000001]
8 [-0.18999999999999995, 0.7290000000000001, 0.6561000000000001, 0.81]
9 [0.9, 0.81, 0.7290000000000001, 0.81]
