In this lab scenario you will finish an implementation of a variant of the Q-learning method called DQN.

On top of the usual Q-learning, using neural nets as function approximations, DQN uses:
* experience replay – used to increase efficacy of samples from the environment and decorrelate elements of a batch;
* target network – used to avoid constantly changing targets in the learning process (to avoid "chasing your own tail").

For the algorithm's details recall the lecture and/or follow the [original paper](https://arxiv.org/abs/1312.5602), which is rather self-contained and not hard to understand.

Without changing any hyperparameters, the agent should solve the problem (obtain rewards \~200) after \~500 episodes (\~13 minutes of training on Colab CPU, not any faster on GPU).

## Tasks

1.   Implement missing code #### TODO IMPLEMENT #####
2.   Experiment with the hyperparameters e.g. gamma (discount-factor in accumulated reward), epsilon (exploration-exploitation trade-off in "epsilon-greedy")
3.   Observe weird behaviors of agents, e.g. "forgetting how to play" (reward going down a lot and then "re-learning" again). Why can that happen? What can we do to avoid it?
4.   Change the args and observe the trained model behavior. What do you see?
5.   What can be improved in the training code?

## Imports
(Some extra packages are needed even on Colab. For some reason `box2d` needs to be installed in a separate step after `swig`, and will take a minute to build).

In [3]:
!pip install 'swig==4.4.*'
!pip install 'gymnasium[box2d]==1.2.*'



In [4]:
import operator
import time
import random
import warnings
from abc import ABC
from collections.abc import Sequence
from copy import deepcopy
from functools import reduce
from pathlib import Path
from typing import Literal, NamedTuple, cast

import cv2
import IPython.display as display
import ipywidgets as widgets
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import gymnasium as gym

# Suppress a warning from pygame: https://github.com/pygame/pygame/issues/4313
warnings.filterwarnings("ignore", message="(?s:.)*pkg_resources is deprecated")

%load_ext tensorboard

In [5]:
# Check for CUDA / MPS (Apple) / XPU (Intel) / ... accelerator.
device = (
    torch.accelerator.current_accelerator(check_available=True)
    or torch.device("cpu")  #
)
use_accel = device != torch.device("cpu")
print(use_accel, device)

True cuda


In [6]:
# Used to speed up the training in Colab and avoid situations when the model is stuck
# But this can be problematic for our model (why?)
# Use -1 for unlimited.
MAX_EPISODE_STEPS = 500

EXP_NAME = "LunarLander"
LOG_DIR = Path("./runs/") / EXP_NAME
TENSORBOARD_LOG_DIR = LOG_DIR / "tensorboard"
CHECKPOINTS_DIR = LOG_DIR / "checkpoints"
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
TENSORBOARD_LOG_DIR.mkdir(parents=True, exist_ok=True)

## Environment

We will try to solve: https://gymnasium.farama.org/environments/box2d/lunar_lander/

LunearLander env can be considered solved once we achieve 200 points.

