#### Imports

In [20]:
# Set up file paths
EXPERIMENT_NAME = "sb3_atari_dqn_ensemble_1"

MODEL_PATH = f"./models/{EXPERIMENT_NAME}"
LOG_PATH = f"./logs/{EXPERIMENT_NAME}"
TENSORBOARD_LOG_PATH = f"./logs/dqn_tensorboard_logs/atari/{EXPERIMENT_NAME}"

# Imports
## numpy
import numpy as np
## pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
## stable-baselines3
from stable_baselines3 import DQN
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList
from stable_baselines3.common.env_checker import check_env
# from stable_baselines3.common.env_util import make_atari_env # seems to be buggy
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor # required for minigrid
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import polyak_update, get_linear_fn, set_random_seed
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv, VecTransposeImage
from stable_baselines3.dqn.policies import CnnPolicy
## gymnasium
import gymnasium as gym
from gymnasium.spaces import Box
import ale_py
gym.register_envs(ale_py)
# from gymnasium.wrappers import FrameStackObservation, ClipReward
## plotly and pyplot
import matplotlib.pyplot as plt

In [21]:
# TensorBoard setup
writer = SummaryWriter(TENSORBOARD_LOG_PATH)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


#### Variables

In [22]:
# Hyperparameters
N_ENSEMBLE = 2  # Number of ensemble agents
RESET_FREQUENCY = 40000  # Reset frequency in timesteps
BETA = 50  # Action selection coefficient
REPLAY_BUFFER_SIZE = 100000  # Replay buffer size
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_STARTS = 2000  # Timesteps before training starts
TAU = 0.005  # Polyak update coefficient
TOTAL_TIMESTEPS = int(1e5)  # Total training timesteps
TRAIN_FREQ = 1  # Frequency of training (steps)
GRADIENT_STEPS = 1  # Gradient steps per update
TARGET_UPDATE_INTERVAL = 1  # Update target networks every step
N_STACK = 4  # Number of stacked frames

In [23]:
# eval_freq = 5000 # once every eval_freq timesteps, evaluate the model
# replay_ratio = 4 # run gradient calculations 4 times per step
env_type = "AlienNoFrameskip-v4" # use this emulation from Gymnasium environments

#### Environment and Model Setup

In [24]:
def make_env(env_type, rank=0, frameskip=1, render_mode=None, seed=0):
    def _init():
        env = gym.make(env_type, render_mode=render_mode, frameskip=frameskip) # frameskip is important
        env = AtariWrapper(env)
        env.action_space.seed(seed + rank)
        env.observation_space.seed(seed + rank)
        return env
    return _init

In [25]:
# Create and preprocess Atari environment
env = DummyVecEnv([make_env(env_type=env_type, seed=42)])
env = VecTransposeImage(env)
env = VecFrameStack(env, n_stack=N_STACK, channels_order='first')

eval_env = DummyVecEnv([make_env(env_type=env_type, seed=84)])
eval_env = VecTransposeImage(eval_env)
eval_env = VecFrameStack(eval_env, n_stack=N_STACK, channels_order='first')

# Debug observation space       
print(f"Current Observation Space: {env.observation_space.shape}")
print(f"Current Eval Observation Space: {eval_env.observation_space.shape}")

Current Observation Space: (4, 84, 84)
Current Eval Observation Space: (4, 84, 84)


In [26]:
a = np.zeros((4,84,84))
A = torch.Tensor(a)
A.transpose(0,2).shape # use maybe_transpose?

torch.Size([84, 84, 4])

In [27]:
# # Prepare an feature extractor capable of overseeing n_stack frames
# class CustomCNN(BaseFeaturesExtractor):
#     def __init__(self, observation_space, features_dim=512):
#         super().__init__(observation_space, features_dim)
#         n_input_channels = observation_space.shape[0]
#         self.cnn = nn.Sequential(
#             nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
#             nn.ReLU(),
#             nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
#             nn.ReLU(),
#             nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
#             nn.ReLU(),
#             nn.Flatten(),
#         )
#         with torch.no_grad():
#             n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]
#         self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

#     def forward(self, observations):
#         return self.linear(self.cnn(observations))

