#### CENG501 Final Project - Reset Deep Ensembles
Authors: Ege Uğur Aguş and Atakan Botasun </br>

A Reset Deep Ensemble (RDE) script for the Atari-100k benchmark, using the hyperparameters from the paper's Section 4 / Appendix B.

Specifically for AlienNoFrameskip-v4: </br>
- 100k environment steps total
- Reset interval: 8e4 (80,000)
- Reset depth: "last1"
- Replay buffer size: 1e5
- Min replay size: 1e4
- Batch size: 32
- Target net update period: 1 (i.e., every training update)
- Max gradient norm: 10
- Softmax β = 50
- Possibly do 4 updates per environment step (replay ratio = 4).

#### Imports

In [13]:
import random
import time
from collections import deque
from typing import Union

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import gymnasium as gym
import ale_py
# Register the Atari environments (ALE) and Minigrid
gym.register_envs(ale_py)

from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

#### Hyperparameter Setup

In [14]:
#   Hyperparameters (referenced from the paper)
ENV_ID              = "Freeway-v4"
N_ENSEMBLE          = 2           # # of ensemble agents
N_ENVS              = 1           # Single env to can be used to keep it truly at N time steps
LR                  = 1e-4        # Learning rate
GAMMA               = 0.99        # Discount factor
BATCH_SIZE          = 32          # Batch size
REPLAY_BUFFER_SIZE  = 10_000     # Replay buffer size (paper uses 100 000) - reduced to 10% of original for performance
MIN_REPLAY_SIZE     = 1_000      # Mininum replay buffer size (paper uses 10 000) - reduced to 10% of original for performance
TOTAL_TIMESTEPS     = 100_000     # Total time steps
TRAIN_FREQUENCY     = 1           # train every environment step
UPDATES_PER_STEP    = 1           # replay ratio (1,2,4) => pick 4
TARGET_UPDATE_FREQ  = 1           # update target net every training step

RESET_FREQUENCY     = 40_000      # 8e4 for Alien, Pong, 4e4 for Freeway
RESET_DEPTH         = "last1"     # "last1" for Alien
                                  # "last2" for Pong
SOFTMAX_BETA        = 50          # beta for action selection

EPS_START           = 1.0
EPS_END             = 0.01
EPS_DECAY_FRAC      = 0.2         # decay epsilon over 0.01% of total steps => 10k

MAX_GRAD_NORM       = 10          # paper indicates max grad norm = 10

# View performance during runs
LOG_INTERVAL        = 1_000
RENDER_INTERVAL     = 500

#### Environment Setup

In [15]:
class StackedFrames(gym.Wrapper):
    def __init__(self, env, n_stack=4, channels_last=True):
        super().__init__(env)
        self.n_stack = n_stack
        self.channels_last = channels_last
        self.stack = None
        obs_shape = env.observation_space.shape
        print(obs_shape)
        if channels_last:
            self.observation_space = gym.spaces.Box(
                low=0, high=255,
                shape=(obs_shape[0], obs_shape[1], n_stack),
                dtype=np.uint8
            )
        else:  # Channels first
            self.observation_space = gym.spaces.Box(
                low=0, high=255,
                shape=(obs_shape[0] * n_stack, obs_shape[1], 1),
                dtype=np.uint8
            )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        if self.channels_last:
            self.stack = np.repeat(obs[..., np.newaxis], self.n_stack, axis=-1)
        else:
            self.stack = np.repeat(obs[np.newaxis, ...], self.n_stack, axis=0)
        return self.stack, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        if self.channels_last:
            self.stack = np.roll(self.stack, shift=-obs.shape[-1], axis=-1)
            self.stack[..., -obs.shape[-1]:] = obs
        else:
            self.stack = np.roll(self.stack, shift=-obs.shape[0], axis=0)
            self.stack[-obs.shape[0]:, ...] = obs
        return self.stack, reward, terminated, truncated, info