![Lunar Lander example GIF](https://gymnasium.farama.org/_images/lunar_lander.gif)

In [7]:
env = gym.make(
    "LunarLander-v3", render_mode="rgb_array", max_episode_steps=MAX_EPISODE_STEPS
)

# Observations are 8-dimensional vectors: (x, y, vx, vy, theta, vtheta, left_leg, right_leg),
# where left_leg, right_leg are booleans: 1 if leg has contact, 0 otherwise.
OBS_SHAPE = cast(tuple[int, ...], env.observation_space.shape)

# Actions are: 0=do nothing, 1=fire left engine, 2=fire main engine, 3=fire right engine.
assert isinstance(env.action_space, gym.spaces.Discrete)
N_ACTIONS = int(env.action_space.n)

print(f"{OBS_SHAPE=} {N_ACTIONS=}")

OBS_SHAPE=(8,) N_ACTIONS=4


In [8]:
# OBS_SHAPE
env.observation_space.shape

(8,)

In [9]:
torch.rand((8,))

tensor([0.5796, 0.9784, 0.8845, 0.3057, 0.0127, 0.9888, 0.3784, 0.4501])

## Scheduler
(for the exploration rate $\epsilon$ in "$\epsilon$-greedy")

Training RL agents requires dealing with exploration-exploitation trade-off. To handle this we will adopt the most basic, but extremely efficient, epsilon-greedy strategy. At the beginning our agent will focus on exploration, and over time will start exploiting his knowledge, and thus becoming more and more greedy. To implement this logic we will use LinearDecay or ExponentialRateScheduler scheduler.

In [10]:
class ExplorationRateScheduler(ABC):
    def __init__(self, initial: float, on: Literal["step", "episode"]) -> None:
        self._value = initial
        self.on = on

    def value(self) -> float:
        """Returns the cur rent exploration rate (epsilon) without updating it."""
        return self._value

    def _update(self) -> None:
        """Updates the exploration rate (epsilon)."""
        raise NotImplementedError

    def next_step(self) -> float:
        """Updates the exploration rate (epsilon) and returns the new value."""
        if self.on == "step":
            self._update()
        return self.value()

    def next_episode(self) -> float:
        """Updates the exploration rate (epsilon) and returns the new value."""
        if self.on == "episode":
            self._update()
        return self.value()


class ConstantRateScheduler(ExplorationRateScheduler):
    """Constant scheduler (use epsilon=0 for greedy policy)."""

    def __init__(self, epsilon: float) -> None:
        super().__init__(initial=epsilon, on="step")


class LinearRateScheduler(ExplorationRateScheduler):
    """Linear scheduler (step() decreases value by `rate`, until final is reached)."""

    def __init__(
        self,
        *,
        initial: float,
        final: float,
        rate: float,
        on: Literal["step", "episode"],
    ) -> None:
        super().__init__(initial=initial, on=on)
        self.final = final
        self.rate = rate

    def _update(self) -> None:
        self._value = max(self._value - self.rate, self.final)


class ExponentialRateScheduler(ExplorationRateScheduler):
    """Exponential scheduler (step() multiplies value by `decay`, until final is reached)."""

    def __init__(
        self,
        *,
        initial: float,
        final: float,
        decay: float,
        on: Literal["step", "episode"],
    ) -> None:
        super().__init__(initial=initial, on=on)
        self.final = final
        self.decay = decay

    def _update(self) -> None:
        self._value = max(self._value * self.decay, self.final)

## Replay buffer

The key trick that makes DQN feasible is the replay buffer. The idea is to store observed transitions, sample them randomly and perform updates based on them. This solution has many advantages, the most significant ones are:

1.   *Data efficiency* – each transition (env step) can be used in many weight updates.
2.   *Data decorrelation* – consecutive transitions are naturally highly correlated. Randomizing the samples reduces these correlations, thus reducing variance of the updates.

Note that when learning by experience replay, it is necessary to learn off-policy (because our current parameters are different than those used to generate the sample), which motivates the choice of Q-learning.

In [11]:
class Transition(NamedTuple):
    observation: np.ndarray  # shape OBS_SHAPE
    action: int
    next_observation: np.ndarray  # shape OBS_SHAPE
    reward: float
    is_terminal: bool  # (useful when using target_net for predicting qvalues)


class TransitionBatch(NamedTuple):
    observation: Tensor  # shape (batch_size, *OBS_SHAPE)
    action: Tensor  # shape (batch_size,), dtype=torch.long
    next_observation: Tensor  # shape (batch_size, *OBS_SHAPE)
    reward: Tensor  # shape (batch_size,)
    is_terminal: Tensor  # shape (batch_size,), dtype=torch.bool


class ReplayBuffer:
    def __init__(self, capacity: int) -> None:
        """Create new replay buffer of a given capacity."""
        self.capacity = capacity
        self._storage = list[Transition]()
        self._next_idx = 0

    def add(self, transition: Transition) -> None:
        """Add new experience to the buffer."""
        if len(self._storage) < self.capacity:
            self._storage.append(transition)
        else:
            self._storage[self._next_idx] = transition
        self._next_idx = (self._next_idx + 1) % self.capacity

    def sample(self, batch_size: int) -> list[Transition]:
        """Sample batch of experiences from the buffer."""
        return random.sample(self._storage, batch_size)

    def sample_to_torch(self, batch_size: int, device: torch.device) -> TransitionBatch:
        """Sample batch of experiences from the buffer and convert it to torch tensors."""
        batch = self.sample(batch_size)
        return TransitionBatch(
            torch.tensor(np.array([t.observation for t in batch]), device=device),
            torch.tensor([t.action for t in batch], dtype=torch.long, device=device),
            torch.tensor(np.array([t.next_observation for t in batch]), device=device),
            torch.tensor([t.reward for t in batch], dtype=torch.float32, device=device),
            torch.tensor(
                [t.is_terminal for t in batch], dtype=torch.bool, device=device
            ),
        )
    
    def __len__(self) -> int:
        return len(self._storage)

## MLP Network

For fast iteration we will stick to numerical observations (original DQN paper works with graphical observations). We will use simple MLP to net approximate our estimates of Q-values for (action, states).

In [12]:
class MLP(nn.Module):
    def __init__(self, hidden_dims: Sequence[int] = (128, 128)) -> None:
        super().__init__()

        input_dim = reduce(operator.mul, OBS_SHAPE, 1)
        dims = [input_dim, *hidden_dims, N_ACTIONS]

        self.layers = nn.Sequential(nn.Flatten())

        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)):
            self.layers.append(nn.Linear(in_dim, out_dim))
            if i < len(dims) - 2:
                self.layers.append(nn.ReLU())

    def forward(self, state: Tensor) -> Tensor:
        """
        Input: observation tensor of shape (batch_size, *OBS_SHAPE).
        Output: Q-values tensor of shape (batch_size, N_ACTIONS).
        """
        return self.layers(state)