# # Prepare custom policy definition to fit into ensemble
# class CustomQPolicy(BasePolicy):
#     def __init__(
#         self,
#         observation_space,
#         action_space,
#         lr_schedule,
#         net_arch=None,
#         features_dim=512,
#         activation_fn=nn.ReLU,
#         ortho_init=True,
#         device="auto",
#     ):
#         super().__init__(
#             observation_space,
#             action_space,
#             lr_schedule,
#             ortho_init,
#             device,
#         )

#         self.features_extractor = CustomCNN(self.observation_space, features_dim=512)
#         self.q_net = nn.Linear(self.features_extractor.features_dim, action_space.n)
#         self.q_net = self.q_net.to(self.device)

#     def forward(self, obs, deterministic=True):
#       return self.q_net(self.extract_features(obs))

#     def _predict(self, observation, deterministic=True):
#         return self(observation)

#     def extract_features(self, obs):
#         return self.features_extractor(obs)

In [80]:
# Replay buffer
replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE, env.observation_space, env.action_space, device=device, n_envs=1)

# Ensemble setup

# ensemble_agents = [
#     CustomQPolicy(
#         observation_space=env.observation_space,
#         action_space=env.action_space,
#         lr_schedule=get_linear_fn(1e-4, 1e-5, 1.0),
#         features_dim=336,
#     ).to(device)
#     for _ in range(N_ENSEMBLE)
# ]
# target_networks = [
#     CustomQPolicy(
#         observation_space=env.observation_space,
#         action_space=env.action_space,
#         lr_schedule=get_linear_fn(1e-4, 1e-5, 1.0),
#         features_dim=336,
#     ).to(device)
#     for _ in range(N_ENSEMBLE)
# ]

ensemble_agents = [
    CnnPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        # feature_extractor_class=CustomCNN(observation_space=env.observation_space, n_stack=N_STACK),
        lr_schedule=get_linear_fn(1e-4, 1e-5, 1.0),
        net_arch=[256, 256], # do not use along with custom CNN definition
    ).to(device)
    for _ in range(N_ENSEMBLE)
]
target_networks = [
    CnnPolicy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        # feature_extractor_class=CustomCNN(observation_space=env.observation_space, n_stack=N_STACK),
        lr_schedule=get_linear_fn(1e-4, 1e-5, 1.0),
        net_arch=[256, 256], # do not use along with custom CNN definition
    ).to(device)
    for _ in range(N_ENSEMBLE)
]
optimizers = [optim.Adam(agent.parameters(), lr=1e-4) for agent in ensemble_agents]

#### Callback Setup

In [84]:
def reset_agent(agent):
    for layer in agent.q_net.modules():
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
def adaptive_action_selection(q_values_stack, beta):
    q_values_normalized = q_values_stack / (q_values_stack.max(dim=-1, keepdim=True)[0] + 1e-8)
    scaled_q_values = beta * torch.max(q_values_normalized, torch.zeros_like(q_values_normalized))
    # sum individual normalized and scaled Q-values from all DQNs in ensemble
    summed_q_values = scaled_q_values.sum(dim=0) 
    action_distributions = torch.softmax(summed_q_values, dim=-1)
    action = action_distributions.argmax()
    return action

#### Train and Run Model

In [30]:
# # Using DQN implementation of Stable-Baselines3 with modified callbacks

# # callback frequencies are scaled to stack counts to match the given actual game timestep
# eval_callback = EvalCallback(env, best_model_save_path=LOG_PATH, log_path=LOG_PATH,
#                              eval_freq=max(eval_freq // n_stack, 1), deterministic=True,
#                              render=True)
# # Create and attach the callback
# reset_callback = ResetWeightsCallback(reset_interval=max(reset_interval // n_stack, 1), verbose=1)

# callback_list = CallbackList([eval_callback, reset_callback])

# model = DQN(
#     policy= "CnnPolicy", 
#     env= env, 
#     verbose= 1, 
#     buffer_size= timesteps,
#     learning_starts= 2000,
#     tau= 0.005,
#     train_freq= (1, "step"),
#     gradient_steps= replay_ratio,
#     target_update_interval= 1,
#     policy_kwargs= policy_kwargs,
#     tensorboard_log="./dqn_tensorboard_logs/atari",
#     )

In [87]:
# Training loop
state = env.reset()
print(f"state shape: {state.shape}")
current_agent_index = 0
step_count = 0