# Build multiple VecEnv with StackedFrames
def make_vec_env_gym(env_id, n_envs=1, seed=0, channels_last=True):
    def _make_env():
        env = gym.make(env_id, render_mode=None) # Important for vectorization
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayscaleObservation(env)
        if isinstance(env.action_space, gym.spaces.Box):
            env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        if not channels_last:
            obs_shape = env.observation_space.shape
            print(obs_shape)
            new_obs_shape = (obs_shape[2], obs_shape[0], obs_shape[1])
            new_observation_space = gym.spaces.Box(low=0, high=1, shape=new_obs_shape, dtype=np.float32)
            env = gym.wrappers.TransformObservation(env, lambda obs: np.transpose(obs, (2, 0, 1)), observation_space=new_observation_space)
            env = gym.wrappers.TransformObservation(env, lambda obs: np.transpose(obs, (2, 0, 1)) if channels_last==False else obs, new_observation_space) # Channels first if needed
        return env

    if n_envs > 1:
      envs = gym.vector.VectorEnv([_make_env for _ in range(n_envs)])
      envs = gym.wrappers.VecFrameStack(envs, n_stack=4)
      return envs

    else:
      env = _make_env()
      env = StackedFrames(env, n_stack=4, channels_last=channels_last)
      return env
    
# Build single VecEnv with Stable-Baselines 3 implementation
def make_vec_env_sb3(env_id=ENV_ID, n_envs=N_ENVS, seed=0):
    # Gray-scaling, 84x84, 4 frame stacks, reward clipping, etc. are typical
    # for "make_atari_env". We assume that or we can provide wrapper_kwargs
    venv = make_atari_env(env_id, n_envs=n_envs, seed=seed)
    # Frame stacking => shape: (n_envs, 84,84,4)
    venv = VecFrameStack(venv, n_stack=4)
    venv.current_obs = None
    return venv

