# DQN Atari Paper Upgrade using Rainbow
Here we will upgrade the DQN Atari paper using the Rainbow algorithm.
From the collection of improvements in the Rainbow algorithm, we will implement the following:
- Dueling Network Architecture
- Prioritized Experience Replay
- N-Step Returns
- Noisy Networks

In [46]:
# ! pip install gymnasium[atari,accept-rom-license] torch numpy opencv-python matplotlib

In [47]:
import gymnasium as gym
import numpy as np
import torch
import torch.optim as optim
import random
from collections import deque
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
import torch.nn as nn

In [48]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

## Hyperparameters
As per the paper, we use certain hyperparameters that were tuned across various Atari games.

In [49]:
LEARNING_RATE = 0.00025  # Paper used a similar learning rate
DISCOUNT_FACTOR = 0.99  # The γ discount factor as mentioned in the paper
REPLAY_MEMORY_SIZE = 150_000  # Large replay buffer as described, but not too large
BATCH_SIZE = 32  # Minibatch size for training
TARGET_UPDATE_FREQ = 1_250  # C steps for target network update
FRAME_SKIP = 4  # Number of frames skipped
MIN_EPSILON = 0.1  # Minimum value of epsilon (for more exploitation)
MAX_EPSILON = 1.0  # Starting value of epsilon (for exploration)
EPSILON_PHASE = 0.1  # Percentage of steps for epsilon to reach MIN_EPSILON
MAX_STEPS = 100_000  # Total training episodes
REPLAY_START_SIZE = 75_000  # Size of replay memory before starting training
SAVE_FREQUENCY = 50_000  # Save model every 50k steps

# Prioritized Experience Replay parameters
ALPHA = 0.6  # Prioritization exponent
BETA_START = 0.4  # Initial value of beta for importance sampling weights
BETA_FRAMES = MAX_STEPS  # Number of frames over which beta increases to 1

# N-step returns
N_STEP = 3  # Number of steps for N-step returns

# Reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x13390b3b0>

## Prioritized Replay Buffer
The paper introduces Prioritized Replay Buffer to sample important transitions more frequently. We use a SumTree data structure to store priorities and sample transitions based on the priorities.

In [50]:
class PrioritizedReplayBuffer:
    def __init__(self, size, obs_shape, action_shape, alpha):
        self.size = size
        self.alpha = alpha
        self.obs_shape = obs_shape
        self.action_shape = action_shape

        self.t_obs = np.empty((size, *obs_shape), dtype=np.uint8)
        self.t1_obs = np.empty((size, *obs_shape), dtype=np.uint8)
        self.actions = np.empty((size, *action_shape), dtype=np.uint8)
        self.rewards = np.empty(size, dtype=np.float16)
        self.dones = np.empty(size, dtype=np.bool_)

        self.priorities = np.zeros((size,), dtype=np.float32)
        self.max_priority = 1.0

        self.idx = 0
        self.current_size = 0

    def append(self, t_obs, t1_obs, actions, reward, done):
        self.t_obs[self.idx] = t_obs
        self.t1_obs[self.idx] = t1_obs
        self.actions[self.idx] = actions
        self.rewards[self.idx] = reward
        self.dones[self.idx] = done

        self.priorities[self.idx] = self.max_priority

        self.current_size = min(self.current_size + 1, self.size)
        self.idx = (self.idx + 1) % self.size

    def sample(self, batch_size, beta):
        if self.current_size == self.size:
            priorities = self.priorities
        else:
            priorities = self.priorities[: self.current_size]
        probabilities = priorities**self.alpha
        probabilities /= probabilities.sum()

        indices = np.random.choice(self.current_size, batch_size, p=probabilities)
        weights = (self.current_size * probabilities[indices]) ** (-beta)
        weights /= weights.max()  # Normalize for stability

        batch = (
            self.t_obs[indices],
            self.t1_obs[indices],
            self.actions[indices],
            self.rewards[indices],
            self.dones[indices],
        )

        return (
            tuple(
                torch.as_tensor(item, dtype=torch.float32).to(device) for item in batch
            ),
            torch.as_tensor(indices).to(device),
            torch.as_tensor(weights, dtype=torch.float32).to(device),
        )

    def update_priorities(self, indices, priorities):
        self.priorities[indices] = priorities
        self.max_priority = max(self.max_priority, priorities.max())

    def __len__(self):
        return self.current_size