for step in range(TOTAL_TIMESTEPS):
    state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
    print(f"state tensor shape: {state_tensor.shape}")
    q_values_stack = torch.stack([agent.q_net(state_tensor) for agent in ensemble_agents])
    action = adaptive_action_selection(q_values_stack, BETA) 
    print(f"Action value: {action}")
    next_state, reward, done, info = env.step([action.item()])

    replay_buffer.add(state, next_state, action, reward, done, info)
    state = next_state

    if replay_buffer.size() > BATCH_SIZE and step > LEARNING_STARTS:
        batch = replay_buffer.sample(BATCH_SIZE)
        observations = batch.observations.to(device)
        next_observations = batch.next_observations.to(device)

        for i, agent in enumerate(ensemble_agents):
            q_values = agent.q_net(observations).gather(1, batch.actions.to(device).long())
            with torch.no_grad():
                target_q_values = target_networks[i].q_net(next_observations).max(1, keepdim=True)[0]
                target = batch.rewards.to(device) + GAMMA * (1 - batch.dones.to(device)) * target_q_values

            loss = F.smooth_l1_loss(q_values, target)
            optimizers[i].zero_grad()
            loss.backward()
            optimizers[i].step()

    if step % TARGET_UPDATE_INTERVAL == 0:
        for i in range(N_ENSEMBLE):
            polyak_update(ensemble_agents[i].q_net.parameters(), target_networks[i].q_net.parameters(), TAU)

    if step % RESET_FREQUENCY == 0:
        reset_agent(ensemble_agents[current_agent_index])
        current_agent_index = (current_agent_index + 1) % N_ENSEMBLE

    if np.any(done):
        state = env.reset()

    if step % 1000 == 0:
        print(f"Step: {step}, Average Reward: {np.mean(reward)}")

writer.close()

state shape: (1, 4, 84, 84)
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 7
Step: 0, Average Reward: 1.0
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 9
state tensor shape: torch.Size([1, 4, 84, 84])
Action value: 

KeyboardInterrupt: 

55.125

#### Evaluation

In [None]:
# Evaluation loop
def evaluate(agents, env, n_eval_episodes=10):
    total_rewards = []
    for _ in range(n_eval_episodes):
        state = env.reset()
        episode_reward = np.zeros(env.num_envs)
        done = np.zeros(env.num_envs, dtype=bool)

        while not np.all(done):
            state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
            q_values = torch.stack([agent.q_net(state_tensor) for agent in agents])
            action = torch.argmax(q_values.mean(dim=0), dim=1).cpu().numpy()

            state, reward, done, _ = env.step(action)
            episode_reward += reward

        total_rewards.append(np.mean(episode_reward))

    avg_reward = np.mean(total_rewards)
    print(f"Evaluation Results: Mean Reward = {avg_reward}")
    return avg_reward

# Perform evaluation
evaluate(ensemble_agents, eval_env)

# Close environments
env.close()
eval_env.close()

#### Code Scraps
This section contains intermediate examples and tests and aren't used by the model

##### Adaptive Action Selection

In [None]:
state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
print(f"state tensor shape: {state_tensor.shape}")
ensemble_agents[0].q_net(state_tensor)
q_values_stack = torch.stack([agent.q_net(state_tensor) for agent in ensemble_agents])
print(f"q_values_stack shape: {q_values_stack.shape}")

state tensor shape: torch.Size([1, 4, 84, 84])
q_values_stack shape: torch.Size([2, 1, 18])


In [None]:
q_values_stack = torch.stack([agent.q_net(state_tensor) for agent in ensemble_agents])

In [None]:
s = [agent.q_net(state_tensor) for agent in ensemble_agents][0]

In [None]:
print(s)

tensor([[-0.0011, -0.0059,  0.0130,  0.0273,  0.0272, -0.0241,  0.0299, -0.0614,
          0.0760, -0.0496, -0.0033, -0.0372,  0.0053, -0.0052, -0.0411, -0.0272,
         -0.0313,  0.0320]], device='cuda:0', grad_fn=<AddmmBackward0>)


In [None]:
q_values_normalized = s / (s.max(dim=-1, keepdim=True)[0] + 1e-8)
print(q_values_normalized)

tensor([[-0.0150, -0.0771,  0.1705,  0.3590,  0.3585, -0.3171,  0.3943, -0.8082,
          1.0000, -0.6535, -0.0439, -0.4903,  0.0695, -0.0689, -0.5408, -0.3583,
         -0.4118,  0.4212]], device='cuda:0', grad_fn=<DivBackward0>)


