## Dependencies

In [1]:
import numpy as np
import gym
import tensorflow as tf
from tensorflow import keras
from collections import deque
import random
import cv2
import wandb
from tqdm import tqdm
from collections import defaultdict

2025-02-13 11:45:38.849741: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-13 11:45:38.879614: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-13 11:45:38.879644: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-13 11:45:38.880493: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-13 11:45:38.885689: I tensorflow/core/platform/cpu_feature_guar

## Initializations and Hyperparameters

In [2]:
# tau=0.2 
# lambda_=1.0
# beta_start=0.4
# beta_end=1.0
# T=1000000
# gamma=0.99

## Preprocessing of ATARI frames

In [3]:
def preprocess_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)  # Grayscale
    frame = cv2.resize(frame, (84, 84))  # Resize to 84x84
    frame = frame / 255.0  # Normalize
    return frame

## Building the DQN

In [4]:
def build_q_network(input_shape, action_size):
    model = keras.Sequential([
        keras.layers.Conv2D(32, (8, 8), strides=4, activation='relu', input_shape=input_shape),
        keras.layers.Conv2D(64, (4, 4), strides=2, activation='relu'),
        keras.layers.Conv2D(64, (3, 3), strides=1, activation='relu'),
        keras.layers.Flatten(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dense(action_size, activation='linear')
    ])
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.00025), loss='mse')
    return model

## Replay buffer

In [5]:
class ReplayBuffer:
    def __init__(self, capacity=100000):
        self.buffer = []
        self.capacity = capacity
        self.position = 0

    def add(self, experience):
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.position] = experience
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states, dones = zip(*batch)
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

    def size(self):
        return len(self.buffer)

## Main AT-DQN Object

In [6]:
class ATDQNAgent:
    def __init__(self, action_size, state_shape, tau=0.2, lambda_=1.0, beta_start=0.4, beta_end=1.0, T=1000000):
        self.action_size = action_size
        self.q_network = build_q_network(state_shape, action_size)
        self.target_network = build_q_network(state_shape, action_size)
        self.target_network.set_weights(self.q_network.get_weights())

        self.replay_buffer = ReplayBuffer()
        self.gamma = 0.99  

        # Attention Mechanism
        self.alpha = defaultdict(lambda: 1.0)  # Default attention weight = 1
        self.td_errors = defaultdict(list)

        # Exploration and importance weighting
        self.tau = tau  
        self.lambda_ = lambda_  

        # Beta Annealing
        self.beta = beta_start
        self.beta_end = beta_end
        self.delta_beta = (beta_end - beta_start) / T  

    def get_attention(self, state):
        return self.alpha[tuple(state.flatten())]

    def update_attention(self, state, td_error):
        state_key = tuple(state.flatten())
        self.td_errors[state_key].append(abs(td_error))
        self.alpha[state_key] = np.mean(self.td_errors[state_key])

    def normalize_attention(self):
        max_alpha = max(self.alpha.values(), default=1)
        for state_key in self.alpha:
            self.alpha[state_key] /= max_alpha  

    def compute_importance_weight(self, state, N):
        alpha_s = self.alpha[tuple(state.flatten())]
        return (1 / (N * alpha_s)) ** self.beta

    def act(self, state):
        sigma = self.get_attention(state) * self.lambda_
        if sigma > self.tau:
            return np.random.choice(self.action_size)  
        else:
            q_values = self.q_network.predict(np.expand_dims(state, axis=0), verbose=0)
            return np.argmax(q_values[0])  

    def train(self, batch_size=32):
        if self.replay_buffer.size() < batch_size:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)

        target_q_values = self.target_network.predict(next_states, verbose=0)
        max_next_q = np.max(target_q_values, axis=1)
        targets = rewards + (1 - dones) * self.gamma * max_next_q

        q_values = self.q_network.predict(states, verbose=0)
        td_errors = targets - q_values[np.arange(batch_size), actions]

        # Update attention weights
        for i in range(batch_size):
            self.update_attention(states[i], td_errors[i])

        # Importance sampling correction
        N = len(self.replay_buffer.buffer)
        importance_weights = np.array([self.compute_importance_weight(states[i], N) for i in range(batch_size)])
        importance_weights /= importance_weights.max()  # Normalize

        # Weighted TD loss
        loss = np.mean(importance_weights * (td_errors ** 2))

        # Backpropagate loss
        self.q_network.fit(states, q_values, batch_size=batch_size, verbose=0)

    def update_target_network(self):
        self.target_network.set_weights(self.q_network.get_weights())

    def anneal_beta(self):
        self.beta = min(self.beta + self.delta_beta, self.beta_end)

## Driver Code

In [7]:
from tqdm import tqdm

# Initialize Weights & Biases
wandb.init(project="AT-DQN", name="Pong_ATDQN")

env = gym.make('ALE/Pong-v5')
state_shape = (84, 84, 4)
action_size = env.action_space.n

agent = ATDQNAgent(action_size, state_shape)

num_episodes = 1000
batch_size = 32

wandb.config.update({
    "num_episodes": num_episodes,
    "batch_size": batch_size
})

# Use tqdm to track episode progress
for episode in tqdm(range(num_episodes), desc="Training Progress", unit="episode"):
    state = preprocess_frame(env.reset()[0])
    state_stack = np.stack([state] * 4, axis=-1)

    total_reward = 0
    episode_td_errors = []

    for step in range(10000):
        action = agent.act(state_stack)
        next_frame, reward, done, _, _ = env.step(action)
        next_state = preprocess_frame(next_frame)

        next_state_stack = np.append(state_stack[:, :, 1:], np.expand_dims(next_state, axis=-1), axis=-1)

        td_error = agent.replay_buffer.add((state_stack, action, reward, next_state_stack, done))
        episode_td_errors.append(td_error)

        agent.train(batch_size)
        state_stack = next_state_stack
        total_reward += reward

        if done:
            break

    agent.update_target_network()
    agent.anneal_beta()
    agent.normalize_attention()

    # Log episode metrics
    wandb.log({
        "Episode": episode + 1,
        "Total Reward": total_reward,
        "Mean TD Error": np.mean(episode_td_errors),
        "Beta": agent.beta
    })

    print(f"Episode {episode + 1}: Reward = {total_reward}, β = {agent.beta:.4f}")

env.close()
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcasarulez[0m ([33mHarish-Personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
2025-02-13 11:45:47.512335: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
  if not isinstance(terminated, (bool, np.bool8)):
Training Progress:   0%|                          | 0/1000 [01:04<?, ?episode/s]


KeyboardInterrupt: 