## Importing depedencies

In [1]:
import os
import random
import time
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import ale_py
from collections import deque
from gym.wrappers import AtariPreprocessing, FrameStack
from collections import defaultdict

In [2]:
gym.register_envs(ale_py)

In [3]:
from gymnasium.wrappers import GrayscaleObservation

## DQN

In [4]:
env_id = "ALE/Breakout-v5"      # The environment to train on
seed = 1                               # Random seed for reproducibility
total_timesteps = 10000000             # Total timesteps to run the training
learning_rate = 1e-4                   # Learning rate for the optimizer
num_envs = 1                           # Number of parallel environments
buffer_size = 1000000                  # Size of the replay buffer
gamma = 0.99                           # Discount factor
tau = 1.0                              # Target network update rate
target_network_frequency = 1000        # Timesteps to update target network
batch_size = 32                        # Batch size for training
start_e = 1                            # Starting epsilon for exploration
end_e = 0.01                           # Ending epsilon for exploration
learning_starts = 80000                # Timesteps before starting training
train_frequency = 4                    # Frequency of training
exploration_fraction = 0.1  # Example value

In [5]:
# ---- Environment Setup ---- #
def make_env(env_id, seed, idx, capture_video=False, run_name=""):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)

        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)

        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)

        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayscaleObservation(env)
        env = gym.wrappers.FrameStackObservation(env, 4)
        env.action_space.seed(seed)

        return env

    return thunk

In [6]:
# ---- Q-Network with Attention ---- #
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x / 255.0)

In [7]:
class TargetNetwork(QNetwork):
    def __init__(self, env):
        super(TargetNetwork, self).__init__(env)

In [8]:
class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.buffer = []
    
    def add(self, experience):
        if len(self.buffer) >= self.buffer_size:
            self.buffer.pop(0)  # Remove the oldest experience if the buffer is full
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [9]:
class AttentionDQN:
    def __init__(self, env, q_network, target_network, learning_rate, gamma, tau, exploration_fraction):
        self.env = env
        self.q_network = q_network
        self.target_network = target_network
        self.optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.tau = tau
        self.exploration_fraction = exploration_fraction
        self.attention_weights = {}  # Dict to store attention weights for states
        self.td_error_history = {}   # Dict to store TD error history for states per episode

    def select_action(self, state, epsilon=0.0):
        # Exploration based on attention weight
        sigma = self.attention_weights.get(state, 1.0) * self.exploration_fraction
        if random.random() < sigma:
            return self.env.action_space.sample()  # Explore: random action
        else:
            # Exploit: choose action with max Q-value
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            q_values = self.q_network(state_tensor)
            return torch.argmax(q_values, dim=1).item()

    def update_attention_weights(self, state, td_error):
        # Accumulate TD error history per state for the episode
        if state not in self.td_error_history:
            self.td_error_history[state] = []
        self.td_error_history[state].append(abs(td_error))

        # Update attention weights based on cumulative TD error history
        cumulative_td_error = np.mean(self.td_error_history[state])  # Average TD error for the state
        self.attention_weights[state] = cumulative_td_error

    def train(self, total_timesteps, train_frequency, learning_starts, target_network_frequency):
        start_time = time.time()
        obs, _ = self.env.reset(seed=seed)

        # ---- Training Loop ---- #
        for global_step in range(total_timesteps):
            state = obs
            episodic_return = 0
            done = False

            # Reset TD error history for new episode
            episode_td_errors = []

            while not done:
                action = self.select_action(state)
                next_state, reward, done, truncated, info = self.env.step(action)

                # Compute TD error for this step
                with torch.no_grad():
                    q_values_next = self.target_network(torch.tensor(next_state, dtype=torch.float32).to(device))
                    target_value = reward + self.gamma * q_values_next.max(dim=1)[0]
                state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
                q_value = self.q_network(state_tensor)[0, action]
                td_error = target_value - q_value.item()

                # Accumulate TD error history for this episode
                episode_td_errors.append(td_error)

                # Store experience in buffer (optional)
                # Replay buffer can be implemented here if needed

                # Update Q-network
                loss = F.mse_loss(torch.tensor(target_value), torch.tensor(q_value))
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                state = next_state
                episodic_return += reward

            # After the episode, update attention weights for all states encountered in the episode
            for state, td_error in zip([state] * len(episode_td_errors), episode_td_errors):
                self.update_attention_weights(state, td_error)

            # Update target network
            if global_step % target_network_frequency == 0:
                for target_param, param in zip(self.target_network.parameters(), self.q_network.parameters()):
                    target_param.data.copy_(
                        self.tau * param.data + (1.0 - self.tau) * target_param.data
                    )

            # Print and log metrics, etc.
            if global_step % train_frequency == 0:
                print(f"Step {global_step}, Episodic Return: {episodic_return}")
                # Add any other logging (e.g., WandB, TensorBoard) here

In [10]:
import torch
import wandb
import gym

# ---- WandB Setup ---- #
run_name = "ADQN_Breakout"

wandb.init(
    project="ADQN-RL",
    name=run_name,
    config={
        "env_id": env_id,
        "seed": seed,
        "learning_rate": learning_rate,
        "buffer_size": buffer_size,
        "batch_size": batch_size,
        "gamma": gamma,
        "tau": tau,
        "train_frequency": train_frequency,
        "start_e": start_e,
        "end_e": end_e,
        "total_timesteps": total_timesteps,
    }
)

# ---- Environment Setup ---- #
envs = gym.vector.SyncVectorEnv(
    [make_env(env_id, seed + i, i, capture_video=False, run_name=run_name) for i in range(num_envs)]
)

# ---- Initialize Q-Network and Target Network ---- #
q_network = QNetwork()  # Initialize the Q-network (ensure this is defined correctly)
target_network = TargetNetwork()  # Initialize the target network (ensure this is defined correctly)

# ---- Initialize AttentionDQN Agent ---- #
adqn = AttentionDQN(
    envs=envs,
    q_network=q_network,
    target_network=target_network,
    learning_rate=learning_rate,
    gamma=gamma,
    tau=tau,
    exploration_fraction=exploration_fraction
)

# ---- Training ---- #
adqn.train()

# ---- Save Model ---- #
model_path = f"runs/{run_name}/adqn_model.pth"
torch.save(adqn.q_network.state_dict(), model_path)
wandb.save(model_path)
print(f"Model saved to {model_path} and uploaded to WandB")

envs.close()
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcasarulez[0m ([33mHarish-Personal[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


NamespaceNotFound: Namespace ALE not found. Have you installed the proper package for ALE?