In [336]:
import numpy as np
import pandas as pd
import gym

# 1. Cliff Walking environment

In [337]:
""" Environment information
ref: https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py#L18

Actions
    There are 4 discrete deterministic actions:
    - 0: move up
    - 1: move right
    - 2: move down
    - 3: move left
Observations
    There are 3x12 + 1 possible states. In fact, the agent cannot be at the cliff, nor at the goal
    (as this results in the end of the episode).
    It remains all the positions of the first 3 rows plus the bottom-left cell.
    The observation is simply the current position encoded as [flattened index](https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html).
Reward
    Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward.
"""
env = gym.make("CliffWalking-v0")

# 2. N-step SARSA

In [338]:
class SARSA:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.env = env
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

        # initialize the Q-table with zeros
        self.Q = np.zeros((env.observation_space.n, env.action_space.n))
        
        # CUSTOMIZE: the environment shape for CliffWalking-v0 is (4, 12)
        self.env_shape = (4, 12)

    def choose_action(self, state):
        raise NotImplementedError()

    def update_Q(self, states, actions, rewards):
        raise NotImplementedError()

    def generate_sample_episode(self):
        raise NotImplementedError()

    def train(self, episodes):
        raise NotImplementedError()
    
    def evaluate_policy(self, episodes):
        raise NotImplementedError()
    
    def visualize_q_table(self):
        raise NotImplementedError()

In [409]:
class NStepSARSA(SARSA):
    def __init__(self, env, n_step=1, alpha=0.1, gamma=0.99, epsilon=0.1):
        super().__init__(env, alpha, gamma, epsilon)
        self.n_step = n_step

    def choose_action(self, state, always_greedy=False):
        """
        Chooses an action using an epsilon-greedy policy.
        """
        if always_greedy:
            return np.argmax(self.Q[state])
        if np.random.uniform(0, 1) < self.epsilon:
            return self.env.action_space.sample()
        else:
            return np.argmax(self.Q[state])

    def update_Q(self, states, actions, rewards):
        """
        Updates the Q-table using N-Step SARSA.
        """
        G = 0
        for i in range(len(rewards)):
            G += (self.gamma ** i) * rewards[i]

        # Update Q(st,at):= Q(st,at) + lr [[G(t:t+n) + (gamma ** n_step) * Q(s(t+n),a(t+n))] - Q(st,at)]
        state, action = states[0], actions[0]
        self.Q[state, action] += self.alpha * (G + (self.gamma ** self.n_step) * self.Q[states[-1], actions[-1]] - self.Q[state, action])

    # TODO: add display option for visualization of the episode
    def generate_sample_episode(self, display=False, always_greedy=False):
        episode = list()
        off_cliff_info = {"status": False, "prev_state": None}
    
        state = self.env.reset()
        while True:
            timestep = list()
            prev_state = state
            
            action = self.choose_action(state, always_greedy=always_greedy)
            state, reward, done, _ = self.env.step(action)
            
            timestep.append(prev_state)
            timestep.append(action)
            timestep.append(reward)
            episode.append(timestep)

            # CUSTOMIZE: if the agent falls off the cliff, the episode ends
            if reward == -100:
                done = True
                off_cliff_info["status"] = True
                off_cliff_info["prev_state"] = prev_state

            if done:
                # include the terminal state
                episode.append([state, None, None])
                break
        
        return episode, off_cliff_info

    def train(self, episodes):
        """
        Trains the agent using N-Step SARSA.
        """
        for _ in range(episodes):
            state = self.env.reset()

            states = [state]
            actions = [self.choose_action(state)]
            rewards = [0.0]

            t = 0
            T = float("inf")
            while True:
                if t < T:
                    next_state, reward, done, _ = self.env.step(actions[-1])

                    states.append(next_state)
                    rewards.append(reward)

                    if done:
                        T = t + 1
                    else:
                        next_action = self.choose_action(next_state)
                        actions.append(next_action)

                tau = t - self.n_step + 1

                if tau >= 0:
                    self.update_Q(states[tau:tau+self.n_step+1], actions[tau:tau+self.n_step], rewards[tau:tau+self.n_step])

                if tau == T - 1:
                    break

                t += 1

    # CUSTOMIZE: for cliffwalking environment
    def evaluate_policy(self, episodes, display=False):
        wins = 0
        for _ in range(episodes):
            episode, off_cliff_info = self.generate_sample_episode(display=False)
            if not off_cliff_info["status"]:
                wins += 1
        return wins / episodes
    
    def visualize_q_table(self, colored_states=[]):
        q_table = np.zeros(self.env_shape + (env.action_space.n,))
    
        for state in range(env.observation_space.n):
            row = state // self.env_shape[1]
            col = state % self.env_shape[1]
            q_table[row, col] = self.Q[state]

        # calculate the average q-values for each state
        avg_q_values = np.mean(q_table, axis=2)
        df = pd.DataFrame(avg_q_values)
        
        return df.style.apply(self._style_trajectory_cells, axis=None, colored_states=colored_states)
    
    def _get_row_col(self, state):
        row = state // self.env_shape[1]
        col = state % self.env_shape[1]
        return row, col

    # CUSTOMIZE: for cliffwalking environment
    def _style_trajectory_cells(self, x, colored_states):
        df_colored = pd.DataFrame('background-color: black', index=x.index, columns=x.columns)
        
        df_colored.iloc[3, 0] = 'background-color: blue'
        df_colored.iloc[3, 11] = 'background-color: darkgreen'
        df_colored.iloc[3, 1:11] = 'background-color: brown'
        
        for state in colored_states:
            row, col = self._get_row_col(state)
            df_colored.iloc[row, col] = 'background-color: darkblue'
        
        return df_colored