## DQN Agent

In [13]:
class DQNAgent:
    def __init__(
        self,
        eps_scheduler: ExplorationRateScheduler,
        q_value_net: nn.Module,
    ) -> None:
        self.eps_scheduler = eps_scheduler
        self.q_value_net = q_value_net

    def act(self, obs: np.ndarray) -> int:
        """
        Sample action using epsilon-greedy policy derived from q_value_net.

        Args:
            obs: shape OBS_SHAPE.

        With probability epsilon=eps_scheduler.value(): return a random action.
        Otherwise: return argmax_a Q_theta(obs, a)
        """
        epsilon = self.eps_scheduler.value()
        if torch.rand(1).item() < epsilon:
            return int(torch.randint(low=0, high=N_ACTIONS, size=(1,)).item())
        else:
            obs_tensor = torch.tensor(obs, dtype=torch.float32, device=device)
            obs_tensor = obs_tensor.view(1, *OBS_SHAPE)
            with torch.no_grad():
                return int(torch.argmax(self.q_value_net(obs_tensor)).item())

    def save_checkpoint(self, path: Path) -> None:
        """Save q_value_net parameters to a checkpoint file."""
        torch.save(self.q_value_net.state_dict(), path)

    def load_checkpoint(self, path: Path) -> None:
        """Load q_value_net parameters from a checkpoint file."""
        self.q_value_net.load_state_dict(torch.load(path, weights_only=True))

## Playthrough (random)

In [14]:
def display_episode_playthrough(agent: DQNAgent, env: gym.Env) -> None:
    """Run the agent to play one episode start-to-finish and show a rendering."""
    agent.q_value_net.eval()
    observation, _info = env.reset()
    terminated, truncated = False, False
    total_reward = 0.0

    image_widget = widgets.Image(format="jpeg")
    display.display(image_widget)

    timestep = 0
    while not (terminated or truncated):
        timestep += 1
        # Pick next action, simulate and observe next_state and reward
        action = agent.act(observation)
        observation, reward, terminated, truncated, _info = env.step(action)

        frame = cv2.cvtColor(env.render(), cv2.COLOR_RGB2BGR)  # type: ignore
        image_widget.value = cv2.imencode(".jpeg", frame)[1].tobytes()
        time.sleep(0.01)

        total_reward += float(reward)

    print(f"Episode length: {timestep}, total reward: {total_reward}")
    time.sleep(0.5)  # 0.5 s between episodes, so it's easier to watch.