## Noisy Linear Layers
The paper introduces Noisy Linear Layers to add noise to the weights of the linear layers. We use a NoisyLinear layer to add noise to the weights of the linear layers.

In [51]:
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma_init=0.017):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        # Learnable parameters
        self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features))
        self.register_buffer(
            "weight_epsilon", torch.FloatTensor(out_features, in_features)
        )
        self.bias_mu = nn.Parameter(torch.FloatTensor(out_features))
        self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features))
        self.register_buffer("bias_epsilon", torch.FloatTensor(out_features))
        self.sigma_init = sigma_init
        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        # Initialize parameters
        mu_range = 1 / self.in_features**0.5
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.sigma_init)
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.sigma_init)

    def reset_noise(self):
        # Sample new noise
        self.weight_epsilon.normal_()
        self.bias_epsilon.normal_()

    def forward(self, input):
        if self.training:
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        return nn.functional.linear(input, weight, bias)

## Deep Q-Network Architecture
Dueling Network Architecture is used in the Rainbow algorithm. The architecture consists of two streams, one for the state value and the other for the advantage values. The two streams are combined to produce the Q-values.

We will also use the previously implemented NoisyLinear layer to add noise to the weights of the linear layers.

In [52]:
class DuelingDQN(torch.nn.Module):
    def __init__(self, n_actions):
        super(DuelingDQN, self).__init__()
        self.conv = torch.nn.Sequential(
            # Convolution layers (as per paper), input: 84x84x4 image
            torch.nn.Conv2d(4, 32, kernel_size=8, stride=4),
            # Using ReLU activations as specified in the paper
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
        )
        # Input features after flattening
        conv_out_size = 7 * 7 * 64

        # Value stream
        self.value_stream = torch.nn.Sequential(
            NoisyLinear(conv_out_size, 512), torch.nn.ReLU(), NoisyLinear(512, 1)
        )

        # Advantage stream
        self.advantage_stream = torch.nn.Sequential(
            NoisyLinear(conv_out_size, 512),
            torch.nn.ReLU(),
            NoisyLinear(512, n_actions),
        )

    def forward(self, x):
        x = self.conv(x / 255.0)
        value = self.value_stream(x)
        advantage = self.advantage_stream(x)
        q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
        return q_values

    def reset_noise(self):
        # Reset noise in NoisyLinear layers
        for m in self.modules():
            if isinstance(m, NoisyLinear):
                m.reset_noise()

## Action Selection
As in the paper, we use an epsilon-greedy policy to select actions during training. We start with a high epsilon value and decay it over time. In addition to the NoisyLinear layers.

In [53]:
def select_action(state, policy_net, epsilon, env):
    if np.random.rand(1) < epsilon:  # Explore
        return env.action_space.sample()
    else:  # Exploit
        with torch.no_grad():
            policy_net.reset_noise()
            state = (
                torch.tensor(np.array(state), dtype=torch.float32)
                .to(device)
                .unsqueeze(0)
            )
            return policy_net(state).argmax(dim=1).item()

## Gymnasium Environment Setup

Here we set up the gym environment, by selecting the Breakout game. We specify RMSProp as the optimizer just like in the paper training details.

In [54]:
def make_env(env_id, render_mode=None, frame_skip=4):
    """Create environment with preprocessing wrappers."""
    env = gym.make(env_id, render_mode=render_mode, frameskip=1)
    env = gym.wrappers.AtariPreprocessing(env, frame_skip=4)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = gym.wrappers.FrameStack(env, 4)
    return env