In [410]:
nstep_sarsa_agent = NStepSARSA(env, n_step=4)
nstep_sarsa_agent.train(episodes=1000)
nstep_sarsa_agent.visualize_q_table()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,-20.577396,-19.105447,-17.74151,-16.874523,-15.71391,-14.523301,-13.184923,-12.18704,-10.863714,-8.83143,-6.912758,-6.325851
1,-20.827237,-20.789518,-26.68929,-20.301294,-17.179361,-16.41113,-14.283513,-13.173182,-9.092132,-10.344381,-6.235996,-5.532813
2,-23.886794,-34.940476,-26.80337,-19.624515,-18.059058,-23.349164,-13.424404,-14.690127,-9.431618,-16.262744,-9.815396,-3.450302
3,-65.367983,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [411]:
# test the trained agent
test_episodes = 100
win_rate = nstep_sarsa_agent.evaluate_policy(episodes=test_episodes)
print(f'agent win rate: {win_rate*100 :.2f}% from {test_episodes} test episodes')

agent win rate: 99.00% from 100 test episodes


# 3. SARSAmax (Q-Learning)

In [412]:
class QLearning(NStepSARSA):
    n_step = 1

    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        super().__init__(env, self.n_step, alpha, gamma, epsilon)

    def update_Q(self, states, actions, rewards):
        """
        Updates the Q-table using Q-Learning.
        """
        state, action, reward = states[0], actions[0], rewards[0]
        next_state = states[-1]
        
        # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]
        max_next_Q = np.max(self.Q[next_state])
        target = reward + self.gamma * max_next_Q
        self.Q[state, action] += self.alpha * (target - self.Q[state, action])
        
    def visualize_q_table(self, colored_states=[]):
        q_table = np.zeros(self.env_shape + (env.action_space.n,))
    
        for state in range(env.observation_space.n):
            row = state // self.env_shape[1]
            col = state % self.env_shape[1]
            q_table[row, col] = self.Q[state]

        # select highest Q-value from each state
        max_q_values = np.max(q_table, axis=2)
        df = pd.DataFrame(max_q_values)

        return df.style.apply(self._style_trajectory_cells, axis=None, colored_states=colored_states)

In [413]:
qlearning_agent = QLearning(env)
qlearning_agent.train(episodes=1000)
qlearning_agent.visualize_q_table()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,-10.886496,-10.4535,-9.843256,-9.197979,-8.488823,-7.725592,-6.968166,-6.181586,-5.376702,-4.555882,-3.736025,-2.931566
1,-11.278456,-10.743044,-10.06321,-9.327267,-8.487419,-7.626441,-6.74054,-5.82364,-4.886025,-3.934385,-2.968254,-1.989997
2,-11.361513,-10.466175,-9.561792,-8.648275,-7.725531,-6.793465,-5.851985,-4.900995,-3.940399,-2.9701,-1.99,-1.0
3,-24.300563,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [414]:
# test the trained agent
test_episodes = 100
win_rate = qlearning_agent.evaluate_policy(episodes=test_episodes)
print(f'agent win rate: {win_rate*100 :.2f}% from {test_episodes} test episodes')

agent win rate: 72.00% from 100 test episodes


# 4. Expected SARSA

