# Workshop 5. Model Based Reinforcement Learning


**[CE716: Reinforcement Learning](https://deeprlcourse.github.io/)**

__Course Instructor__: Dr. Mohammad Hossein Rohban

__Notebook Author__: Ramtin Moslemi

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepRLCourse/Workshop-5-Material/blob/main/DynaQ.ipynb)
[![Open In kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/DeepRLCourse/Workshop-5-Material/main/DynaQ.ipynb)

---
## Notebook Objectives

Welcome to this workshop on **Reinforcement Learning (RL)**!
This notebook will help you learn core RL concepts and tools interactively by running code cells. Let's get started!

In [1]:
# @title Installations

! pip install minigrid --quiet

In [2]:
# @title Imports

import gymnasium as gym
from collections import defaultdict
from tqdm.notebook import trange

import logging
import base64
import json
import imageio
import IPython

import matplotlib
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import seaborn as sns

In [3]:
# @title Helper functions

# disable warnings
logging.getLogger().setLevel(logging.ERROR)

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display


def embed_mp4(filename):
    video = open(filename, 'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())
    return IPython.display.HTML(tag)

def embed_gif(filename):
    gif = open(filename, 'rb').read()
    b64 = base64.b64encode(gif)
    tag = '''
    <img src="data:image/gif;base64,{0}" width="640" height="480">
    '''.format(b64.decode())
    return IPython.display.HTML(tag)

def create_policy_eval_video(env, agent, filename, format='mp4', max_steps=1000, num_episodes=1, fps=8):
    if format == 'mp4':
        filename = filename + ".mp4"
        with imageio.get_writer(filename, fps=fps) as video:
            for _ in range(num_episodes):
                counter = 0
                state, info = env.reset()
                video.append_data(env.render())
                while True:
                    action = agent.act(state)
                    state, reward, terminated, truncated, info = env.step(action)
                    video.append_data(env.render())
                    counter += 1
                    if terminated or truncated or counter > max_steps:
                        break
        return embed_mp4(filename)
    elif format == 'gif':
        filename = filename + ".gif"
        with imageio.get_writer(filename, mode='I', duration=1/fps) as gif:
            for _ in range(num_episodes):
                counter = 0
                state, info = env.reset()
                gif.append_data(env.render())
                while True:
                    action = agent.act(state)
                    state, reward, terminated, truncated, info = env.step(action)
                    gif.append_data(env.render())
                    counter += 1
                    if terminated or truncated or counter > max_steps:
                        break
        return embed_gif(filename)
    else:
        raise ValueError("Unsupported format. Use 'mp4' or 'gif'.")


def plot_state_action_values(value, ax=None, show=False):
    if ax is None:
      fig, ax = plt.subplots()

    for a in range(4):
        ax.plot(range(54), value[:, a], marker='o', linestyle='--')
    ax.set(xlabel='States', ylabel='Values')
    ax.legend(['Up','Down','Right','Left'], loc='upper right')
    if show:
        plt.show()


def plot_quiver_max_action(value, ax=None, show=False):
    if ax is None:
        fig, ax = plt.subplots()
    X = np.tile(np.arange(4), [4,1]) + 0.5
    Y = np.tile(np.arange(4)[:,np.newaxis], [1,4]) + 0.5
    which_max = np.reshape(value.argmax(axis=1), (4,4))
    which_max = which_max[::-1,:]
    U = np.zeros(X.shape)
    V = np.zeros(X.shape)
    U[which_max == 0] = -1
    V[which_max == 1] = -1
    U[which_max == 2] = 1
    V[which_max == 3] = 1

    ax.quiver(X, Y, U, V)
    ax.set(title='Maximum value/probability actions', xlim=[-0.5, 4+0.5], ylim=[-0.5, 4+0.5])

    ax.set_xticks(np.linspace(0.5, 4-0.5, num=4))
    ax.set_xticklabels(["%d" % x for x in np.arange(4)])
    ax.set_xticks(np.arange(4+1), minor=True)
    ax.set_yticks(np.linspace(0.5, 4-0.5, num=4))
    ax.set_yticklabels(["%d" % y for y in np.arange(4)[::-1]])
    ax.set_yticks(np.arange(4+1), minor=True)
    ax.grid(which='minor',linestyle='-')
    if show:
        plt.show()

def plot_heatmap(value, color_terminal_states=True, ax=None, show=False):
    # Generate heatmap showing maximum value at each state
    if ax is None:
        fig, ax = plt.subplots()
    dim_x, dim_y = 9, 6
    action_max = value.argmax(axis=1)
    value_max = value.max(axis=1).reshape(dim_y, dim_x)
    act_dict = {0: 'U', 1: 'R', 2: 'D', 3: 'L'}
    act_dict = {3: '←', 2: '↓', 1: '→', 0: '↑'}
    walls = {(1, 2), (2, 2), (3, 2), (4, 5), (0, 7), (1, 7), (2, 7)}
    labels = np.array([act_dict.get(action, '') for action in action_max])
    for i in range(dim_x * dim_y):
        if (i // dim_x, i % dim_x) in walls:
            labels[i] = 'W'
        if (i // dim_x, i % dim_x) == (0, 8):
            labels[i] = 'G'
    labels = labels.reshape(dim_y, dim_x)
    im = sns.heatmap(value_max, cmap="RdYlGn", annot=labels, annot_kws={'fontsize': 16}, fmt='s')

    if color_terminal_states:
        for i in range(dim_x * dim_y):
            if (i // dim_x, i % dim_x) in walls:
                ax.add_patch(plt.Rectangle((i % dim_x, i // dim_x), 1, 1, color='black'))
            if (i // dim_x, i % dim_x) == (0, 8):
                ax.add_patch(plt.Rectangle((i % dim_x, i // dim_x), 1, 1, color='purple'))

    ax.set(title='Maximum value per state')
    ax.set_xticks(np.linspace(0.5, dim_x-0.5, num=dim_x))
    ax.set_xticklabels(["%d" % x for x in np.arange(dim_x)])
    ax.set_yticks(np.linspace(0.5, dim_y-0.5, num=dim_y))
    ax.set_yticklabels(["%d" % y for y in np.arange(dim_y)], rotation='horizontal')
    if show:
        plt.show()
    return im


def plot_rewards(rewards, ax=None, show=False):
    if ax is None:
        fig, ax = plt.subplots()
    ax.plot(rewards, marker='o', linestyle='--')
    ax.set(xlabel='Episodes', ylabel='Total reward')
    if show:
        plt.show()


def get_color_for_value(value, vmin, vmax):
    # Normalize the value between 0 and 1
    norm_value = (value - vmin) / (vmax - vmin)
    # Get color from a colormap
    colormap = plt.cm.RdYlGn  # You can choose any other colormap
    return colormap(norm_value)


def plot_q_values_grid(q_values, square_size=1, color_terminal_states=True, ax=None, show=False):
    rows, cols = 6, 9
    q_values = q_values.reshape(rows, cols, 4)
    walls = {(1, 2), (2, 2), (3, 2), (4, 5), (0, 7), (1, 7), (2, 7)}

    if ax is None:
      fig, ax = plt.subplots()

    # Determine the range of Q-values for normalization
    vmin = np.min(q_values)
    vmax = np.max(q_values)

    # Actions correspond to directions: 0=left, 1=right, 2=down, 3=up
    actions = {0: 'left', 2: 'right', 3: 'down', 1: 'up'}

    # Loop through each position in the grid
    for i in range(rows):
        for j in range(cols):
            # Calculate the lower-left corner of the square
            x = j * square_size
            y = i * square_size

            # Define the corners of the square
            bl = (x, y)  # bottom-left
            br = (x + square_size, y)  # bottom-right
            tl = (x, y + square_size)  # top-left
            tr = (x + square_size, y + square_size)  # top-right

            # Get Q-values for current state (i, j)
            q_left = q_values[i, j, 3]
            q_up = q_values[i, j, 2]
            q_right = q_values[i, j, 1]
            q_down = q_values[i, j, 0]

            # Check if the current cell is the one to be colored black
            if color_terminal_states and (i, j) in walls:
                edge_color = None
                left_color, right_color, down_color, up_color = ['black'] * 4
            elif color_terminal_states and (i, j) == (0, 8):
                edge_color = None
                left_color, right_color, down_color, up_color = ['purple'] * 4
            else:
                edge_color = 'black'
                left_color = get_color_for_value(q_left, vmin, vmax)
                right_color = get_color_for_value(q_right, vmin, vmax)
                down_color = get_color_for_value(q_down, vmin, vmax)
                up_color = get_color_for_value(q_up, vmin, vmax)

            # Draw and color the triangles based on Q-values
            triangle_left = patches.Polygon([bl, tl, (x + square_size/2, y + square_size/2)], closed=True,
                                            edgecolor=edge_color, facecolor=left_color)
            triangle_right = patches.Polygon([br, tr, (x + square_size/2, y + square_size/2)], closed=True,
                                             edgecolor=edge_color, facecolor=right_color)
            triangle_down = patches.Polygon([bl, br, (x + square_size/2, y + square_size/2)], closed=True,
                                            edgecolor=edge_color, facecolor=down_color)
            triangle_up = patches.Polygon([tl, tr, (x + square_size/2, y + square_size/2)], closed=True,
                                          edgecolor=edge_color, facecolor=up_color)

            ax.add_patch(triangle_left)
            ax.add_patch(triangle_right)
            ax.add_patch(triangle_down)
            ax.add_patch(triangle_up)

    # Set the limits of the plot
    ax.set_xlim(0, cols * square_size)
    ax.set_ylim(0, rows * square_size)

    # Set aspect of the plot to be equal
    ax.set_aspect('equal')

    # Disable ticks on both axes
    ax.set_xticks([])  # Disable x ticks
    ax.set_yticks([])  # Disable y ticks

    ax.invert_yaxis()  # Optional: Invert y-axis to have (0,0) at the top-left corner
    plt.colorbar(plt.cm.ScalarMappable(norm=plt.Normalize(vmin=vmin, vmax=vmax), cmap='RdYlGn'), ax=ax, label='Q-value')
    if show:
      plt.show()


def plot_performance(value, reward_sums):
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(24, 12), dpi=100)
    plot_state_action_values(value, ax=axes[0, 0])
    plot_q_values_grid(value, ax=axes[0, 1])
    im = plot_heatmap(value, ax=axes[1, 1])
    plot_rewards(reward_sums, ax=axes[1, 0])
    plt.show(fig)

In [4]:
# @title Simple Dyna Maze Environment

from typing import Any, SupportsFloat
import numpy as np
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal
from minigrid.minigrid_env import MiniGridEnv
from gymnasium.core import ObsType, ActType
from gymnasium import spaces

# world height
WORLD_HEIGHT = 6

# world width
WORLD_WIDTH = 9

# start state
START = (1, 3)

# goal state
GOAL = (9, 1)

# living reward
REWARD = -1

# possible actions
ACTION_UP = 0
ACTION_DOWN = 2
ACTION_LEFT = 3
ACTION_RIGHT = 1


class DynaMaze(MiniGridEnv):
    metadata = {'render_fps': 4}

    def __init__(self, **kwargs):
        self.agent_start_pos = START
        self.agent_start_dir = 0  # right
        self.goal = GOAL
        self.reward = REWARD
        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            width=WORLD_WIDTH + 2,
            height=WORLD_HEIGHT + 2,
            # Set this to True for maximum speed
            see_through_walls=True,
            highlight=False,
            **kwargs,
        )

        self.action_space = spaces.Discrete(4)

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None,) -> (
            tuple)[ObsType, dict[str, Any]]:
        super().reset()
        return self.observe(), {}

    def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        self.step_count += 1
        terminated = False
        truncated = False

        if action == ACTION_UP:
            self.agent_dir = 3
        elif action == ACTION_RIGHT:
            self.agent_dir = 0
        elif action == ACTION_DOWN:
            self.agent_dir = 1
        elif action == ACTION_LEFT:
            self.agent_dir = 2
        else:
            raise 'Invalid Action'

        # Get the position in front of the agent
        i, j = self.front_pos

        if 0 <= i < self.width and 0 <= j < self.height:
            fwd_cell = self.grid.get(i, j)
            if fwd_cell is None or fwd_cell.can_overlap():
                j = max(0, j)
                self.agent_pos = (i, j)

        if self.agent_pos == self.goal:
            terminated = True

        if self.render_mode == "human":
            self.render()

        return self.observe(), self.reward, terminated, truncated, {}

    def observe(self) -> int:
        return self.agent_pos[0] - 1 + (self.agent_pos[1] - 1) * (self.width - 2)

    @staticmethod
    def _gen_mission():
        return "get to the green goal square"

    def _gen_grid(self, width=WORLD_WIDTH, height=WORLD_HEIGHT):
        # Create an empty grid
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.wall_rect(0, 0, 1, WORLD_HEIGHT + 2)
        self.grid.wall_rect(0, 0, WORLD_WIDTH + 2, 1)
        self.grid.wall_rect(WORLD_WIDTH + 1, 0, 1, WORLD_HEIGHT + 2)
        self.grid.wall_rect(0, WORLD_HEIGHT + 1, WORLD_WIDTH + 2, 1)
        # Generate other walls
        self.grid.wall_rect(3, 2, 1, 3)
        self.grid.wall_rect(8, 1, 1, 3)
        self.grid.wall_rect(6, 5, 1, 1)

        # Place a goal square in the bottom-right corner
        self.put_obj(Goal(), *self.goal)

        # Place the agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()



# Dyna: Integrated Planning, Acting, and Learning

## The Environment

Here we learn how adding a model and planning can help Reinforcement Learning.
The environment we are going to use is based on the `Example 8.1: Dyna Maze` taken from [Reinforcement Learning: An Introduction](http://incompleteideas.net/book/the-book.html) by [Richard S. Sutton](http://incompleteideas.net/index.html)
and [Andrew G. Barto](http://www-anw.cs.umass.edu/%7Ebarto/) with a few minor changes.
The code for the environment is implemented using [MiniGrid](https://minigrid.farama.org/) and you can see how a random agent does in this environment by running the next cell.

In [5]:
# @title Random Policy

class RandomAgent(object):
    def __init__(self):
        pass

    def act(self, state, greedy=None):
        # Select and return a random action
        return np.random.randint(4)


env = DynaMaze(render_mode='rgb_array')
random_agent = RandomAgent()

create_policy_eval_video(env, random_agent, 'random_policy', fps=32)

The agent receives a `-1` reward at each step until the episode is terminated by reaching the goal which is a terminal state.
Although this environment is *deterministic* we write our code for solving it as if it were *stochastic* so you can see a more general implementation.
Feel free to implement the deterministic version as an exercise.

## Dyna-Q

In [6]:
# @title The ɛ-greedy policy

def epsilon_greedy_policy(state: int, q_values: np.ndarray, epsilon: float) -> int:
    if np.random.rand() > epsilon:  # greedy action
        action = np.argmax(q_values[state])
    else:  # random action
        action = np.random.choice(q_values.shape[1])
    return action

The planning procedure is as follows:


1.   Select a state, $S \in \mathcal{S}$, and an action, $A \in \mathcal{A}(S)$, at random
2.   Send $S, A$ to a sample model, and obtain a sample next reward, $R$, and a sample next state, $S'$
3.  Apply one-step tabular Q-learning to $S, A, R, S'$:
$$Q(S,A) \leftarrow Q(S,A) + \alpha [R + \gamma \max_a Q(S', a) - Q(S,A)]$$

We can control the number of planning steps using `n` and repeat these steps as many times as necessary.


In [7]:
def q_planning(model: np.ndarray, q: np.ndarray, alpha: float, gamma: float, n: int) -> np.ndarray:

    # Execute n planning steps
    for _ in range(n):

        # Randomly sample a state-action pair from the model
        state = np.random.choice(list(model.keys()))
        action = np.random.choice(list(model[state].keys()))

        # Sample from the distribution of rewards and next states
        transitions = model[state][action]
        rewards, next_states, counts = zip(*transitions)
        total_counts = sum(counts)

        # Calculate probabilities
        probabilities = [count / total_counts for count in counts]

        # Randomly select a transition based on estimated probabilities
        index = np.random.choice(len(transitions), p=probabilities)
        reward = rewards[index]
        next_state = next_states[index]

        # Update the Q-value using the simulated transition
        q[state, action] += alpha * (reward + gamma * np.max(q[next_state]) - q[state, action])

    return q

But we also need to learn the model itself!
In order to do this, we simply store all the transitions we encouter during our real experiences.
As some of these transitions will be identical (taking action $A$ at state $S$ generates reward $R$ and takes us to the next state $S'$), we can assign a counter for each transition and increment it based on real experience observations.

In [8]:
def update_model(model, state: int, action: int, reward: float, next_state: int):
    # Search for similar transition
    for i, (r, ns, count) in enumerate(model[state][action]):
        if ns == next_state and r == reward:
            # Increment the number of observations for transition
            model[state][action][i] = (r, ns, count + 1)
            return model
    # Add new transition if not found
    model[state][action].append((reward, next_state, 1))
    return model

At last we can use the `q_planning` and `update_model` functions to implement the **Dyna-Q** algorithm!
This is quite similar to the classic **Q-Learning** algorithm but after each action and subsequent observation, we also update our model of the environment and perform `n` steps of planning!

In [9]:
def dyna_q(n_episodes: int, env: gym.Env, epsilon: float, alpha: float,
           gamma: float, n: int) -> tuple[np.ndarray, np.ndarray]:
    # Store rewards per episode
    reward_sums = np.empty(n_episodes)

    # Start with a uniform value function
    q = np.zeros((6 * 9, env.action_space.n))

    # Dyna-Q model for stochastic environments
    model = defaultdict(lambda: defaultdict(list))

    # Loop over episodes
    for episode_i in (pbar := trange(n_episodes, leave=False)):

        # Initialize state
        state, info = env.reset()
        reward_sum, terminal = 0, False

        # Run episode
        while not terminal:

            # Choose next action
            action = epsilon_greedy_policy(state, q, epsilon)

            # Observe resultant reward and state
            next_state, reward, terminated, truncated, info = env.step(action)

            # Update value function
            q[state, action] += alpha * (reward + gamma * np.max(q[next_state]) - q[state, action])

            # Update model
            model = update_model(model, state, action, reward, next_state)

            # Start planning
            q = q_planning(model, q, alpha, gamma, n)

            # Go to the next state
            state, terminal = next_state, terminated or truncated

            # Update sum rewards obtained
            reward_sum += reward

        pbar.set_description(f'Episode Reward {int(reward_sum)}')
        reward_sums[episode_i] = reward_sum

    return q, reward_sums

## Results

You can do your own experimentation by running the next cell.
Setting `n=0` will give you the **Q-Learning** algorithm, but keep in mind you will need to train for more episodes to get good results.

In [10]:
# @markdown Experimentation Widget

import ipywidgets as widgets
from IPython.display import display

class GreedyAgent(object):
    def __init__(self, q):
        self.q = q

    def act(self, state, greedy=None):
        return np.argmax(self.q[state])

def run_experiment(n_episodes: int, epsilon: float, alpha: float,
                   gamma: float, n: int):
    # set for reproducibility
    np.random.seed(716)

    # Initialize environment
    env = DynaMaze()

    # Solve the simple maze using Dyna-Q
    q_values, rewards = dyna_q(n_episodes, env, epsilon, alpha, gamma, n)

    # Plot the results
    plot_performance(q_values, rewards)

    # Visualize policy
    env = DynaMaze(render_mode='rgb_array')
    agent = GreedyAgent(q_values)
    name = f'Dyna-Q t={n_episodes} ɛ={epsilon} α={alpha} γ={gamma} n={n}'
    return create_policy_eval_video(env, agent, name, max_steps=100)

t = widgets.IntSlider(value=100, min=10, max=1000, step=10,
                      style={'description_width': 'initial'},
                      layout={'width': '500px'},
                      description='Number of Episodes:')

e = widgets.FloatSlider(value=0.1, min=0, max=1, step=0.05,
                        style={'description_width': 'initial'},
                        layout={'width': '500px'},
                        description='Exploration Rate ɛ:')

a = widgets.FloatSlider(value=0.1, min=0, max=1, step=0.05,
                        style={'description_width': 'initial'},
                        layout={'width': '500px'},
                        description='Step Size α:')

g = widgets.FloatSlider(value=0.95, min=0, max=1, step=0.05,
                        style={'description_width': 'initial'},
                        layout={'width': '500px'},
                        description='Discount Factor γ:')

n = widgets.IntSlider(value=30, min=0, max=100, step=5,
                      style={'description_width': 'initial'},
                      layout={'width': '500px'},
                      description='Planning Steps n:')

button = widgets.Button(description="Run Experiment!")
output = widgets.Output()

display(t, e, a, g, n, button, output)

def on_button_clicked(b):
    with output:
        plt.close()
        print(f'\n{t.value} Episodes ɛ={e.value}, α={a.value}, γ={g.value}, n={n.value}')
        run_experiment(t.value, e.value, a.value, g.value, n.value)

button.on_click(on_button_clicked)

IntSlider(value=100, description='Number of Episodes:', layout=Layout(width='500px'), max=1000, min=10, step=1…

FloatSlider(value=0.1, description='Exploration Rate ɛ:', layout=Layout(width='500px'), max=1.0, step=0.05, st…

FloatSlider(value=0.1, description='Step Size α:', layout=Layout(width='500px'), max=1.0, step=0.05, style=Sli…

FloatSlider(value=0.95, description='Discount Factor γ:', layout=Layout(width='500px'), max=1.0, step=0.05, st…

IntSlider(value=30, description='Planning Steps n:', layout=Layout(width='500px'), step=5, style=SliderStyle(d…

Button(description='Run Experiment!', style=ButtonStyle())

Output()