# Deep Learning - Tabular Reinforcement Learning
## Grid World

A simple  demonstration of tabular reinforcement learning.

Import required packages

In [None]:
import numpy as np
import pandas as pd
import math
import random
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
saving_dpi = 100
random_seed = 1234
epsilon = 0.1
num_episodes = 350

## Modelling Grid Worlds

Cell class

In [None]:
class Cell:
    
    # Constructor which sets position and kind of a cell
    def __init__(self, parent, label, position, kind = 'empty'):
        """Contructs a new cell object
        Arguments
            parent: the world to which this cell belongs
            label: a string label for this cell
            position: the (row, col) coordinates of this cell
            kind: a string indicating the kind of cell
        """
        self.parent = parent
        self.label = label
        self.position = position
        self.kind = kind
        self.actions = dict()

    # Generate the actions available from a cell, the cells they will bring an agent to, and the reward that will be returned
    def generate_action_state_pairs(self, rewards):
        
        # Add actions that are available from this cell by looking at cordiantes and find the cell that we end up in next
        self.actions = dict()

        # If the agent can go up from this cell, setup an up action
        if self.position[0] > 0:
            next_cell = self.parent.find(self.position[0] - 1, self.position[1])
            reward = rewards[next_cell.kind]
            # If the cell was a trap go back to the beginning
            #if next_cell.kind == 'trap': 
            #    next_cell = self.parent.start_cell
            self.actions['up'] = {'state':next_cell, 'reward':reward}
        else:
            reward = rewards[self.kind]
            self.actions['up'] = {'state':self, 'reward':reward}
            
        # If the agent can go down from this cell, setup an down action
        if self.position[0] < (self.parent.num_rows - 1):
            next_cell = self.parent.find(self.position[0] + 1, self.position[1])
            reward = rewards[next_cell.kind]
            # If the cell was a trap go back to the beginning
            #if next_cell.kind == 'trap': 
            #    next_cell = self.parent.start_cell
            self.actions['down'] = {'state':next_cell, 'reward':reward}
        else:
            reward = rewards[self.kind]
            self.actions['down'] = {'state':self, 'reward':reward}
            
        # If the agent can go left from this cell, setup an left action
        if self.position[1] > 0:
            next_cell = self.parent.find(self.position[0], self.position[1] - 1)
            reward = rewards[next_cell.kind]
            # If the cell was a trap go back to the beginning
            #if next_cell.kind == 'trap': 
            #    next_cell = self.parent.start_cell
            self.actions['left'] = {'state':next_cell, 'reward':reward}
        else:
            reward = rewards[self.kind]
            self.actions['left'] = {'state':self, 'reward':reward}
            
        # If the agent can go right from this cell, setup an right action            
        if self.position[1] < (self.parent.num_cols - 1):
            next_cell = self.parent.find(self.position[0], self.position[1] + 1)
            reward = rewards[next_cell.kind]
            # If the cell was a trap go back to the beginning
            #if next_cell.kind == 'trap': 
            #    next_cell = self.parent.start_cell 
            self.actions['right'] = {'state':next_cell, 'reward':reward}
        else:
            reward = rewards[self.kind]
            self.actions['right'] = {'state':self, 'reward':reward}
            
    # Genereate as string representation of the cell. Wh
    def to_string(self, details = False):
        s = ""
        if(details):
            s = self.label + "(" + self.kind + ") "
            for  a in self.actions.keys():
                s = s + a + "(" + self.actions[a]['state'].label + "), "
        else:
            s = self.label
        
        return s

World class