In [15]:
display_episode_playthrough(
    DQNAgent(ConstantRateScheduler(0.0), MLP(hidden_dims=[256, 256]).to(device)), env
)

Image(value=b'', format='jpeg')

Episode length: 137, total reward: -286.6744656081555


## DQN Trainer

In [16]:
class DQNTrainer:
    def __init__(
        self,
        agent: DQNAgent,
        optim: torch.optim.Optimizer,
        env: gym.Env,
        buffer_size: int = 10000,
        gamma: float = 0.99,
        batch_size: int = 128,
        target_update_interval: int = 100,  # in steps
        checkpoints_dir: Path = CHECKPOINTS_DIR,
        checkpoint_save_interval: int = 5000,  # in steps
        tensorboard_log_dir: Path = TENSORBOARD_LOG_DIR,
    ) -> None:
        self.agent = agent
        self.env = env
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update_interval = target_update_interval
        self.optim = optim
        self.checkpoints_dir = checkpoints_dir
        self.checkpoint_save_interval = checkpoint_save_interval
        self.tensorboard_log_dir = tensorboard_log_dir

        self.replay_buffer = ReplayBuffer(buffer_size)

        self.q_value_net = agent.q_value_net
        self.target_net = deepcopy(self.q_value_net)

        self._total_steps = 0

    def train(self, n_episodes: int) -> None:
        self.q_value_net.train()
        self.target_net.load_state_dict(self.q_value_net.state_dict())
        self.target_net.eval()

        self._total_steps = 0

        rewards_history = list[float]()
        run_id = len(list(self.tensorboard_log_dir.iterdir()))
        with SummaryWriter(self.tensorboard_log_dir / f"r{run_id}") as writer:
            progress_bar = tqdm(range(n_episodes), desc="Training", unit="episode")
            for episode in progress_bar:
                reward, length = self.train_episode(episode, writer)

                # Logging to tqdm and Tensorboard
                eps = self.agent.eps_scheduler.value()
                rewards_history.append(reward)
                reward_mean10 = np.mean(rewards_history[-10:])
                progress_bar.set_postfix(
                    {
                        "reward_mean10": f"{reward_mean10:5.1f}",
                        "episode_length": f"{length}",
                        "exploration_eps": f"{eps:.1g}",
                    }
                )
                writer.add_scalar("1_Reward/last1", reward, episode)
                writer.add_scalar("1_Reward/last10", reward_mean10, episode)
                writer.add_scalar("2_Other/episode_steps", length, episode)
                writer.add_scalar("2_Other/exploration_eps", eps, episode)
                self.agent.eps_scheduler.next_episode()

    def train_episode(self, episode: int, writer: SummaryWriter) -> tuple[float, int]:
        episode_reward, episode_steps = 0.0, 0

        observation: np.ndarray
        observation, _info = self.env.reset()
        terminated, truncated = False, False

        while not (terminated or truncated):
            action = self.agent.act(observation)
            next_observation: np.ndarray
            next_observation, reward, terminated, truncated, _info = self.env.step(
                action
            )

            # Store Transition in replay buffer. ("state", "action", "next_state", "reward", "non_terminal_mask")

            is_terminal = terminated or truncated
            self.replay_buffer.add(Transition(observation, action, next_observation, reward, is_terminal))

            observation = next_observation

            loss = self._update_q_value_net()

            if (self._total_steps + 1) % self.target_update_interval == 0:
                self._update_target_net()

            if (self._total_steps + 1) % self.checkpoint_save_interval == 0:
                ckpt_name = f"step{self._total_steps + 1}_ep{episode}.ckpt"
                self.agent.save_checkpoint(self.checkpoints_dir / ckpt_name)

            self._total_steps += 1
            episode_steps += 1
            episode_reward += float(reward)
            self.agent.eps_scheduler.next_step()

            if loss is not None:
                writer.add_scalar("2_Other/loss", loss, self._total_steps)

        return episode_reward, episode_steps

    def _update_target_net(self) -> None:
        """Copy q_value_net parameters into target_net."""
        self.target_net.load_state_dict(self.q_value_net.state_dict())

        # Alternatively:
        # for q, t in zip(self.q_value_net.parameters(), self.target_net.parameters(), strict=True):
        #     q.data.copy_(t.data)

    def _update_q_value_net(self) -> float | None:
        """Perform one round of q_value_net update.

        Sample random minibatch of transitions (fi(s_t), a_t, r_t, fi(s_t+1)) from the replay buffer
        and update q_value_net according to the DQN algorithm.
        """
        if len(self.replay_buffer) < self.batch_size:
            return None

        batch = self.replay_buffer.sample_to_torch(self.batch_size, device=device)

        state_action_values = self.get_state_action_values(batch)
        with torch.no_grad():
            targets = self.get_targets(batch)

        loss = F.mse_loss(state_action_values, targets)

        self.optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.q_value_net.parameters(), 100.0)
        self.optim.step()

        return loss.item()

    def get_targets(self, batch: TransitionBatch) -> Tensor:
        """Uses `target_net` and immediate rewards to calculate expected future rewards."""
        batch_size = batch.next_observation.shape[0]

        q_value_of_next_obs, _argmax = torch.max(
            self.target_net(batch.next_observation), dim=1
        )
        q_value_of_next_obs[batch.is_terminal] = 0.0

        assert batch.reward.shape == (batch_size,)
        assert q_value_of_next_obs.shape == (batch_size,)
        # Expected future reward for terminal state is equal to immediate reward
        # For non terminal states expected future reward:
        #   immediate reward + discounted future expectation

        targets = batch.reward + self.gamma * q_value_of_next_obs


        assert targets.shape == (batch_size,)
        return targets

    def get_state_action_values(self, batch: TransitionBatch) -> Tensor:
        """
        Use `q_value_net` to get estimates of future rewards for (obs, action).

        Output shape: (batch_size,)
        """
        batch_size = batch.observation.shape[0]
        preds = self.q_value_net(batch.observation)  # shape (batch_size, n_actions).
        # Select preds[i, action_index[i]] for i in batch.
        return preds[torch.arange(batch_size), batch.action]
        # Alternatively:
        # return preds.gather(dim=1, index=batch.action.unsqueeze(dim=1)).squeeze(dim=1)

