## Algorithm

<div style="text-align: center;">
    <img src="algor.png" alt="Description" width="600">
</div>

## Dependencies

In [None]:
import numpy as np
import gym
import tensorflow as tf
from tensorflow import keras
from collections import deque
import random
import cv2
import wandb
from tqdm import tqdm

## Initializations and Hyperparameters

In [None]:
# tau=0.2 
# lambda_=1.0
# beta_start=0.4
# beta_end=1.0
# T=1000000
# gamma=0.99

## Preprocessing of ATARI frames

In [None]:
def preprocess_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)  # Grayscale
    frame = cv2.resize(frame, (84, 84))  # Resize to 84x84
    frame = frame / 255.0  # Normalize
    return frame

## Building the DQN

In [None]:
def build_q_network(input_shape, action_size):
    model = keras.Sequential([
        keras.layers.Conv2D(32, (8, 8), strides=4, activation='relu', input_shape=input_shape),
        keras.layers.Conv2D(64, (4, 4), strides=2, activation='relu'),
        keras.layers.Conv2D(64, (3, 3), strides=1, activation='relu'),
        keras.layers.Flatten(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dense(action_size, activation='linear')
    ])
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.00025), loss='mse')
    return model

## Replay buffer

In [None]:
class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)
    
    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)
    
    def size(self):
        return len(self.buffer)


## Main AT-DQN Object

In [None]:
class ATDQNAgent:
    def __init__(self, action_size, state_shape, tau=0.2, lambda_=1.0, beta_start=0.4, beta_end=1.0, T=1000000):
        self.action_size = action_size
        self.q_network = build_q_network(state_shape, action_size)
        self.target_network = build_q_network(state_shape, action_size)
        self.target_network.set_weights(self.q_network.get_weights())

        self.replay_buffer = ReplayBuffer()
        self.gamma = 0.99  # Discount factor
        
        # Attention mechanism
        self.alpha = {}  # Attention weights per state
        self.td_errors = {}  # TD error history

        # Exploration parameters
        self.tau = tau  # Threshold
        self.lambda_ = lambda_  # Score adjustment factor

        # Beta Annealing
        self.beta = beta_start
        self.beta_end = beta_end
        self.delta_beta = (beta_end - beta_start) / T  # Annealing rate

    #Returns attention weight for a state or initializes with 1 if not present
    def get_attention(self, state):
        state_key = tuple(state.flatten())
        return self.alpha.get(state_key, 1.0)

    #Updates attention weights based on cumulative TD error over the episode
    def update_attention(self, state, td_error):
        state_key = tuple(state.flatten())
        if state_key not in self.td_errors:
            self.td_errors[state_key] = []
        self.td_errors[state_key].append(abs(td_error))

        self.alpha[state_key] = np.mean(self.td_errors[state_key])

    #Normalization using Beta-Annealing
    def normalize_attention(self):
        max_alpha = max(self.alpha.values(), default=1)
        for state_key in self.alpha:
            self.alpha[state_key] /= max_alpha

    #Select action based on attention-weighted exploration strategy
    def act(self, state):
        state_key = tuple(state.flatten())
        sigma = self.get_attention(state) * self.lambda_

        if sigma > self.tau:
            return np.random.choice(self.action_size)  # Explore
        else:
            q_values = self.q_network.predict(np.expand_dims(state, axis=0), verbose=0)
            return np.argmax(q_values[0])  # Exploit

    #Train the DQN using replay buffer samples
    def train(self, batch_size=32):
        if self.replay_buffer.size() < batch_size:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)

        target_q_values = self.target_network.predict(next_states, verbose=0)
        max_next_q = np.max(target_q_values, axis=1)
        targets = rewards + (1 - dones) * self.gamma * max_next_q

        q_values = self.q_network.predict(states, verbose=0)
        td_errors = targets - q_values[np.arange(batch_size), actions]

        for i in range(batch_size):
            self.update_attention(states[i], td_errors[i])

        self.q_network.fit(states, q_values, batch_size=batch_size, verbose=0)

    #Update target network with stable weights
    def update_target_network(self):
        self.target_network.set_weights(self.q_network.get_weights())

    #Scale beta parameter over time
    def anneal_beta(self):
        self.beta = min(self.beta + self.delta_beta, self.beta_end)

## Driver Code

In [None]:
# Initialize Weights & Biases
wandb.init(project="AT-DQN", name="Pong_ATDQN")

env = gym.make('ALE/Pong-v5')
state_shape = (84, 84, 4)
action_size = env.action_space.n

agent = ATDQNAgent(action_size, state_shape)

num_episodes = 1000
batch_size = 32

# Logging Metrics
wandb.config.update({
    "num_episodes": num_episodes,
    "batch_size": batch_size
})

for episode in tqdm(range(num_episodes), desc="Training Progress"):
    state = preprocess_frame(env.reset()[0])
    state_stack = np.stack([state] * 4, axis=-1)

    total_reward = 0
    episode_td_errors = []

    for step in range(10000):  # Max steps per episode
        action = agent.act(state_stack)
        next_frame, reward, done, _, _ = env.step(action)
        next_state = preprocess_frame(next_frame)
        
        next_state_stack = np.append(state_stack[:, :, 1:], np.expand_dims(next_state, axis=-1), axis=-1)
        
        # Store transition and calculate TD error
        td_error = agent.replay_buffer.add((state_stack, action, reward, next_state_stack, done))
        episode_td_errors.append(td_error)

        agent.train(batch_size)
        state_stack = next_state_stack
        total_reward += reward

        if done:
            break
    
    agent.update_target_network()
    agent.anneal_beta()
    agent.normalize_attention()

    # Log episode metrics
    wandb.log({
        "Episode": episode + 1,
        "Total Reward": total_reward,
        "Mean TD Error": np.mean(episode_td_errors),
        "Beta": agent.beta
    })

    print(f"Episode {episode + 1}: Reward = {total_reward}, β = {agent.beta:.4f}")

env.close()
wandb.finish()