In [1]:
%matplotlib
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb

def change_range(values, vmin=0, vmax=1):
    start_zero = values - np.min(values)
    return (start_zero / (np.max(start_zero) + 1e-7)) * (vmax - vmin) + vmin

class GridWorld:
    terrain_color = dict(normal=[127/360, 0, 96/100],
                         objective=[26/360, 100/100, 100/100],
                         cliff=[247/360, 92/100, 70/100],
                         player=[344/360, 93/100, 100/100])
        
    def __init__(self):
        self.player = None
        self._create_grid()  
        self._draw_grid()
        
    def _create_grid(self, initial_grid=None):
        self.grid = self.terrain_color['normal'] * np.ones((4, 12, 3))
        self._add_objectives(self.grid)
        
    def _add_objectives(self, grid):
        grid[-1, 1:11] = self.terrain_color['cliff']
        grid[-1, -1] = self.terrain_color['objective']
        
    def _draw_grid(self):
        self.fig, self.ax = plt.subplots(figsize=(12, 4))
        self.ax.grid(which='minor')       
        self.q_texts = [self.ax.text(*self._id_to_position(i)[::-1], '0',
                                     fontsize=11, verticalalignment='center', 
                                     horizontalalignment='center') for i in range(12 * 4)]     
         
        self.im = self.ax.imshow(hsv_to_rgb(self.grid), cmap='terrain',
                                 interpolation='nearest', vmin=0, vmax=1)        
        self.ax.set_xticks(np.arange(12))
        self.ax.set_xticks(np.arange(12) - 0.5, minor=True)
        self.ax.set_yticks(np.arange(4))
        self.ax.set_yticks(np.arange(4) - 0.5, minor=True)
        
    def reset(self):
        self.player = (3, 0)        
        return self._position_to_id(self.player)
    
    def step(self, action):
        # Possible actions
        if action == 0 and self.player[0] > 0:
            self.player = (self.player[0] - 1, self.player[1])
        if action == 1 and self.player[0] < 3:
            self.player = (self.player[0] + 1, self.player[1])
        if action == 2 and self.player[1] < 11:
            self.player = (self.player[0], self.player[1] + 1)
        if action == 3 and self.player[1] > 0:
            self.player = (self.player[0], self.player[1] - 1)
            
        # Rules
        if all(self.grid[self.player] == self.terrain_color['cliff']):
            reward = -100
            done = True
        elif all(self.grid[self.player] == self.terrain_color['objective']):
            reward = 0
            done = True
        else:
            reward = -1
            done = False
            
        return self._position_to_id(self.player), reward, done
    
    def _position_to_id(self, pos):
        ''' Maps a position in x,y coordinates to a unique ID '''
        return pos[0] * 12 + pos[1]
    
    def _id_to_position(self, idx):
        return (idx // 12), (idx % 12)
        
    def render(self, q_values=None, action=None, max_q=False, colorize_q=False):
        assert self.player is not None, 'You first need to call .reset()'  
        
        if colorize_q:
            assert q_values is not None, 'q_values must not be None for using colorize_q'            
            grid = self.terrain_color['normal'] * np.ones((4, 12, 3))
            values = change_range(np.max(q_values, -1)).reshape(4, 12)
            grid[:, :, 1] = values
            self._add_objectives(grid)
        else:            
            grid = self.grid.copy()
            
        grid[self.player] = self.terrain_color['player']       
        self.im.set_data(hsv_to_rgb(grid))
               
        if q_values is not None:
            xs = np.repeat(np.arange(12), 4)
            ys = np.tile(np.arange(4), 12)  
            
            for i, text in enumerate(self.q_texts):
                if max_q:
                    q = max(q_values[i])    
                    txt = '{:.2f}'.format(q)
                    text.set_text(txt)
                else:                
                    actions = ['U', 'D', 'R', 'L']
                    txt = '\n'.join(['{}: {:.2f}'.format(k, q) for k, q in zip(actions, q_values[i])])
                    text.set_text(txt)
                
        if action is not None:
            self.ax.set_title(action, color='r', weight='bold', fontsize=32)

        plt.pause(0.001)



Using matplotlib backend: Qt5Agg


In [7]:
def egreedy_policy(q_values, state, epsilon = 0.2):
    # This function gets a random number from a uniform distribution
    # If the number is greater than epsilon, act greedily
    # If the number is less than epsilon, choose randomly

    # Parameters:
    ## state: an integer describing the current state (row of Q table)
    ## q_values: Q table of stored rewards
    ## epsilon: the probability describing degree of encouraged exploration

    # Returns:
    ## An integer describing index of action chosen
    if np.random.uniform() > epsilon:
        # return the highest scoring action for a given state
        return np.argmax(q_values[state,:])
    else:
        # Pick a random action of the available actions
        return np.random.randint(4)

In [3]:
def greedy_policy(q_values, state):
    # This function selects the greediest action from a list
    # of given actions (highest score). FOr any ties, the first
    # greedy action is chosen

    # Parameters:
    ## state: an integer describing the current state (row of Q table)
    ## q_values: Q table of stored rewards

    # Returns:
    ## An integer describing index of action chosen

    return np.argmax(q_values[state,:])


In [4]:
def sarsa_episode(env, q_values, state,  policy_func, alpha):
    # This function runs one episode of the task, using a specified environment and policy
    # Parameters:
        ## env: The environment in which episode will be run in
        ## q_values: The table of rewards stored through iterations
        ## state: The starting state of the environment
        ## action: Initial action to be taken by the agent
        ## policy_func: handle of function describing learner policy
        ## alpha: The user specified learning rate

    # Returns:
        ## Rewards: a list of rewards for actions chosen in the episode
        ## Q_values: an updated state-action value table
    total_reward = []
    # Initialize done variable to a given value
    done = False
    while not done:
        action = egreedy_policy(q_values, state)
        next_state, reward, done = env.step(action)

        # Evaluate the successor action a'
        next_action = policy_func(q_values, next_state)
        
        # Evaluate the Temporal DIfference (TD) error
        td_error = (reward + (gamma * q_values[next_state, next_action]) - q_values[state, action])

        # Update the q_value table for the state-action pair 
        q_values[state, action] +=  (alpha * td_error)
        
        # Update the state and action to the next state and action
        state = next_state
        action = next_action

        # Append the reward to the list of total rewards for the episode so far
        total_reward.append(reward)

    # Convert the total rewards obtained into an array and return it along with updated state-action value table
    total_reward = np.asarray(total_reward)
    return (q_values,total_reward)


In [5]:
def play_sarsa(env, q_values, state, policy_func, alpha):
    # This function runs one episode of the task, using a specified environment and policy
    # Parameters:
        ## env: The environment in which episode will be run in
        ## q_values: The table of rewards stored through iterations
        ## state: The starting state of the environment
        ## action: Initial action to be taken by the agent
        ## policy_func: handle of function describing learner policy
        ## alpha: The user specified learning rate


    # Initialize done variable to a given value
    done = False
    while not done:
        action = policy_func(q_values, state)
        next_state, reward, done = env.step(action)
        env.render(q_values, colorize_q = True)

        # print(f"Now in state: {state}, with reward {reward} ")

        # Evaluate the successor action a'
        next_action = policy_func(q_values, next_state)
        
        # Update the state and action to the next state and action
        state = next_state
        action = next_action



In [10]:
### SARSA ALGORTIHM ###

env = GridWorld()
num_states = 48
num_actions = 4
sarsa_episode_rewards = []
alpha = 0.5
gamma = 1
num_episodes = 500
# Initialize the Q-values table
sarsa_Q_values = np.zeros((num_states,num_actions))

for i in tqdm(range(num_episodes)):
    # Evaluate state action pair for initial state
    state = env.reset()

    # Run a Sarsa episode
    sarsa_Q_values, r = sarsa_episode(env, sarsa_Q_values, state, egreedy_policy, alpha)
    # Append the total reward to the episode reward
    sarsa_episode_rewards.append(sum(r))
# Close the training figure
plt.close()
# Visualize Sarsa results
play_env = GridWorld()
state = play_env.reset()
action = egreedy_policy(sarsa_Q_values, state)
play_env.render(sarsa_Q_values, colorize_q = True)
play_sarsa(play_env, sarsa_Q_values, state, greedy_policy, alpha)



100%|██████████| 500/500 [00:00<00:00, 1012.45it/s]


In [None]:
%matplotlib inline
plt.plot(sarsa_episode_rewards)
plt.show