In [None]:
class World:
    
    def __init__(self, num_rows = 10, num_cols = 10, special_cells = None):
        """Contructs a new world
        Arguments
            num_rows: the size of the world in rows
            num_cols: the size of the world in cols
            special_cells: a list of special cells given as ((row, col), kind) pairs for special cells. Kinds allowed are 'start', 'empty', 'trap', and 'goal'
        """
        
        # Set up rewards returned by the world
        self.rewards = dict()
        self.rewards['start'] = 0
        self.rewards['empty'] = -1
        self.rewards['trap'] = -2
        self.rewards['trap1'] = -2
        self.rewards['trap2'] = -10
        self.rewards['trap3'] = -100
        self.rewards['goal'] = 50
        self.allowed_special_cells = ['start', 'empty', 'trap', 'trap1', 'trap2', 'trap3', 'goal']
        
        # Attributes to store world and world details
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.cells = list()
        
        # Generate a grid of cells
        idx = 0
        for r in range(0, num_rows):
            for c in range(0, num_cols):
                
                # Generate labels that have consisten number of digits
                l_r = ""
                size_digits = len(str(num_rows - 1))
                while len(str(r)) < size_digits:
                    l_r = l_r + "0"
                    size_digits -= 1
                l_r = l_r + str(r)
                
                l_c = ""
                size_digits = len(str(num_cols - 1))
                while len(str(c)) < size_digits:
                    l_c = l_c + "0"
                    size_digits -= 1
                l_c = l_c + str(c)
                
                l = l_r + '-' + l_c
                
                cell = Cell(self, l, (r, c)) 
                self.cells.append(cell)
                idx = idx + 1

        # If not special cell statuses have been set set defaults
        if special_cells == None:

            self.cells[0].kind = 'start'
            self.start_cell = self.cells[0]
            self.cells[-1].kind = 'goal'
            self.goal_cell = self.cells[-1]
        
        # If the user has set special cell statuses then add them
        else:
            start_added = False
            goal_added = False
            
            # Special cells contains a list of ((row, col), kind) pairs for special cells
            for entry in special_cells:
                if entry[1] in self.allowed_special_cells:
                    self.set_cell_kind(entry[0][0], entry[0][1], entry[1])
                    
                    # If the start cell has just been added then update it and record this
                    if entry[1] == 'start':
                        self.start_cell = self.find(entry[0][0], entry[0][1])
                        start_added = True
                        
                    # If the goal cell has just been added then update it and record this
                    if entry[1] == 'goal':
                        self.goal_cell = self.find(entry[0][0], entry[0][1])
                        goal_added = True
                   
            # If the start and goal haven't been added then add them as first and laast cell respectively
            if not start_added:
                self.cells[0].kind = 'start'
                self.start_cell = self.cells[0]
            if not goal_added:
                self.cells[-1].kind = 'goal'
                self.goal_cell = self.cells[-1]
            
        # Generate a list of vald action state pairs for each cell
        for c in self.cells:
            c.generate_action_state_pairs(self.rewards)
        
        # Generate a numeric version and label version of the world map - used for drawing worlds
        self.generate_numeric_map()
        self.generate_label_map()
            
    # Find a cell oobject based on a pair  of coordinates
    def find(self, row, col):
        """Find a cell oobject based on a pair  of coordinates"""
        cell = None
        for c in self.cells: 
            if c.position == (row, col):
                cell = c
                break

        return cell

    # Set the kind of a cell given its row and col coords
    def set_cell_kind(self, row, col, kind):
        """Set the kind of a cell given its row and col coords"""
        cell = None
        for c in self.cells: 
            if c.position == (row, col):
                c.kind = kind
                cell = c
                break

        return cell

    # Find a cell object based on a label
    def find_label(self, label):
        """Find a cell oobject based on a label"""
        idx = 0
        cell = None
        for c in self.cells: 
            if c.label == label:
                cell = c
                break

        return cell
    
    # Set the coordinates of the goal cell
    def set_goal(self, row, col):
        """Find a cell object based on a label"""
        self.goal_cell = find(row, col)
        
    # Print an ascii version of the the world environment    
    def print_world(self, details = False):
        """Print an ascii version of the the world environment
        Arguments
            details: True gies a details version of each cell whereas False gives a simple grid version
        """
        print("Goal: ", self.goal_cell.to_string())
        
        if details:

            for idx, c in enumerate(self.cells):
                print(c.to_string(True))
                if(idx%5 == 0):
                    print("")
            
        else:
                
            idx = 0
            for r in range(0, self.num_rows):
                print("|" + "---"*self.num_cols + "|")
                line = "|"
                for c in range(0, self.num_cols):
                    line = line + self.cells[idx].label + "|"
                    idx += 1
                print(line)
            print("|" + "---"*self.num_cols + "|")
        
    # Generate a numeric array version of the world - used for drawing a heatmap based map image
    def generate_numeric_map(self):
        """Generate a numeric array version of the world - used for drawing a heatmap based map image"""
        self.numeric_map = np.zeros((self.num_rows, self.num_cols))
        
        for idx, c in enumerate(self.cells):
            if c.kind == 'empty':
                value = 0
            elif c.kind == 'trap':
                value = 10
            elif c.kind == 'trap1':
                value = 10
            elif c.kind == 'trap2':
                value = 20
            elif c.kind == 'trap3':
                value = 30
            elif c.kind == 'start':
                value = 40
            elif c.kind == 'goal':
                value = 50
            else:
                value = 0
                
            self.numeric_map[c.position[0], c.position[1]] = value

    # Generate a numeric array version of the world - used for drawing a heatmap based map image
    def generate_label_map(self):
        """Generate a numeric array version of the world - used for drawing a heatmap based map image"""
        self.label_map = np.empty((self.num_rows, self.num_cols), dtype= np.str)

        # Unicode arrows http://xahlee.info/comp/unicode_arrows.html
        # left: 2190
        # right: 2192
        # up: 2191
        # down: 2193
        
        for idx, c in enumerate(self.cells):
            if c.kind == 'empty':
                label = ''
            elif c.kind == 'trap':
                label = 't'
            elif c.kind == 'trap1':
                label = 'm'
            elif c.kind == 'trap2':
                label = 't'
            elif c.kind == 'trap3':
                label = 'f'
            elif c.kind == 'start':
                label = 'S'
            elif c.kind == 'goal':
                label = 'G'
            else:
                label = 0
                
            self.label_map[c.position[0], c.position[1]] = label
            
    # Draw the world map colour coding special cells
    def draw_map(self):
        """Draw the world map colour coding special cells"""
        # Set up a colour map for different cell types (0 = 'empty', 10 = 'trap', 20 = 'start', 30 = 'goal')
        #colors = ["light grey", "yellow", "orange", "red", "faded green", "ochre"]
        colors = ["light grey", "dark grey", "dark grey", "dark grey", "black", "black"]
        ax = sns.heatmap(self.numeric_map, annot = self.label_map, fmt='', linewidths=.5, vmax= 50, vmin = 0, cmap=sns.xkcd_palette(colors), cbar=False) 
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
        return ax
    

            
    # Draw the world map colour coding special cells
    def draw_map_visited_states(self, states, filename = None):
        """Draw the world map colour coding special cells and cells visited on a trip"""

        # gGenerate a new numeric map showing visited cells
        new_numeric_map = np.zeros((self.num_rows, self.num_cols))
        
        for idx, c in enumerate(self.cells):
            if c.kind == 'empty':
                value = 0
            elif c.kind == 'trap':
                value = 10
            elif c.kind == 'trap1':
                value = 10
            elif c.kind == 'trap2':
                value = 20
            elif c.kind == 'trap3':
                value = 30
            elif c.kind == 'start':
                value = 40
            elif c.kind == 'goal':
                value = 50
            else:
                value = 0
                
            if c.label in states and c.label != self.start_cell.label and c.label != self.goal_cell.label:
                value = 60
                
            new_numeric_map[c.position[0], c.position[1]] = value
            
        plt.figure(figsize=(6, 6))
        # Set up a colour map for different cell types (0 = 'empty', 10 = 'trap', 20 = 'start', 30 = 'goal')
        #colors = ["light grey", "yellow", "orange", "red", "faded green", "ochre", "black"]
        colors = ["light grey", "dark grey", "dark grey", "dark grey", "black", "black", "white"]
        ax = sns.heatmap(new_numeric_map, annot = self.label_map, fmt='', linewidths=.5, vmax= 60, vmin = 0, cmap=sns.xkcd_palette(colors), cbar=False) 
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            
        if(filename != None):
            plt.savefig(filename, bbox_inches = 'tight', dpi = saving_dpi)
        
        return ax
    
    # Draw the world map colour coding special cells
    def draw_map_visited_states_labels(self, states, actions, filename = None):
        """Draw the world map colour coding special cells and cells visited on a trip"""

        # gGenerate a new numeric map showing visited cells
        new_numeric_map = np.zeros((self.num_rows, self.num_cols))
        
        for idx, c in enumerate(self.cells):
            if c.kind == 'empty':
                value = 0
            elif c.kind == 'trap':
                value = 10
            elif c.kind == 'trap1':
                value = 10
            elif c.kind == 'trap2':
                value = 20
            elif c.kind == 'trap3':
                value = 30
            elif c.kind == 'start':
                value = 40
            elif c.kind == 'goal':
                value = 50
            else:
                value = 0
          
            new_numeric_map[c.position[0], c.position[1]] = value

        # Unicode arrows http://xahlee.info/comp/unicode_arrows.html
        # left: 2190
        # right: 2192
        # up: 2191
        # down: 2193
        
        new_label_map = np.empty((self.num_rows, self.num_cols), dtype= np.str)
                
        self.find_label
        for idx, c_label in enumerate(states): 

            c = self.find_label(c_label)
            
            # The last state has no action, so for all other states find the action and draw it
            if idx != (len(states)-1):
                
                if actions[idx] == 'left':
                    label = u'\u2190'
                    #label = u'\u21e6'
                elif actions[idx] == 'right':
                    label = u'\u2192'
                    #label = u'\u21e8'
                elif actions[idx] == 'up':
                    label = u'\u2191'
                    #label = u'\u21e7'
                elif actions[idx] == 'down':
                    label = u'\u2193'
                    #label = u'\u21e9'
                #else:
                #    label = 0
            else:
                label = 0
                
            new_label_map[c.position[0], c.position[1]] = label
            
                
            
            
        plt.figure(figsize=(6, 6))
        # Set up a colour map for different cell types (0 = 'empty', 10 = 'trap', 20 = 'start', 30 = 'goal')
        #colors = ["light grey", "yellow", "orange", "red", "faded green", "ochre", "black"]
        colors = ["light grey", "dark grey", "dark grey", "dark grey", "black", "black"]
        sns.set(font_scale=1.5)
        ax = sns.heatmap(new_numeric_map, annot = new_label_map, fmt='', linewidths=.5, vmax= 60, vmin = 0, cmap=sns.xkcd_palette(colors), cbar=False) 
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            
        if(filename != None):
            plt.savefig(filename, bbox_inches = 'tight', dpi = saving_dpi)
        
        return ax
    
    # Draw the world map colour coding special cells
    def draw_map_visited_states_labels(self, states, actions, filename = None):
        """Draw the world map colour coding special cells and cells visited on a trip"""

        # gGenerate a new numeric map showing visited cells
        new_numeric_map = np.zeros((self.num_rows, self.num_cols))
        
        for idx, c in enumerate(self.cells):
            if c.kind == 'empty':
                value = 0
            elif c.kind == 'trap':
                value = 10
            elif c.kind == 'trap1':
                value = 10
            elif c.kind == 'trap2':
                value = 20
            elif c.kind == 'trap3':
                value = 30
            elif c.kind == 'start':
                value = 40
            elif c.kind == 'goal':
                value = 50
            else:
                value = 0
          
            new_numeric_map[c.position[0], c.position[1]] = value

        # Unicode arrows http://xahlee.info/comp/unicode_arrows.html
        # left: 2190
        # right: 2192
        # up: 2191
        # down: 2193
        
        new_label_map = np.empty((self.num_rows, self.num_cols), dtype= np.str)
                
        self.find_label
        for idx, c_label in enumerate(states): 

            c = self.find_label(c_label)
            
            # The last state has no action, so for all other states find the action and draw it
            if idx != (len(states)-1):
                
                if actions[idx] == 'left':
                    label = u'\u2190'
                    #label = u'\u21e6'
                elif actions[idx] == 'right':
                    label = u'\u2192'
                    #label = u'\u21e8'
                elif actions[idx] == 'up':
                    label = u'\u2191'
                    #label = u'\u21e7'
                elif actions[idx] == 'down':
                    label = u'\u2193'
                    #label = u'\u21e9'
                
                new_label_map[c.position[0], c.position[1]] = label
            
        plt.figure(figsize=(6, 6))
        # Set up a colour map for different cell types (0 = 'empty', 10 = 'trap', 20 = 'start', 30 = 'goal')
        #colors = ["light grey", "yellow", "orange", "red", "faded green", "ochre", "black"]
        colors = ["light grey", "dark grey", "dark grey", "dark grey", "black", "black"]
        ax = sns.heatmap(new_numeric_map, annot = new_label_map, fmt='', linewidths=.5, vmax= 60, vmin = 0, cmap=sns.xkcd_palette(colors), cbar=False, annot_kws={"size": 20}) 
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            
        if(filename != None):
            plt.savefig(filename, bbox_inches = 'tight', dpi = saving_dpi)
        
        return ax
    

    # Draw the world map colour coding special cells
    def draw_map_policy(self, agent, filename = None):
        """Draw the world map colour coding special cells and showing the best action from each state"""

        # gGenerate a new numeric map showing visited cells
        new_numeric_map = np.zeros((self.num_rows, self.num_cols))
        
        for idx, c in enumerate(self.cells):
            if c.kind == 'empty':
                value = 0
            elif c.kind == 'trap':
                value = 10
            elif c.kind == 'trap1':
                value = 10
            elif c.kind == 'trap2':
                value = 20
            elif c.kind == 'trap3':
                value = 30
            elif c.kind == 'start':
                value = 40
            elif c.kind == 'goal':
                value = 50
            else:
                value = 0
          
            new_numeric_map[c.position[0], c.position[1]] = value

        # Unicode arrows http://xahlee.info/comp/unicode_arrows.html
        # left: 2190
        # right: 2192
        # up: 2191
        # down: 2193
        
        new_label_map = np.empty((self.num_rows, self.num_cols), dtype= np.str)
                
        for idx, c in enumerate(self.cells):

            if c != self.goal_cell:
                a = agent.choose_action_greedy_for_cell(c)

                if a == 'left':
                    label = u'\u2190'
                    #label = u'\u21e6'
                elif a == 'right':
                    label = u'\u2192'
                    #label = u'\u21e8'
                elif a == 'up':
                    label = u'\u2191'
                    #label = u'\u21e7'
                elif a == 'down':
                    label = u'\u2193'
                    #label = u'\u21e9'

                new_label_map[c.position[0], c.position[1]] = label
            
        plt.figure(figsize=(6, 6))
        # Set up a colour map for different cell types (0 = 'empty', 10 = 'trap', 20 = 'start', 30 = 'goal')
        #colors = ["light grey", "yellow", "orange", "red", "faded green", "ochre", "black"]
        colors = ["light grey", "dark grey", "dark grey", "dark grey", "black", "black"]
        ax = sns.heatmap(new_numeric_map, annot = new_label_map, fmt='', linewidths=.5, vmax= 60, vmin = 0, cmap=sns.xkcd_palette(colors), cbar=False, annot_kws={"size": 20}) 
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            
        if(filename != None):
            plt.savefig(filename, bbox_inches = 'tight', dpi = saving_dpi)
        
        return ax

