# SARSA & Q-learning


In [None]:
"""
Multi Agent systems - Assignment 6
Tiddo Loos
date:08-01-2020
"""

import matplotlib.pylab as plt
import numpy as np;np.random.seed(0)
import seaborn as sns;sns.set_theme()
import random
import operator


class State:
    def __init__(self, y, x, grid, wall=False, treasure=False, snakepit=False):
        self.y = y
        self.x = x
        self.grid = grid
        self.neighbours = []
        self.value = 0
        self.wall = wall
        self.treasure = treasure
        self.snakepit = snakepit

    def get_neighbours(self, state):
        if state.y >= 1:
            self.neighbours.append(self.grid.states[self.y - 1][self.x])
        if state.y <= self.grid.size - 2:
            self.neighbours.append(self.grid.states[self.y + 1][self.x])
        if state.x >= 1:
            self.neighbours.append(self.grid.states[self.y][self.x - 1])
        if state.x <= self.grid.size - 2:
            self.neighbours.append(self.grid.states[self.y][self.x + 1])
        for neighbour in self.neighbours:
            if neighbour.wall:
                self.neighbours.remove(neighbour)
        add_current_state = 4 - len(self.neighbours)
        for i in range(add_current_state):
            self.neighbours.append(self.grid.states[self.y][self.x])

    def update_state_value(self):
        new_value = 0
        if self.grid.states[self.y][self.x].wall:
            self.value = 0
        elif self.grid.states[self.y][self.x].treasure:
            self.value = self.grid.reward_treasure
        elif self.grid.states[self.y][self.x].snakepit:
            self.value = self.grid.reward_snakepit
        else:
            for neighbour in self.neighbours:
                reward = self.get_reward_state(neighbour)
                new_value += 1/4 * reward
            self.value = new_value

    def get_reward_state(self, neighbour):
        y, x = neighbour.y, neighbour.x
        if (y, x) is self.grid.treasure:
            reward = self.grid.reward_treasure
        elif (y, x) is self.grid.snakepit:
            reward = self.grid.reward_snakepit
        else:
            reward = (self.grid.cost + neighbour.value)
        return reward


class Agent:
    def __init__(self, pos_y, pos_x, grid):
        self.pos_y = pos_y
        self.pos_x = pos_x
        self.agent_reward = 0
        self.reward_list = []
        self.grid = grid

    def walk_agent(self):
        action_one = self.e_greedy(self.grid.q_value)
        while ([(self.pos_y, self.pos_x)]) != self.grid.treasure and ([(self.pos_y, self.pos_x)]) != self.grid.snakepit:
            [current_y, current_x] = self.pos_y, self.pos_x
            self.action(self.pos_y, self.pos_x, action_one, self.grid.walls, self.grid.size)
            reward = self.grid.grid_reward(self.pos_y, self.pos_x)
            self.agent_reward += reward
            new_action = self.e_greedy(self.grid.q_value)
            [new_y, new_x] = self.pos_y, self.pos_x
            # sarsa or q learning
            if self.grid.sarsa:
                update_action = new_action
            else:
                update_action = self.e_greedy(self.grid.q_value)
            self.grid.update_q_values(current_y, current_x, action_one, reward, new_y, new_x, update_action)
            action_one = new_action
        return

    def e_greedy(self, q_value):
        global count
        if self.grid.sarsa:
            e = 0.1
        else:
            e = 0 #q_learning

        [y, x] = [self.pos_y, self.pos_x]
        if np.random.uniform(0,1) < e:
            return np.random.choice(list(q_value[y][x].keys()))
        else:
            return max(q_value[y][x], key=q_value[y][x].get)

    def action(self, y, x, action, walls, size):
        current_state = [y, x]
        if action == 'north':
            next_state = (current_state[0]-1, current_state[1])
        if action == 'south':
            next_state = (current_state[0]+1, current_state[1])
        if action == 'east':
            next_state = (current_state[0], current_state[1]+1)
        if action == 'west':
            next_state = (current_state[0], current_state[1]-1)
        self.update_position(current_state, next_state, walls, size)

    def update_position(self, current_state, next_state, walls, size):
        if next_state in walls or \
                 next_state[0] < 0 or \
                 next_state[0] >= size or \
                 next_state[1] < 0 or \
                 next_state[1] >= size:
            new_pos = current_state
        else:
            new_pos = next_state
        self.pos_y = new_pos[0]
        self.pos_x = new_pos[1]