In [16]:
#   CNN Q-Network for Atari (3 conv layers + 1 FC)
class QNetworkAtari(nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        # From paper: 3 conv layers [32,64,64], kernels [8x8,4x4,3x3], strides [4,2,1].
        self.features = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        # 7x7 is typical after [8,4,3] filters/strides
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
        self.n_actions = n_actions
        self.reset_parameters("full")

    def forward(self, x):
        # x shape: [batch, 4, 84, 84]
        feats = self.features(x)
        feats = feats.contiguous().view(feats.size(0), -1)
        out = self.fc(feats)
        return out

    def reset_parameters(self, reset_depth="full"):
        """Reset some or all of the network parameters."""
        def _init_module(m):
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

        if reset_depth == "full":
            self.apply(_init_module)
        elif reset_depth == "last2":
            # Re-init the last 2 layers in self.fc
            if len(self.fc) == 3:
                _init_module(self.fc[-1])
                _init_module(self.fc[-3])
            else:
                raise ValueError("Unexpected architecture for partial reset.")
        elif reset_depth == "last1":
            # Re-init ONLY the final linear layer in self.fc
            if len(self.fc) == 3:
                _init_module(self.fc[-1])
            else:
                raise ValueError("Unexpected architecture for partial reset.")
        else:
            raise ValueError("Unknown reset depth option.")

class QNetworkMiniGrid(nn.Module):
    def _init_(self, obs_dim, n_actions, hidden_size = 256):
        super()._init_()
        # We'll define a 5-layer MLP (including output).
        # For example: input -> 256 -> 256 -> 256 -> 256 -> output(n_actions)
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

        # If we want partial resets, we can define "reset_parameters" similarly.

    def forward(self, x):
        # x shape: [batch, obs_dim]
        return self.net(x)

#   Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)

    def add(self, obs, action, reward, next_obs, done):
        self.buffer.append((obs, action, reward, next_obs, done))

    def sample(self, batch_size=32):
        batch = random.sample(self.buffer, batch_size)
        obs, acts, rews, next_obs, dones = zip(*batch)

        obs       = np.stack(obs)        # shape: [B, 4,84,84]
        acts      = np.array(acts, dtype=np.int64)
        rews      = np.array(rews, dtype=np.float32)
        next_obs  = np.stack(next_obs)
        dones     = np.array(dones, dtype=np.float32)
        return obs, acts, rews, next_obs, dones

    def __len__(self):
        return len(self.buffer)

# class ReplayBuffer:
#     def __init__(self, capacity, obs_shape, action_dim, device="cpu"):
#         self.capacity = capacity
#         self.device = device
#         self.obs = torch.zeros((capacity,) + obs_shape, dtype=torch.float32, device=self.device)
#         self.next_obs = torch.zeros((capacity,) + obs_shape, dtype=torch.float32, device=self.device)
#         self.actions = torch.zeros(capacity, dtype=torch.int64, device=self.device)
#         self.rewards = torch.zeros(capacity, dtype=torch.float32, device=self.device)
#         self.dones = torch.zeros(capacity, dtype=torch.float32, device=self.device)  # Store as float for easier use in loss functions
#         self.ptr = 0
#         self.size = 0

#     def add(self, obs, action, reward, next_obs, done):
#         self.obs[self.ptr] = torch.tensor(obs, dtype=torch.float32, device=self.device) if not torch.is_tensor(obs) else obs.to(self.device)
#         self.next_obs[self.ptr] = torch.tensor(next_obs, dtype=torch.float32, device=self.device) if not torch.is_tensor(next_obs) else next_obs.to(self.device)
#         self.actions[self.ptr] = action if torch.is_tensor(action) else torch.tensor(action, dtype=torch.int64, device=self.device)
#         self.rewards[self.ptr] = reward if torch.is_tensor(reward) else torch.tensor(reward, dtype=torch.float32, device=self.device)
#         self.dones[self.ptr] = float(done) if not torch.is_tensor(done) else done.to(self.device) # Ensure it's a float

#         self.ptr = (self.ptr + 1) % self.capacity
#         self.size = min(self.size + 1, self.capacity)
#         return self.obs, self.actions, self.rewards, self.next_obs, self.dones

#     def sample(self, batch_size=32):
#         indices = torch.randint(0, self.size, (batch_size,), device=self.device)
#         return (
#             self.obs[indices],
#             self.actions[indices],
#             self.rewards[indices].unsqueeze(1),  # Add dimension for rewards
#             self.next_obs[indices],
#             self.dones[indices].unsqueeze(1),    # Add dimension for dones
#         )

#     def __len__(self):
#         return self.size

#   Ensemble DQN Agent
class EnsembleDQNAgent:
    def __init__(self, n_actions,
                 n_ensemble=N_ENSEMBLE, lr=LR, gamma=GAMMA,
                 reset_freq=RESET_FREQUENCY, reset_depth=RESET_DEPTH,
                 softmax_beta=SOFTMAX_BETA, feature_type="atari"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Using device:", self.device)

        self.n_ensemble = n_ensemble
        self.gamma = gamma
        self.reset_freq = reset_freq
        self.reset_depth = reset_depth
        self.softmax_beta = softmax_beta

        self.q_networks = []
        self.target_networks = []
        self.optimizers = []

        for _ in range(n_ensemble):
            qnet = QNetworkAtari(n_actions).to(self.device)
            tnet = QNetworkAtari(n_actions).to(self.device)
            tnet.load_state_dict(qnet.state_dict())

            optimizer = optim.Adam(qnet.parameters(), lr=lr)
            self.q_networks.append(qnet)
            self.target_networks.append(tnet)
            self.optimizers.append(optimizer)

        self.global_step = 0

        # Round-robin
        self.last_reset_idx = 0
        # The 'oldest' agent is next after the just-reset agent
        self.oldest_agent_idx = (self.last_reset_idx + 1) % self.n_ensemble

        self.n_actions = n_actions

    def select_action(self, obs_np, epsilon=0.05):
        # Epsilon-greedy
        if random.random() < epsilon:
            return random.randint(0, self.n_actions - 1)

        # obs_np shape: (84,84,4) => transpose to (4,84,84)
        obs_ch_first = np.transpose(obs_np, (2,0,1))
        obs_t = torch.from_numpy(obs_ch_first).unsqueeze(0).float().to(self.device)

        # Each agent picks argmax
        candidate_actions = []
        with torch.no_grad():
            for qnet in self.q_networks:
                qvals = qnet(obs_t)
                a = qvals.argmax(dim=1).item()
                candidate_actions.append(a)

        # Use oldest agent's Q-values to compute softmax distribution
        oldest = self.oldest_agent_idx
        with torch.no_grad():
            qvals_oldest = self.q_networks[oldest](obs_t).squeeze(0)

        # For each agent's chosen action, get Q_oldest(s, a_i)
        r_values = []
        for act in candidate_actions:
            r_values.append(qvals_oldest[act].item())

        # Scale for stable softmax
        max_r = max(abs(v) for v in r_values) if r_values else 1.0
        if max_r == 0:
            max_r = 1.0
        scaled_r = [(val / max_r) * self.softmax_beta for val in r_values]
        exp_r = np.exp(scaled_r)
        sum_exp = np.sum(exp_r)
        if sum_exp <= 1e-9:
            probs = np.ones(self.n_ensemble) / self.n_ensemble
        else:
            probs = exp_r / sum_exp

        chosen_agent_idx = np.random.choice(self.n_ensemble, p=probs)
        return candidate_actions[chosen_agent_idx]

    def reset_agent(self, idx):
        """Reset the parameters for agent idx."""
        self.q_networks[idx].reset_parameters(self.reset_depth)
        self.target_networks[idx].load_state_dict(
            self.q_networks[idx].state_dict()
        )

    def step_env(self, vec_env, replay_buffer, obs_np=None, epsilon=0.05):
        """
        Interact with the single-env (n_envs=1) or multi-env. We store transitions.
        """
        if isinstance(vec_env, VecFrameStack):
            obs_np = vec_env.current_obs  # shape: (N_ENVS, 84,84,4)

        n_envs = obs_np.shape[0]

        # Get actions from the ensemble
        actions = []
        for i in range(n_envs):
            single_obs = obs_np[i]  # (84,84,4)
            a = self.select_action(single_obs, epsilon=epsilon)
            actions.append(a)

        next_obs, rewards, dones, infos = vec_env.step(actions)
        self.global_step += n_envs

        # Store transitions
        for i in range(n_envs):
            single_obs    = obs_np[i]
            single_next   = next_obs[i]
            single_reward = rewards[i]
            done_bool     = bool(dones[i])

            # Transpose to channels-first before storing
            obs_ch_first     = np.transpose(single_obs,  (2,0,1))
            next_ch_first    = np.transpose(single_next, (2,0,1))

            replay_buffer.add(obs_ch_first, actions[i],
                              single_reward, next_ch_first,
                              done_bool)

        # Round-robin reset
        if (self.global_step % self.reset_freq) == 0:
            print(f"[INFO] Resetting agent {self.last_reset_idx} at step {self.global_step}")
            self.reset_agent(self.last_reset_idx)
            self.oldest_agent_idx = (self.last_reset_idx + 1) % self.n_ensemble
            self.last_reset_idx = (self.last_reset_idx + 1) % self.n_ensemble

        # Return aggregated reward, done if any env ended
        return float(np.mean(rewards)), any(dones), next_obs, actions

    def train_on_batch(self, replay_buffer):
        if len(replay_buffer) < MIN_REPLAY_SIZE:
            return
        obs, acts, rews, next_obs, dones = replay_buffer.sample(BATCH_SIZE)

        obs_t      = torch.FloatTensor(obs).to(self.device)
        acts_t     = torch.LongTensor(acts).unsqueeze(1).to(self.device)
        rews_t     = torch.FloatTensor(rews).unsqueeze(1).to(self.device)
        next_obs_t = torch.FloatTensor(next_obs).to(self.device)
        dones_t    = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        for i in range(self.n_ensemble):
            with torch.no_grad():
                q_next = self.target_networks[i](next_obs_t)
                max_q_next = q_next.max(dim=1, keepdim=True)[0]
                target = rews_t + self.gamma * (1 - dones_t) * max_q_next

            current_q = self.q_networks[i](obs_t).gather(1, acts_t)
            # loss = nn.SmoothL1Loss()(current_q, target)
            loss = nn.HuberLoss()(current_q, target)

            self.optimizers[i].zero_grad()
            loss.backward()

            # Clip gradients per the paper (max grad norm=10)
            nn.utils.clip_grad_norm_(self.q_networks[i].parameters(), MAX_GRAD_NORM)

            self.optimizers[i].step()

    # Tensor version
    # def train_on_batch(self, replay_buffer):
    #     if len(replay_buffer) < MIN_REPLAY_SIZE:
    #         return

    #     obs_t, acts_t, rews_t, next_obs_t, dones_t = replay_buffer.sample(BATCH_SIZE)

    #     for i in range(self.n_ensemble):
    #         with torch.no_grad():
    #             q_next = self.target_networks[i](next_obs_t)
    #             max_q_next = q_next.max(dim=1, keepdim=True)[0]
    #             target = rews_t + self.gamma * (1 - dones_t) * max_q_next

    #         current_q = self.q_networks[i](obs_t).gather(1, acts_t)
    #         loss = nn.HuberLoss()(current_q, target)

    #         self.optimizers[i].zero_grad()
    #         loss.backward()

    #         nn.utils.clip_grad_norm_(self.q_networks[i].parameters(), MAX_GRAD_NORM)
    #         self.optimizers[i].step()
            
    def update_target_nets(self):
        """Here, we copy the online Q-net parameters into each target net."""
        for i in range(self.n_ensemble):
            self.target_networks[i].load_state_dict(
                self.q_networks[i].state_dict()
            )


In [17]:
# Main Training Loop
def run_training(env: Union[StackedFrames, VecFrameStack], render_env: Union[StackedFrames, VecFrameStack], verbose: int=0):
    start_time = time.monotonic()

    n_actions = env.action_space.n

    agent = EnsembleDQNAgent(n_actions=n_actions)

    obs_np = env.reset()  # shape: (N_ENVS, 84,84,4)
    render_obs = render_env.reset()
    env.current_obs = obs_np

    if verbose > 0:
        print(f"Observation space size: {env.observation_space.shape}")

    # replay_buffer = ReplayBuffer(obs_shape=env.observation_space.shape[::-1], action_dim=n_actions,capacity=REPLAY_BUFFER_SIZE, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    replay_buffer = ReplayBuffer(capacity=REPLAY_BUFFER_SIZE)

    episode_reward = 0.0
    episode_count = 0
    rewards_history = []

    train_steps = 0
    # We'll decay epsilon over the first 10% (EPS_DECAY_FRAC=0.1) of total timesteps => 10k steps
    eps_decay_steps = int(EPS_DECAY_FRAC * TOTAL_TIMESTEPS)

    print(f"Starting training for {TOTAL_TIMESTEPS} timesteps on {ENV_ID} with RDE...")

    # Timers
    iteration_start = time.monotonic()
    action_log = []

    while agent.global_step < TOTAL_TIMESTEPS:
        
        fraction = min(1.0, agent.global_step / eps_decay_steps)  # in [0,1]
        epsilon = EPS_START + fraction * (EPS_END - EPS_START)
        epsilon = max(EPS_END, epsilon)  # clamp

        r, done, obs_np, actions = agent.step_env(env, replay_buffer, obs_np=obs_np, epsilon=epsilon)
               
        episode_reward += r

        if done:
            episode_count += 1
            rewards_history.append(episode_reward)
            obs = env.reset()
            render_obs = render_env.reset()
            env.current_obs = obs
            episode_reward = 0.0

        # train multiple times => replay ratio=4
        if (agent.global_step % TRAIN_FREQUENCY) == 0:
            for _ in range(UPDATES_PER_STEP):
                agent.train_on_batch(replay_buffer)
                train_steps += 1

        # Because the table says "target update period=1",
        # we do it every training step
        if train_steps > 0 and (train_steps % TARGET_UPDATE_FREQ) == 0:
            agent.update_target_nets()

        # Rendering (just sample one of the actions from all environments)
        _ = render_env.step(actions[0])
        if (agent.global_step % RENDER_INTERVAL) == 0 and (agent.global_step > TOTAL_TIMESTEPS // 2):
            render_env.render()

        action_log.append(actions[0])

        # Logging
        if (agent.global_step % LOG_INTERVAL) == 0 and agent.global_step > 0:
            iteration_end = time.monotonic()
            last_10 = np.mean(rewards_history[-10:]) if len(rewards_history)>=10 else np.mean(rewards_history)
            elapsed = iteration_start-iteration_end
            print(f"Step={agent.global_step} | Episodes={episode_count} | "
                  f"AvgRew(last10)={last_10:.2f} | Eps={epsilon:.3f} | Real Time Elapsed={(time.time()-elapsed):.2f}s")
            print(f"Last taken actions: {action_log[-10:]}")

            iteration_start = time.monotonic()
        
    final_10 = np.mean(rewards_history[-10:]) if len(rewards_history)>=10 else 0.0
    total_elapsed = time.monotonic() - start_time
    print(f"Finished. Total elapsed time: {total_elapsed:.2f}s")

    env.close()
    render_env.close()
    return final_10

#### Run the model

In [18]:
env = make_vec_env_sb3(ENV_ID, N_ENVS)
render_env = gym.make(ENV_ID, render_mode="human")
final_10 = run_training(env, render_env, verbose=1)
print("Training done. Final 10-episode average reward:", final_10)

Using device: cuda
Observation space size: (84, 84, 4)
Starting training for 100000 timesteps on Freeway-v4 with RDE...
Step=1000 | Episodes=1 | AvgRew(last10)=0.00 | Eps=0.975 | Real Time Elapsed=1736538661.11s
Last taken actions: [1, 0, 1, 1, 1, 2, 2, 0, 1, 0]
Step=2000 | Episodes=2 | AvgRew(last10)=0.50 | Eps=0.951 | Real Time Elapsed=1736538710.68s
Last taken actions: [0, 1, 2, 1, 1, 0, 1, 1, 0, 0]
Step=3000 | Episodes=4 | AvgRew(last10)=0.25 | Eps=0.926 | Real Time Elapsed=1736538760.58s
Last taken actions: [1, 2, 0, 1, 2, 0, 2, 0, 1, 1]
Step=4000 | Episodes=5 | AvgRew(last10)=0.20 | Eps=0.901 | Real Time Elapsed=1736538811.21s
Last taken actions: [2, 0, 0, 2, 0, 1, 0, 2, 0, 2]
Step=5000 | Episodes=7 | AvgRew(last10)=0.14 | Eps=0.876 | Real Time Elapsed=1736538861.02s
Last taken actions: [0, 2, 0, 1, 0, 2, 1, 0, 1, 2]
Step=6000 | Episodes=8 | AvgRew(last10)=0.12 | Eps=0.852 | Real Time Elapsed=1736538911.00s
Last taken actions: [0, 0, 0, 1, 2, 1, 2, 2, 0, 2]
Step=7000 | Episodes=1

In [None]:
# Use this line if you need to prematurely kill the render_env
render_env.close()

: 