## Importing depedencies

In [16]:
import os
import random
import time
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import ale_py
from collections import deque
from gym.wrappers import AtariPreprocessing, FrameStack

In [17]:
gym.register_envs(ale_py)

In [18]:
from gymnasium.wrappers import GrayscaleObservation

## DQN

In [20]:

# DQN Model
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        # Flatten layer
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, n_actions)
        
        # Attention Layer
        self.attention_weights = nn.Linear(512, 1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        
        # Compute attention scores for each state-action pair
        attention_score = torch.sigmoid(self.attention_weights(x))
        
        # Compute Q-values
        q_values = self.fc2(x)
        
        return q_values, attention_score

# DQN Agent
class DQNAgent:
    def __init__(self, env, input_shape, n_actions, gamma=0.99, lr=0.0001, buffer_size=100000, batch_size=32):
        self.env = env
        self.input_shape = input_shape
        self.n_actions = n_actions
        self.gamma = gamma
        self.lr = lr
        self.batch_size = batch_size
        self.policy_net = DQN(input_shape, n_actions).cuda()
        self.target_net = DQN(input_shape, n_actions).cuda()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
        self.memory = deque(maxlen=buffer_size)
        
        self.update_target_network()
        
    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

    def select_action(self, state, attention_score):
        # If attention score is high, explore by choosing a random action
        if attention_score > 0.5:
            return random.choice(range(self.n_actions))  # Random action for exploration
        else:
            with torch.no_grad():
                # Select the best action according to the Q-values for exploitation
                q_values, _ = self.policy_net(state)
                return q_values.argmax().item()  # Best action for exploitation

    def store_transition(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def train(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states).cuda()
        next_states = torch.stack(next_states).cuda()
        actions = torch.tensor(actions).cuda()
        rewards = torch.tensor(rewards).cuda()
        dones = torch.tensor(dones).cuda()

        # Calculate Q values from the policy network
        q_values, attention_scores = self.policy_net(states)
        
        # Calculate target Q values from the target network
        next_q_values, _ = self.target_net(next_states)
        next_q_values = next_q_values.max(1)[0]
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        # Compute TD Error
        td_error = target_q_values - q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Supervise the attention mechanism with the TD error
        attention_loss = torch.mean(torch.abs(attention_scores.squeeze() - td_error.detach()))

        # Loss for Q-learning
        q_loss = torch.mean(td_error ** 2)

        # Total loss
        loss = q_loss + attention_loss

        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

# Training function
def train_dqn(env, agent, episodes=1000):
    for episode in range(episodes):
        state = env.reset()
        state = torch.tensor(state).unsqueeze(0).float().cuda()
        done = False
        episode_loss = 0
        
        while not done:
            # Get attention score for the current state
            _, attention_score = agent.policy_net(state)
            
            # Select action based on exploration-exploitation controlled by attention score
            action = agent.select_action(state, attention_score)
            
            next_state, reward, done, _ = env.step(action)
            next_state = torch.tensor(next_state).unsqueeze(0).float().cuda()
            
            agent.store_transition(state, action, reward, next_state, done)
            
            state = next_state
            
            # Train the agent with the collected experience
            loss = agent.train()
            episode_loss += loss
        
        agent.update_target_network()
        print(f'Episode {episode+1}, Loss: {episode_loss}')

# Set up the environment with grayscale preprocessing
env = gym.make('ALE/Breakout-v5')
env = AtariPreprocessing(env, grayscale_obs=True, frame_skip=1)  # Grayscale and frame skip
env = FrameStack(env, 4)  # Stack 4 frames for temporal context

# Get the shape of the environment's observations and the number of actions
input_shape = (1, 84, 84)  # Grayscale, 84x84
n_actions = env.action_space.n

# Initialize the DQN agent
agent = DQNAgent(env, input_shape, n_actions)

# Train the agent
train_dqn(env, agent)


AssertionError: 