class Grid:
    def __init__(self, walls, treasure, snakepit, size, reward_treasure, reward_snakepit, cost, alfa, gamma, sarsa=False):
        self.states = []
        self.walls = walls
        self.treasure = treasure
        self.snakepit = snakepit
        self.size = size
        self.reward_snakepit = reward_snakepit
        self.reward_treasure = reward_treasure
        self.cost = cost
        self.values_list = []
        self.init_grid()
        # assignment 2: sarsa/q_learning
        self.alfa = alfa
        self.gamma = gamma
        self.sarsa = sarsa
        self.q_value = [[{'north': 0., 'south': 0., 'east': 0., 'west': 0.}for _ in range(self.size)]for _ in range(self.size)]
        self.pos_y = []
        self.pos_x = []
        self.policy_y = []
        self.policy_sarsa_x = []
        self.create_agent_on_grid()

    def init_grid(self):
        self.fill_grid()
        for row in self.states:
            for state in row:
                state.get_neighbours(state)
        for i in range(500):
            self.get_grid_values()

    def fill_grid(self):
        for y in range(self.size):
            row = []
            for x in range(self.size):
                if (y, x) in self.walls:
                    row.append(State(y, x, self, wall=True))
                elif (y, x) in self.treasure:
                    row.append(State(y, x, self, treasure=True))
                elif (y, x) in self.snakepit:
                    row.append(State(y, x, self, snakepit=True))
                else:
                    row.append(State(y, x, self))
            self.states.append(row)

    def get_grid_values(self):
        for row in self.states:
            for state in row:
                state.update_state_value()

    def append_value_list(self):
        for row in self.states:
            row_values = []
            for state in row:
                row_values.append(state.value)
            self.values_list.append(row_values)

    def show_heatmap(self):
        self.append_value_list()
        sns.heatmap(self.values_list, annot=True, fmt='0.1f')
        plt.show()

    def create_agent_on_grid(self):
        global count
        count = 0
        for i in range(100):
            y, x = self.random_position()
            Agent(y, x, self).walk_agent()
            count += 1
        self.save_policy_data(self.policy_y, self.policy_sarsa_x)

    def random_position(self):
        y, x = self.walls[0]
        while (y, x) in self.walls:
            y = random.randint(0, self.size-1)
            x = random.randint(0, self.size-1)
        position = (y, x)
        # position = (0, 0) #to start always form the 0,0 position
        return position

    def grid_reward(self, y, x):
        current_pos = (y, x)
        if [current_pos] == self.snakepit:
            return self.reward_snakepit
        elif [current_pos] == self.treasure:
            return self.reward_treasure
        else:
            return self.cost

    def update_q_values(self, current_y, current_x, action, reward, new_y, new_x, update_action):
        self.q_value[current_y][current_x][action] += self.alfa * (reward + self.gamma * self.q_value[new_y][new_x][update_action]-self.q_value[current_y][current_x][action])
        return

    def save_policy_data(self, policy_y, policy_x):
        for y in range(len(self.q_value)):
            for x in range(len(self.q_value[y])):
                direction = max(self.q_value[y][x].items(), key=operator.itemgetter(1))
                if direction[1] == 0.0:
                    policy_y.append(0)
                    policy_x.append(0)
                elif direction[0] == 'north':
                    policy_y.append(0.3)
                    policy_x.append(0)
                elif direction[0] == 'east':
                    policy_y.append(0)
                    policy_x.append(0.3)
                elif direction[0] == 'west':
                    policy_y.append(0)
                    policy_x.append(-0.3)
                else:
                    policy_y.append(-0.3)
                    policy_x.append(0)
        return

    def show_quiver_map(self):
        pos_y = []
        pos_x = []
        for y in range(len(self.q_value)):
            for x in range(len(self.q_value[y])):
                pos_y.append(y)
                pos_x.append(x)
        pos_y_r = pos_y[-1::-1]
        fig, ax = plt.subplots()
        if self.sarsa:
            title = 'SARSA policies (e=0.1)'
        else:
            title = 'Q-learning policies (e=0)'
        plt.title(title)
        ax.quiver(pos_x, pos_y_r, self.policy_sarsa_x, self.policy_y, scale=5)
        plt.show()


In [None]:
def main():
    walls = [(7, 1), (7, 2), (7, 3), (7, 4), (5, 6), (4, 6), (3, 6), (2, 6), (1, 6), (1, 5), (1, 4), (1, 3), (1, 2)]
    treasure = [(8, 8)]
    snakepit = [(6, 5)]
    size = 9
    reward_snakepit = -50
    reward_treasure = 50
    cost = -1

    # to use the sarsa algorithm set sarsa=True)
    grid = Grid(walls, treasure, snakepit, size, reward_treasure, reward_snakepit, cost, alfa=0.5, gamma=1, sarsa=True)
    grid.show_heatmap()
    
    grid.show_quiver_map()

    # create new grid (set q values to 0) and use Q-learning
    grid_q = Grid(walls, treasure, snakepit, size, reward_treasure, reward_snakepit, cost, alfa=0.5, gamma=1)
    grid_q.show_quiver_map()


main()