Test the world

In [None]:
w = World()
plt.figure(figsize=(5, 5))
w.draw_map()

In [None]:
w = World(8, 8, special_cells=[((4, 1), 'trap'), ((4, 2), 'trap'), ((4, 3), 'trap'), ((4, 4), 'trap'), ((4, 5), 'trap'), ((4, 6), 'trap')])
plt.figure(figsize=(5, 5))
w.draw_map()

In [None]:
w = World(7, 7, special_cells=[((0, 3), 'start'), ((6, 3), 'goal'),
                               ((2, 1), 'trap1'), ((2, 2), 'trap3'), ((2, 4), 'trap3'), ((2, 5), 'trap1'), 
                               ((3, 1), 'trap1'), ((3, 2), 'trap3'), ((3, 4), 'trap3'), ((3, 5), 'trap1'), 
                               ((4, 1), 'trap1'), ((4, 2), 'trap3'), ((4, 4), 'trap3'), ((4, 5), 'trap1')])
plt.figure(figsize=(5, 5))
w.draw_map()

## Modelling Agents

RL Agent class definition.

In [None]:
# Utility fucntion to creaate empty data frame with column names and types
# From: https://stackoverflow.com/questions/36462257/create-empty-dataframe-in-pandas-specifying-column-types
def df_empty(columns, dtypes, index=None):
    assert len(columns)==len(dtypes)
    df = pd.DataFrame(index=index)
    for c,d in zip(columns, dtypes):
        df[c] = pd.Series(dtype=d)
    return df

