## Grid World 
Based on State Space Search (actions here are considered deterministic) <br><br>
In the code below, the grid world looks as follows:

<img src="grid_world_value_iter.png" alt="example grid world"/>

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Global Variables
win_state = (0, 3)
lose_state = (1, 3)
start_state = (2, 0)
walls = [(1, 1),(0, 1)]
deteminstic = True
grid_rows = 3
grid_cols = 4
uni_actions = ['up', 'down', 'left', 'right']

In [3]:
class State:
    def __init__(self, state=start_state):
        self.state = state
        self.is_end = False
        self.determine = deteminstic

    def give_reward(self):
        if self.state == win_state:
            return 1
        elif self.state == lose_state:
            return -1
        else:
            return 0

    def is_end_fn(self):
        if (self.state == win_state) or (self.state == lose_state):
            self.is_end = True

    def next_position(self, action):
        next_state = self.state
        if action == 'up':
            next_state = (self.state[0]-1, self.state[1])
        elif action == 'down':
            next_state = (self.state[0]+1, self.state[1])
        elif action == 'left':
            next_state = (self.state[0], self.state[1]-1)
        elif action == 'right':
            next_state = (self.state[0], self.state[1]+1)
        if next_state[0] >= 0 and next_state[0] < grid_rows and next_state[1] >= 0 and next_state[1] < grid_cols and next_state not in walls:
            return next_state
        else:
            return self.state


class Agent:
    def __init__(self):
        self.actions = ['up', 'down', 'left', 'right']
        self.exp_rate = 0.3
        self.lr = 0.2
        self.visited_states = []
        self.State = State()
        # Stores the state values for each state
        self.state_values = {}
        for i in range(grid_rows):
            for j in range(grid_cols):
                    self.state_values[(i, j)] = 0

    def choose_action(self):
        max_next_reward = 0
        action = ""
        if np.random.uniform(0, 1) <= self.exp_rate:
            action = np.random.choice(self.actions)
        else:
            for act in self.actions:
                next_reward = self.state_values[self.State.next_position(act)]
                if next_reward >= max_next_reward:
                    action = act
                    max_next_reward = next_reward
        return action
    
    def go_to_next_state(self, action):
        new_state = self.State.next_position(action)
        return State(state=new_state)
    
    def reset(self):
        self.State = State()
        self.visited_states = []

    def learn(self, reward):
        reward = reward
        for state in reversed(self.visited_states):
            reward = self.state_values[state] + self.lr * (reward - self.state_values[state])
            self.state_values[state] = round(reward, 3)

    def play(self, rounds=10):
        for _ in range(rounds):
            while True:
                self.visited_states.append(self.State.state)
                self.State.is_end_fn()
                if self.State.is_end:
                    reward = self.State.give_reward()
                    self.state_values[self.State.state] = reward
                    self.learn(reward)
                    self.reset()
                    #self.show_values()
                    break
                action = self.choose_action()
                self.State = self.go_to_next_state(action)
            #print("End of the game " + str(_ + 1) + "\n")

    def show_values(self):
        for i in range(0, grid_rows):
            print('----------------------------------')
            out = '| '
            for j in range(0, grid_cols):
                out += str(self.state_values[(i, j)]).ljust(6) + ' | '
            print(out)
        print('----------------------------------')

In [4]:
agent = Agent()
agent.play(300)
agent.show_values()

----------------------------------
| 0.624  | 0      | 0.998  | 1.0    | 
----------------------------------
| 0.627  | 0      | 0.984  | -1.0   | 
----------------------------------
| 0.786  | 0.837  | 0.894  | 0.586  | 
----------------------------------