In [415]:
class ExpectedSARSA(NStepSARSA):
    n_step = 1

    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        super().__init__(env, self.n_step, alpha, gamma, epsilon)

    def update_Q(self, states, actions, rewards):
        """
        Updates the Q-table using Expected SARSA.
        """
        state, action, reward = states[0], actions[0], rewards[0]
        next_state = states[-1]

        # calculate action probabilities based on  epsilon-greedy policy
        action_probs = np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n
        action_probs[np.argmax(self.Q[next_state])] += 1 - self.epsilon

        # calculate the expected Q-value of the next state-action pair (apply epsilon-greedy adjustment on the action probabilities)
        expected_next_Q = np.sum(self.Q[next_state] * action_probs)

        # update Q-Table
        target = reward + self.gamma * expected_next_Q
        self.Q[state, action] += self.alpha * (target - self.Q[state, action])

    def visualize_q_table(self, colored_states=[]):
        q_table = np.zeros(self.env_shape + (env.action_space.n,))
    
        # calculate weighted average Q-value from each state and assign to the new Q-table for visualization
        for state in range(env.observation_space.n):
            row = state // self.env_shape[1]
            col = state % self.env_shape[1]
            
            # calculate action probabilities based on  epsilon-greedy policy
            action_probs = np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n
            action_probs[np.argmax(self.Q[state])] += 1 - self.epsilon
            weighted_q_value = np.sum(self.Q[state] * action_probs)
            q_table[row, col] = [weighted_q_value] * self.env.action_space.n

        # reshape the Q-table to 2D for visualization
        weighted_q_values = np.mean(q_table, axis=2)
        df = pd.DataFrame(weighted_q_values)

        return df.style.apply(self._style_trajectory_cells, axis=None, colored_states=colored_states)

In [432]:
expected_sarsa_agent = ExpectedSARSA(env)
expected_sarsa_agent.train(episodes=1000)
expected_sarsa_agent.visualize_q_table()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,-12.582266,-11.951465,-11.260097,-10.483311,-9.645045,-8.782875,-7.900686,-7.004885,-6.0911,-5.150369,-4.18981,-3.213486
1,-13.1458,-12.235151,-11.299332,-10.338089,-9.351899,-8.376383,-7.388588,-6.384692,-5.37408,-4.332993,-3.289128,-2.229553
2,-14.264012,-13.178698,-12.084342,-11.002192,-9.855299,-8.569511,-7.55981,-6.342542,-5.093924,-3.842243,-2.501901,-1.135267
3,-22.950457,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [433]:
# test the trained agent
test_episodes = 100
win_rate = expected_sarsa_agent.evaluate_policy(episodes=test_episodes)
print(f'agent win rate: {win_rate*100 :.2f}% from {test_episodes} test episodes')

agent win rate: 94.00% from 100 test episodes


# 5. Algorithm visualization and comparision

In [429]:
# N-step SARSA
nstep_sarsa_episode, offcliff_info = nstep_sarsa_agent.generate_sample_episode(always_greedy=True)
states, _, _ = zip(*nstep_sarsa_episode)
df_nstep_sarsa = nstep_sarsa_agent.visualize_q_table(colored_states=states)
df_nstep_sarsa

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,-20.577396,-19.105447,-17.74151,-16.874523,-15.71391,-14.523301,-13.184923,-12.18704,-10.863714,-8.83143,-6.912758,-6.325851
1,-20.827237,-20.789518,-26.68929,-20.301294,-17.179361,-16.41113,-14.283513,-13.173182,-9.092132,-10.344381,-6.235996,-5.532813
2,-23.886794,-34.940476,-26.80337,-19.624515,-18.059058,-23.349164,-13.424404,-14.690127,-9.431618,-16.262744,-9.815396,-3.450302
3,-65.367983,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [430]:
# Q-Learning
qlearning_episode, offcliff_info = qlearning_agent.generate_sample_episode(always_greedy=True)
states, _, _ = zip(*qlearning_episode)
df_qlearning = qlearning_agent.visualize_q_table(colored_states=states)
df_qlearning

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,-10.886496,-10.4535,-9.843256,-9.197979,-8.488823,-7.725592,-6.968166,-6.181586,-5.376702,-4.555882,-3.736025,-2.931566
1,-11.278456,-10.743044,-10.06321,-9.327267,-8.487419,-7.626441,-6.74054,-5.82364,-4.886025,-3.934385,-2.968254,-1.989997
2,-11.361513,-10.466175,-9.561792,-8.648275,-7.725531,-6.793465,-5.851985,-4.900995,-3.940399,-2.9701,-1.99,-1.0
3,-24.300563,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [434]:
# Expected SARSA
expected_sarsa_episode, offcliff_info = expected_sarsa_agent.generate_sample_episode(always_greedy=True)
states, _, _ = zip(*expected_sarsa_episode)
df_expected_sarsa = expected_sarsa_agent.visualize_q_table(colored_states=states)
df_expected_sarsa

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,-12.582266,-11.951465,-11.260097,-10.483311,-9.645045,-8.782875,-7.900686,-7.004885,-6.0911,-5.150369,-4.18981,-3.213486
1,-13.1458,-12.235151,-11.299332,-10.338089,-9.351899,-8.376383,-7.388588,-6.384692,-5.37408,-4.332993,-3.289128,-2.229553
2,-14.264012,-13.178698,-12.084342,-11.002192,-9.855299,-8.569511,-7.55981,-6.342542,-5.093924,-3.842243,-2.501901,-1.135267
3,-22.950457,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