env = make_env("ALE/Breakout-v5")
action_space = [i for i in range(env.action_space.n)]

# Initialize Dueling Noisy Networks
policy_net = DuelingDQN(env.action_space.n).to(device)
target_net = DuelingDQN(env.action_space.n).to(device)

# Use Adam instead to speed up training
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
replay_buffer = PrioritizedReplayBuffer(
    REPLAY_MEMORY_SIZE, (4, 84, 84), (1,), alpha=ALPHA
)

epsilon = MAX_EPSILON  # Starting value of epsilon

In [55]:
input_names = ["Image Sequence"]
output_names = ["Q-Values"]
torch.onnx.export(
    policy_net.to("cpu"),
    torch.rand((1, 4, 84, 84)),
    "rainbow.onnx",
    input_names=input_names,
    output_names=output_names,
)

In [56]:
import torch

def count_trainable_parameters(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params

print(f"Trainable parameters: {count_trainable_parameters(policy_net):,}")

Trainable parameters: 6,507,690


In [12]:
policy_net.to(device)

DuelingDQN(
  (conv): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (value_stream): Sequential(
    (0): NoisyLinear()
    (1): ReLU()
    (2): NoisyLinear()
  )
  (advantage_stream): Sequential(
    (0): NoisyLinear()
    (1): ReLU()
    (2): NoisyLinear()
  )
)

## Training Loop
Here, we follow the methodology from the paper, by putting all the above components together to train the DQN agent on the Atari game.

In [13]:
total_steps = 0

plot_infos = {
    "total_steps": [],
    "total_reward": [],
    "epsilon": [],
    "total_q_values": [],
    "total_loss": [],
}

progress_bar = tqdm(total=MAX_STEPS, desc="Training Progress")

n_step_buffer = deque(maxlen=N_STEP)
gamma = DISCOUNT_FACTOR

while total_steps < MAX_STEPS:
    state, info = env.reset()  # Reset environment to initial state
    n_step_buffer.clear()

    # Decay epsilon for exploration-exploitation tradeoff
    epsilon = max(
        MIN_EPSILON,
        MAX_EPSILON
        - (total_steps * (MAX_EPSILON - MIN_EPSILON) / (MAX_STEPS * EPSILON_PHASE)),
    )

    total_reward = 0
    total_q_values = 0
    total_loss = 0
    episode_steps = 0

    while True:
        action = select_action(state, policy_net, epsilon, env)

        # Step the environment
        next_state, reward, done, truncated, info = env.step(action)

        # Clip the reward to be in the range [-1, 1] as mentioned in the paper
        total_reward += reward
        reward = np.sign(reward)

        # Store the transition in N-step buffer
        n_step_buffer.append((state, action, reward, next_state, done))

        if len(n_step_buffer) == N_STEP:
            # Get N-step transition and add to replay buffer
            n_state, n_action, n_reward, n_next_state, n_done = (
                n_step_buffer[0][0],
                n_step_buffer[0][1],
                0,
                next_state,
                done,
            )
            for idx, (s, a, r, s_, d) in enumerate(n_step_buffer):
                n_reward += (gamma**idx) * n_step_buffer[idx][2]
                if n_step_buffer[idx][4]:
                    n_done = True
                    n_next_state = n_step_buffer[idx][3]
                    break
            replay_buffer.append(
                n_state, n_next_state, np.array([n_action]), n_reward, n_done
            )

        total_steps += 1
        episode_steps += 1

        # Only start training when replay memory has enough samples
        if len(replay_buffer) >= REPLAY_START_SIZE:
            if total_steps % 4 == 0:  # Update every 4 steps like in the paper
                beta = min(
                    1.0, BETA_START + total_steps * (1.0 - BETA_START) / BETA_FRAMES
                )

                # Sample minibatch from replay buffer
                (
                    (
                        batch_state,
                        batch_next_state,
                        batch_action,
                        batch_reward,
                        batch_done,
                    ),
                    indices,
                    weights,
                ) = replay_buffer.sample(BATCH_SIZE, beta)

                # Compute Q targets using Double DQN
                with torch.no_grad():
                    not_done = ~batch_done.bool()
                    policy_net.reset_noise()
                    target_net.reset_noise()
                    # Double DQN target computation
                    next_q_actions = policy_net(batch_next_state).argmax(
                        dim=1, keepdim=True
                    )
                    next_q_values = (
                        target_net(batch_next_state)
                        .gather(1, next_q_actions)
                        .squeeze(1)
                    )
                    target_q_values = (
                        batch_reward
                        + (DISCOUNT_FACTOR**N_STEP) * next_q_values * not_done
                    )

                optimizer.zero_grad()

                # Get current Q estimates
                policy_net.reset_noise()
                q_values = policy_net(batch_state)
                batch_action = batch_action.long()
                idx = torch.arange(batch_action.size(0)).to(device).long()
                values = q_values[idx, batch_action.squeeze(1)]

                # Compute loss
                td_errors = target_q_values - values
                loss = torch.nn.functional.huber_loss(
                    values, target_q_values, reduction="none"
                )
                loss = (weights * loss).mean()

                # Backpropagate and update the network
                loss.backward()
                optimizer.step()

                total_q_values += q_values.mean().item()
                total_loss += loss.item()

                # Update priorities
                new_priorities = td_errors.abs().detach().cpu().numpy() + 1e-6
                replay_buffer.update_priorities(indices.cpu().numpy(), new_priorities)

            # Update target network periodically
            if total_steps % TARGET_UPDATE_FREQ == 0:
                target_net = deepcopy(policy_net)

            # Save checkpoints every SAVE_FREQUENCY steps
            if total_steps % SAVE_FREQUENCY == 0:
                torch.save(
                    policy_net.state_dict(),
                    f"../checkpoints/rainbow/checkpoint_{total_steps}.pth",
                )
        state = next_state

        if done or truncated:
            # Process remaining transitions in N-step buffer
            while len(n_step_buffer) > 0:
                n_state, n_action, n_reward, n_next_state, n_done = (
                    n_step_buffer[0][0],
                    n_step_buffer[0][1],
                    0,
                    next_state,
                    done,
                )
                for idx, (s, a, r, s_, d) in enumerate(n_step_buffer):
                    n_reward += (gamma**idx) * n_step_buffer[idx][2]
                    if n_step_buffer[idx][4]:
                        n_done = True
                        n_next_state = n_step_buffer[idx][3]
                        break
                replay_buffer.append(
                    n_state, n_next_state, np.array([n_action]), n_reward, n_done
                )
                n_step_buffer.popleft()
            break

    # Append the total reward for tracking
    plot_infos["total_reward"].append(total_reward)
    plot_infos["epsilon"].append(epsilon)
    plot_infos["total_steps"].append(total_steps)
    plot_infos["total_q_values"].append(total_q_values / max(1, episode_steps))
    plot_infos["total_loss"].append(total_loss / max(1, episode_steps))

    progress_bar.set_description(
        f"R: {plot_infos['total_reward'][-1]}, ε: {plot_infos['epsilon'][-1]:.5f}, RSize: {len(replay_buffer)} Q-value: {plot_infos['total_q_values'][-1]:.5f}, Loss: {plot_infos['total_loss'][-1]:.5f}"
    )
    progress_bar.update(episode_steps)

progress_bar.close()
env.close()

Training Progress:   0%|          | 0/100000 [00:00<?, ?it/s]

R: 2.0, ε: 0.10000, RSize: 100442 Q-value: 0.08875, Loss: 0.00020: : 100101it [05:53, 283.34it/s]                         


*The trainings were run on Kaggle, so the training logs are not included in this notebook.*

In [14]:
import pandas as pd

# Convert the dictionary to a DataFrame
df_plot_infos = pd.DataFrame(plot_infos)

# Save the DataFrame to a CSV file
df_plot_infos.to_csv("../data/rainbow_dqn_plot_infos.csv", index=False)

## Result
*The tests were run on Kaggle, so the tests logs are not included in this notebook.*

In [17]:
# Function to load model weights from checkpoint file
def load_checkpoint(model, checkpoint_file):
    model.load_state_dict(torch.load(checkpoint_file, map_location=device))
    model.eval()  # Set the model to evaluation mode (important for inference)


# Function to play a single episode and return the total reward
def play_episode(env, model):
    obs, info = env.reset()
    state = (
        torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
    )  # Move to correct device
    total_reward = 0

    done = False
    while not done:
        if np.random.rand() < 0.05:
            action = env.action_space.sample()  # Random action with 5% probability
        else:
            with torch.no_grad():
                action = (
                    model(state).argmax(dim=1).item()
                )  # Choose action with highest Q-value
        next_obs, reward, done, truncated, info = env.step(action)
        next_state = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0).to(device)

        total_reward += reward
        state = next_state

        if done or truncated:
            break

    return total_reward