df = df_empty(['a', 'b'], dtypes=[np.int64, np.int64])
df

In [None]:
class RL_agent:
    
    def __init__(self, world, start, kind = "SARSA", epsilon = 0.1, alpha = 0.2, gamma = 0.9):
        """Constructor
        Arguments
            world: the world to which the agent belongs
            start: the cell to which the agent belongs when it starts
            epsilon: epsilon parameter for epsilon greedy action selection policy
            alpha: learning rate for SARSA algorithm
            gamma: gamma parameter to balance reward and state contributions in SARSA algorithm
        """
        self.world = world
        self.position = start
        self.kind = kind
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.action_value_function_table = None
        self.pandas_action_value_function_table = None
        self.heatmaps = dict()

        self.__generate_action_value_function_table()
            
        self.__generate_pandas_action_value_function_table()
        self.__generate_action_value_function_heatmaps()
        
    # Construct the acttion value funtion table. Iterate thtrough all state-action pairs and assignment them random values.
    def __generate_action_value_function_table(self):
        """Construct the acttion value funtion table. Iterate thtrough all state-action pairs and assignment them random values."""
        self.action_value_function_table = dict()
    
        # Iterate through states
        for c in self.world.cells:
            
            # Iterate through actions and assign each state-action pair a random values
            actions = dict()
            for a in c.actions.keys():
    
                actions[a] = random.uniform(-1, 1)
            
                # Set the value of any actions from the gaol table  to  0
                if(c.kind == 'goal'):
                    actions[a] = 0
                    
            self.action_value_function_table[c.label] = actions

            # Construct the acttion value funtion table. Iterate thtrough all state-action pairs and assignment them random values.

    def __generate_pandas_action_value_function_table(self):
        """Construct a pandas version of the action value funtion table."""
        self.pandas_action_value_function_table = df_empty(['State', 'Action', 'Value'], dtypes=[np.str, np.str, np.float])
        
        row_count = 0
        for s in self.action_value_function_table.keys():
            for a in self.action_value_function_table[s].keys():
                self.pandas_action_value_function_table.loc[row_count] = [s, a, self.action_value_function_table[s][a]]
                row_count = row_count + 1
            
    # Print a text version of the action vbalue function table
    def print_action_value_function_table(self):
        """Print a text version of the action vbalue function table"""
        for s in self.action_value_function_table.keys():
            for a in self.action_value_function_table[s].keys():
                print(s + " " + a + " " + " " + str(self.action_value_function_table[s][a]))

    # Write a version of the action value table to csv file
    def save_csv_action_value_function_table(self, out_filename):
        """Write a version of the action value table to a csv file"""
        with open(out_filename, 'w') as out_file:
            out_file.write("State,Action,Value\n")
            for s in self.action_value_function_table.keys():
                for a in self.action_value_function_table[s].keys():
                    line = s + "," + a + "," + str(self.action_value_function_table[s][a]) + " \n"
                    out_file.write(line)

    # Write a version of the action value table to csv file
    def save_latex_action_value_function_table(self, out_filename):
        """Write a version of the action value table to a latex file with FMLPDA formatting"""
        with open(out_filename, 'w') as out_file:
            out_file.write('State & Action & Value \\\\ \n')
            for s in self.action_value_function_table.keys():
                for a in self.action_value_function_table[s].keys():
                    line = "\\rlState{" + s + "} & \\rlAction{" + a + "} & $" + str(round(self.action_value_function_table[s][a], 3)) + "$ \\\\ \n"
                    out_file.write(line)
        
    # Generate a set of numeric heatmap represenations for each action. Used for drawing things.
    def __generate_action_value_function_heatmaps(self):
        """Generate a set of numeric heatmap represenations for each action. Used for drawing things."""

        # Intialise empty heat mapes 
        self.heatmaps = dict()
        self.heatmaps['left'] = np.zeros((self.world.num_rows, self.world.num_cols))
        self.heatmaps['right'] = np.zeros((self.world.num_rows, self.world.num_cols))
        self.heatmaps['up'] = np.zeros((self.world.num_rows, self.world.num_cols))
        self.heatmaps['down'] = np.zeros((self.world.num_rows, self.world.num_cols))

        # Iterate through the action value function table and fill up heatmaps
        for s in self.action_value_function_table.keys():
            for a in self.action_value_function_table[s].keys():
                c = self.world.find_label(s)
                self.heatmaps[a][c.position[0], c.position[1]] = self.action_value_function_table[s][a]

    # Draw the heatmaps in a nice aragnemnt with a world mpa in the middle
    def draw_action_value_function_heatmaps(self, layout = 'grid', color = 'RGB', filename = None):
        """Draw the action value function heatmaps in a nice aragnemnt with a world map in the middle"""
        
        # Update teh action value heatmap data structures
        self.__generate_action_value_function_heatmaps()

        if color == 'RGB':
            cmap_used = sns.color_palette("BrBG", 21)
        elif color == 'BW':
            cmap_used = sns.color_palette("Greys", 12)
        else:
            cmap_used = sns.color_palette("BrBG", 21)
            
        #max_value = max(50, self.heatmaps['up'].max(), self.heatmaps['down'].max(), self.heatmaps['left'].max(), self.heatmaps['right'].max())
        #min_value = min(-20, self.heatmaps['up'].min(), self.heatmaps['down'].min(), self.heatmaps['left'].min(), self.heatmaps['right'].min())
        max_value = 50
        min_value = -50
        
        if(layout == "grid"):            
            plt.figure(figsize=(12,12))

            plt.subplot(332)
            ax = sns.heatmap(self.heatmaps['up'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=False) 
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('up')
            plt.subplot(334)
            ax = sns.heatmap(self.heatmaps['left'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=False) 
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('left')
            plt.subplot(335)
            ax = self.world.draw_map()
            ax.set_title('world')
            plt.subplot(336)
            ax = sns.heatmap(self.heatmaps['right'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=True)
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('right')
            plt.subplot(338)
            ax = sns.heatmap(self.heatmaps['down'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=False)
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('down')

        elif(layout == 'line'):
            fig = plt.figure(figsize=(20,4))
            cbar_ax = fig.add_axes([.91, .15, .015, .7])
            plt.subplot(151)
            ax = self.world.draw_map()
            ax.set_title('world')
            plt.subplot(152)
            ax = sns.heatmap(self.heatmaps['up'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=False) 
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('up')
            plt.subplot(153)
            ax = sns.heatmap(self.heatmaps['down'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=False)
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('down')
            plt.subplot(154)
            ax = sns.heatmap(self.heatmaps['left'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=False) 
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('left')
            plt.subplot(155)
            ax = sns.heatmap(self.heatmaps['right'], linewidths=.5, vmin = min_value, vmax= max_value, center= 0, cmap=cmap_used, cbar=True, cbar_ax = cbar_ax)
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
            ax.set_title('right')
            
        if(filename != None):
            plt.savefig(filename, bbox_inches = 'tight', dpi = saving_dpi)
        
        plt.show()
        
    # Choose an action using the greedy strategy - best available
    def choose_action_greedy(self, verbose = 0):
        """Choose an action using the greedy strategy - best available"""
        potential_actions = list(self.position.actions.keys())
        
        # Choose the best action
        action = potential_actions[0]
        #max_value = self.action_value_function_table[self.position.label][action]['value']
        max_value = self.action_value_function_table[self.position.label][action]

        for a in potential_actions[1:]:
            #value = self.action_value_function_table[self.position.label][a]['value']
            value = self.action_value_function_table[self.position.label][a]
            if value > max_value:
                action = a
                max_value = value
            
        return action
    
    # Choose an action using the greedy strategy - best available
    def choose_action_greedy_for_cell(self, cell, verbose = 0):
        """Choose an action using the greedy strategy - best available"""
        potential_actions = list(cell.actions.keys())
        
        # Choose the best action
        action = potential_actions[0]
        #max_value = self.action_value_function_table[self.position.label][action]['value']
        max_value = self.action_value_function_table[cell.label][action]

        for a in potential_actions[1:]:
            #value = self.action_value_function_table[self.position.label][a]['value']
            value = self.action_value_function_table[cell.label][a]
            if value > max_value:
                action = a
                max_value = value
            
        return action
    
    # Choose an action using an epsilon greedy strategy
    def choose_action_e_greedy(self, verbose = 0):
        """Choose an action using an epsilon greedy strategy"""
        potential_actions = list(self.position.actions.keys())
        
        # Generate a random number, if less than epsilon do a random acrtion, otherwise pick the action with the highest value
        r = random.uniform(0, 1)

        if(verbose == 2): print("choose_action_e_greedy: ", r)
        if(r < self.epsilon):
            if(verbose == 2): print("random")
            action = random.choice(potential_actions)

        # If greater than epsilon choose the best action
        else:
            if(verbose == 2): print("greedy")
            action = self.choose_action_greedy()
            
        return action

    # Choose an action following the current policy
    def choose_action_on_policy(self, verbose = 0):
        """Choose an action following the current policy"""
        return self.choose_action_e_greedy(verbose)
    
    # Choose an action off policy - choose the best one
    def choose_action_off_policy(self, verbose = 0):
        """Choose an action following the current policy"""
        return self.choose_action_greedy(verbose)
    
    # Make  a move and return the resulting reward and the new state
    def move(self, action):
        """Make  a move and return the resulting reward and the new state"""
        r = self.position.actions[action]['reward']
        s_prime = self.position.actions[action]['state']
                
        self.position = s_prime
        return r, s_prime
    
    # Perform a number of reinforcement learning episodes 
    def perform_RL_episodes(self, num_episodes = 1, heatmap_layout = 'grid', image_filename_root = None, verbose = 0):
        """Perform a number of reinforcement learning episodes using whicherver approach the agent is setup for
        Arguments
            num_episodes: the number of RL episodes to perform
            verbose: how much debug info to print (0: none, 1: updates every 50 episodes, 2: updates every episode])
        """
        if self.kind == "SARSA":
            return self.perform_RL_episodes_SARSA(num_episodes, heatmap_layout, image_filename_root, verbose)
    
        elif self.kind == "Q_Learning":
            return self.perform_RL_episodes_Q_Learning(num_episodes, heatmap_layout, image_filename_root, verbose)
    

    # Perform a number of reinforcement learning episodes using the SARSA approach
    def perform_RL_episodes_SARSA(self, num_episodes = 1, heatmap_layout = 'grid', image_filename_root = None, verbose = 0):
        """Perform a number of reinforcement learning episodes using the SARSA approach
        Arguments
            num_episodes: the number of RL episodes to perform
            verbose: how much debug info to print (0: none, 1: updates every 50 episodes, 2: updates every episode])
        """
        episodes = list()

        # Iterate through the episodes
        for e in range(0, num_episodes):
            
            total_reward = 0
            moves = list()
            states = list()
            
            # Reset agent to first state
            self.position = self.world.start_cell

            s = self.position
            
            a = self.choose_action_on_policy(verbose)

            # Repeat until the goal position is reached
            idx = 0
            while s.position != self.world.goal_cell.position:

                # Take a move
                r, s_prime = self.move(a)
                total_reward += r
                moves.append(a)
                states.append(s)

                # Choose the next action
                a_prime = self.choose_action_on_policy(verbose)

                # Calcualte the new value for the action value function
                q_s_a = self.action_value_function_table[s.label][a]
                q_s_prime_a_prime = self.action_value_function_table[s_prime.label][a_prime]
                new_value = q_s_a + self.alpha*(r + self.gamma*q_s_prime_a_prime - q_s_a)

                # Update the value table
                self.action_value_function_table[s.label][a] =  new_value

                # Update the current state and ation
                s = s_prime
                a = a_prime
                idx = idx + 1
            
            states.append(s)
            episodes.append({'idx':e,'return':total_reward, 'moves':moves, 'states':states})
            
            # Draw heatmaps for every 50th episode
            if((verbose == 1 and e%50 == 0) or verbose == 2):
                print("Episode:", e)
                print("Moves:",  moves)
                print("Reward: ", total_reward) 
                filename = None
                if(image_filename_root != None):
                    filename = image_filename_root + "_ep_" + str(e) + ".pdf"
                self.draw_action_value_function_heatmaps(heatmap_layout, filename)

                filename = None
                if(image_filename_root != None):
                    filename = image_filename_root + "_ep_" + str(e) + "_policy.pdf"
                self.world.draw_map_policy(self, filename)
                
                if(image_filename_root != None):
                    filename_tex = image_filename_root + "_ep_" + str(e) + "_action_value_table.tex"
                    self.save_latex_action_value_function_table(filename_tex)
                    filename_csv = image_filename_root + "_ep_" + str(e) + "_action_value_table.csv"
                    self.save_csv_action_value_function_table(filename_csv)

                
       
        # Dreaw heatmaps of the final table
        print("Episode:", e)
        print("Moves:",  moves)
        print("Reward: ", total_reward) 
        filename = None
        if(image_filename_root != None):
            filename = image_filename_root + "_ep_" + str(e) + ".pdf"
        self.draw_action_value_function_heatmaps(heatmap_layout, filename)
        filename = None
        if(image_filename_root != None):
            filename = image_filename_root + "_ep_" + str(e) + "_policy.pdf"
        self.world.draw_map_policy(self, filename)
        
        return episodes
    
    # Perform a number of reinforcement learning episodes using the Q Learning approach
    def perform_RL_episodes_Q_Learning(self, num_episodes = 1, heatmap_layout = 'grid', image_filename_root = None, verbose = 0):
        """Perform a number of reinforcement learning episodes using the Q Learning approach
        Arguments
            num_episodes: the number of RL episodes to perform
            verbose: how much debug info to print (0: none, 1: updates every 50 episodes, 2: updates every episode])
        """
        episodes = list()

        # Iterate through the episodes
        for e in range(0, num_episodes):
            
            total_reward = 0
            moves = list()
            states = list()

            # Reset agent to first state
            self.position = self.world.start_cell

            s = self.position
            
            idx = 0
            # Repeat until the goal position is reached
            while s.position != self.world.goal_cell.position:

                a = self.choose_action_on_policy(verbose)
                    
                # Take a move
                r, s_prime = self.move(a)
                total_reward += r
                moves.append(a)
                states.append(s)
                    
                # Choose the next action
                a_prime = self.choose_action_off_policy(verbose)

                # Calcualte the new value for the action value function
                q_s_a = self.action_value_function_table[s.label][a]
                q_s_prime_a_prime = self.action_value_function_table[s_prime.label][a_prime]
                new_value = q_s_a + self.alpha*(r + self.gamma*q_s_prime_a_prime - q_s_a)

                # Update the value table
                self.action_value_function_table[s.label][a] =  new_value

                # Update the current state and ation
                s = s_prime
                
                idx = idx + 1
                
            states.append(s)
            episodes.append({'idx':e,'return':total_reward, 'moves':moves, 'states':states})
            
            # Draw heatmaps for every 50th episode
            if((verbose == 1 and e%50 == 0) or verbose == 2):
                print("Episode:", e)
                print("Moves:",  moves)
                print("Reward: ", total_reward) 
                filename = None
                if(image_filename_root != None):
                    filename = image_filename_root + "_ep_" + str(e) + ".pdf"
                self.draw_action_value_function_heatmaps(heatmap_layout, filename)
                filename = None
                if(image_filename_root != None):
                    filename = image_filename_root + "_ep_" + str(e) + "_policy.pdf"
                self.world.draw_map_policy(self, filename)
                if(image_filename_root != None):
                    filename_tex = image_filename_root + "_ep_" + str(e) + "_action_value_table.tex"
                    self.save_latex_action_value_function_table(filename_tex)
                    filename_csv = image_filename_root + "_ep_" + str(e) + "_action_value_table.csv"
                    self.save_csv_action_value_function_table(filename_csv)

       
        # Dreaw heatmaps of the final table
        print("Episode:", e)
        print("Moves:",  moves)
        print("Reward: ", total_reward) 
        if(image_filename_root != None):
            filename = image_filename_root + "_ep_" + str(e) + ".pdf"
        self.draw_action_value_function_heatmaps(heatmap_layout, filename)
        filename = None
        if(image_filename_root != None):
            filename = image_filename_root + "_ep_" + str(e) + "_policy.pdf"
        self.world.draw_map_policy(self, filename)
        
        return episodes
    
    # Perform a number of reinforcement learning episodes using the Q Learning approach
    def perform_offline_task(self, image_filename_root = None, verbose = 0):
        """Perform athe task using a trained value fucntion
        Arguments
            verbose: how much debug info to print (0: none, 1: updates every 50 episodes, 2: updates every episode])
        """
            
        total_reward = 0
        moves = list()
        states = list()

        # Reset agent to first state
        self.position = self.world.start_cell

        s = self.position
        states.append(s.label)
            
        idx = 0
        # Repeat until the goal position is reached
        while s.position != self.world.goal_cell.position:

            a = self.choose_action_greedy(verbose)

            # Take a move
            r, s_prime = self.move(a)
            total_reward += r
            moves.append(a)

            # Update the current state and ation
            s = s_prime
            states.append(s.label)
            
            idx = idx + 1
        
        return states, moves

# Q-learning

Create a world

In [None]:
#demo_w = World(7, 7, special_cells=[((0, 3), 'start'), ((6, 3), 'goal'),
#                               ((2, 2), 'trap3'), ((2, 4), 'trap3'), 
#                               ((3, 2), 'trap3'), ((3, 3), 'trap1'), ((3, 4), 'trap3'), 
#                               ((4, 2), 'trap3'), ((4, 4), 'trap3')])
demo_w = World(7, 7, special_cells=[((0, 2), 'start'), ((6, 4), 'goal'),
                               ((2, 2), 'trap3'), ((2, 4), 'trap3'), 
                               ((3, 2), 'trap3'), ((3, 4), 'trap3'), 
                               ((4, 2), 'trap3'), ((4, 4), 'trap3')])
plt.figure(figsize=(5, 5))
demo_w.draw_map()
plt.show()

Create an agent and perform a few episodes

In [None]:
random.seed(random_seed) 
a = RL_agent(demo_w, demo_w.start_cell, kind = "Q_Learning", epsilon = epsilon)
display(a.pandas_action_value_function_table)
ql_eps = a.perform_RL_episodes(2, verbose = 1, heatmap_layout = 'line', image_filename_root = 'rl_grid_world_q_learning')
df = pd.DataFrame(ql_eps)
ax = df.loc[0:100, 'return'].plot(color = 'black')
ax.set_xlabel("Episode")
ax.set_ylabel("Return")

Run another episode

In [None]:
random.seed(random_seed) 
a = RL_agent(demo_w, demo_w.start_cell, kind = "Q_Learning", epsilon = epsilon)
ql_eps = a.perform_RL_episodes(1, verbose = 1, heatmap_layout = 'line')
a.print_action_value_function_table()


Run the agent for 350 episodes

In [None]:
random.seed(random_seed) 
a = RL_agent(demo_w, demo_w.start_cell, kind = "Q_Learning", epsilon = epsilon)
ql_eps = a.perform_RL_episodes(num_episodes, heatmap_layout = 'line',image_filename_root = "rl_grid_world_q_learning_", verbose = 1)

Plot the evoloution  of the reward thriough the learning episodes

In [None]:
df = pd.DataFrame(ql_eps)
ax = df.loc[0:num_episodes, 'return'].plot(color = 'lightgray', xlim = [-5, num_episodes])
df['return'].rolling(10).mean().plot(color = 'black', xlim = [-5, num_episodes])
ax.set_xlabel("Episode")
plt.ylabel("Rolling Mean (10) Cumulative Return")
plt.show()

Get the agent to navigate. 

In [None]:
s, m = a.perform_offline_task()
print(s)
print(m)
demo_w.draw_map_visited_states_labels(s, m, "rl_grid_world_q_learning__offline_path_arrows.pdf")

# SARSA

Create an agent an perform some episodes

In [None]:
random.seed(random_seed) 
a = RL_agent(demo_w, demo_w.start_cell, kind = "SARSA", epsilon = epsilon)
sarsa_eps = a.perform_RL_episodes(num_episodes, verbose = 1, heatmap_layout = 'line', image_filename_root = "rl_grid_world_sarsa_")

Plot the reward

In [None]:
df = pd.DataFrame(sarsa_eps)
ax = df.loc[0:num_episodes, 'return'].plot(color = 'lightgray', xlim = [-5, num_episodes])
df['return'].rolling(10).mean().plot(color = 'black', xlim = [-5, num_episodes])
ax.set_xlabel("Episode")
plt.ylabel("Rolling Mean (10) Cumulative Return")
plt.show()

In [None]:
s, m = a.perform_offline_task()
print(s)
print(m)
demo_w.draw_map_visited_states_labels(s, m, "rl_grid_world_sarsa__offline_path_arrows.pdf")