## Tensorboard

In [17]:
# Start tensorboard inside Google Colab
# If you can't see anything, run this cell twice.
%tensorboard --logdir $TENSORBOARD_LOG_DIR

Reusing TensorBoard on port 6006 (pid 5256), started 1:25:45 ago. (Use '!kill 5256' to kill it.)

## Run training

In [None]:
n_episodes = 500
estimated_n_steps = n_episodes * 100
agent = DQNAgent(
    q_value_net=MLP().to(device),
    eps_scheduler=ExponentialRateScheduler(
        initial=1.0, final=0.01, decay=0.995, on="episode"
    ),
    # eps_scheduler=LinearRateScheduler(initial=1.0, final=0.01, decay=1 / estimated_n_steps, on="step"),
)
trainer = DQNTrainer(
    agent=agent,
    optim=torch.optim.Adam(agent.q_value_net.parameters(), lr=0.0001),
    env=env,
)
trainer.train(n_episodes)

## Playthrough (after training)

In [19]:
# Load the latest checkpoint.
agent = DQNAgent(q_value_net=MLP().to(device), eps_scheduler=ConstantRateScheduler(0.0))
ckpt = sorted(CHECKPOINTS_DIR.iterdir(), key=lambda p: p.stat().st_mtime)[-1]
agent.load_checkpoint(ckpt)

In [24]:
display_episode_playthrough(agent, env=env)

Image(value=b'', format='jpeg')

Episode length: 310, total reward: 239.11061318981072