# Function to evaluate the model by playing 50 games
def evaluate_model(checkpoint_file, num_games=50):
    # Create the environment
    env = make_env("ALE/Breakout-v5", frame_skip=4)

    # Initialize model
    action_space = env.action_space.n
    model = DuelingDQN(action_space).to(device)

    # Load the best checkpoint
    load_checkpoint(model, checkpoint_file)

    total_rewards = []
    for game in range(num_games):
        total_reward = play_episode(env, model)
        total_rewards.append(total_reward)
        print(f"Game {game + 1}, Reward: {total_reward}")

    # Calculate average reward
    avg_reward = np.mean(total_rewards)
    std_reward = np.std(total_rewards)
    max_reward = np.max(total_rewards)
    min_reward = np.min(total_rewards)

    print(f"Average Reward: {avg_reward}")
    print(f"Standard Deviation: {std_reward}")
    print(f"Max Reward: {max_reward}")
    print(f"Min Reward: {min_reward}")

    env.close()


# Call the function to evaluate the model
best_checkpoint_path = "../checkpoints/rainbow/checkpoint_400000.pth"
evaluate_model(best_checkpoint_path)

  model.load_state_dict(torch.load(checkpoint_file, map_location=device))
  torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)


Game 1, Reward: 22.0
Game 2, Reward: 13.0
Game 3, Reward: 23.0
Game 4, Reward: 11.0
Game 5, Reward: 15.0
Game 6, Reward: 17.0
Game 7, Reward: 15.0
Game 8, Reward: 23.0
Game 9, Reward: 23.0
Game 10, Reward: 20.0
Game 11, Reward: 15.0
Game 12, Reward: 12.0
Game 13, Reward: 25.0
Game 14, Reward: 22.0
Game 15, Reward: 23.0
Game 16, Reward: 38.0
Game 17, Reward: 30.0
Game 18, Reward: 30.0
Game 19, Reward: 37.0
Game 20, Reward: 21.0
Game 21, Reward: 21.0
Game 22, Reward: 19.0
Game 23, Reward: 12.0
Game 24, Reward: 33.0
Game 25, Reward: 27.0
Game 26, Reward: 11.0
Game 27, Reward: 12.0
Game 28, Reward: 16.0
Game 29, Reward: 22.0
Game 30, Reward: 10.0
Game 31, Reward: 35.0
Game 32, Reward: 16.0
Game 33, Reward: 15.0
Game 34, Reward: 29.0
Game 35, Reward: 25.0
Game 36, Reward: 31.0
Game 37, Reward: 9.0
Game 38, Reward: 19.0
Game 39, Reward: 17.0
Game 40, Reward: 12.0
Game 41, Reward: 23.0
Game 42, Reward: 11.0
Game 43, Reward: 21.0
Game 44, Reward: 23.0
Game 45, Reward: 32.0
Game 46, Reward: 16.