In [None]:
scaled_q_values = BETA * torch.max(q_values_normalized, torch.zeros_like(q_values_normalized))
print(scaled_q_values)
print(scaled_q_values.shape)

tensor([[ 0.0000,  0.0000,  8.5257, 17.9475, 17.9251,  0.0000, 19.7133,  0.0000,
         50.0000,  0.0000,  0.0000,  0.0000,  3.4735,  0.0000,  0.0000,  0.0000,
          0.0000, 21.0619]], device='cuda:0', grad_fn=<MulBackward0>)
torch.Size([1, 18])


In [None]:
summed_q_values = scaled_q_values.sum(dim=0)
print(summed_q_values)
print(summed_q_values.shape)

tensor([ 0.0000,  0.0000,  8.5257, 17.9475, 17.9251,  0.0000, 19.7133,  0.0000,
        50.0000,  0.0000,  0.0000,  0.0000,  3.4735,  0.0000,  0.0000,  0.0000,
         0.0000, 21.0619], device='cuda:0', grad_fn=<SumBackward1>)
torch.Size([18])


In [None]:
action_distributions = torch.softmax(summed_q_values, dim=-1)
print(action_distributions)
print(action_distributions.shape)

tensor([1.9288e-22, 1.9288e-22, 9.7262e-19, 1.2017e-14, 1.1750e-14, 1.9288e-22,
        7.0253e-14, 1.9288e-22, 1.0000e+00, 1.9288e-22, 1.9288e-22, 1.9288e-22,
        6.2202e-21, 1.9288e-22, 1.9288e-22, 1.9288e-22, 1.9288e-22, 2.7061e-13],
       device='cuda:0', grad_fn=<SoftmaxBackward0>)
torch.Size([18])


In [None]:
action = action_distributions.argmax()
print(action)

tensor(8, device='cuda:0')


In [None]:
q_values_stack.shape

torch.Size([2, 1, 18])

In [72]:
# Procedure to figure out the adaptive action selection method

state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
print(f"state tensor shape: {state_tensor.shape}")
ensemble_agents[0].q_net(state_tensor)
q_values_stack = torch.stack([agent.q_net(state_tensor) for agent in ensemble_agents])
print(f"q_values_stack shape: {q_values_stack.shape}")

q_values_normalized = q_values_stack / (q_values_stack.max(dim=-1, keepdim=True)[0] + 1e-8)
print(q_values_normalized)

scaled_q_values = BETA * torch.max(q_values_normalized, torch.zeros_like(q_values_normalized))
print(scaled_q_values)
print(scaled_q_values.shape)

summed_q_values = scaled_q_values.sum(dim=0)
print(summed_q_values)
print(summed_q_values.shape)

action_distributions = torch.softmax(summed_q_values, dim=-1)
print(action_distributions)
print(action_distributions.shape)

action = action_distributions.argmax()
print(action)

state tensor shape: torch.Size([1, 4, 84, 84])
q_values_stack shape: torch.Size([2, 1, 18])
tensor([[[-0.0150, -0.0771,  0.1705,  0.3590,  0.3585, -0.3171,  0.3943,
          -0.8082,  1.0000, -0.6535, -0.0439, -0.4903,  0.0695, -0.0689,
          -0.5408, -0.3583, -0.4118,  0.4212]],

        [[ 0.8975, -1.8960, -0.9367,  0.0646,  0.7456, -0.6327, -1.7249,
           1.0000,  0.1268,  0.2321,  0.5314,  0.7750, -0.5035, -0.2335,
          -0.2739, -0.6958, -0.0054, -0.2948]]], device='cuda:0',
       grad_fn=<DivBackward0>)
tensor([[[ 0.0000,  0.0000,  8.5257, 17.9475, 17.9251,  0.0000, 19.7133,
           0.0000, 50.0000,  0.0000,  0.0000,  0.0000,  3.4735,  0.0000,
           0.0000,  0.0000,  0.0000, 21.0619]],

        [[44.8770,  0.0000,  0.0000,  3.2288, 37.2805,  0.0000,  0.0000,
          50.0000,  6.3413, 11.6057, 26.5723, 38.7501,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000]]], device='cuda:0',
       grad_fn=<MulBackward0>)
torch.Size([2, 1, 18])
tensor([[