In [None]:
# ID3QNE_deepQnet.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

# ----- Core PyTorch Dueling DQN Model -----
class DistributionalDQN(nn.Module):
    def __init__(self, state_dim, n_actions):
        super(DistributionalDQN, self).__init__()

        self.conv = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        )

        self.value_stream = nn.Sequential(
            nn.Linear(128, 1)
        )

        self.advantage_stream = nn.Sequential(
            nn.Linear(128, n_actions)
        )

    def forward(self, state):
        x = self.conv(state)
        value = self.value_stream(x)
        advantage = self.advantage_stream(x)
        qvals = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return qvals

# ----- Wrapper Class for Distributional DQN -----
class Dist_DQN(object):
    def __init__(self, state_dim, n_actions):
        self.q_net = DistributionalDQN(state_dim, n_actions)
        self.target_net = copy.deepcopy(self.q_net)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=1e-5)
        self.gamma = 0.9

    def polyak_target_update(self, tau=0.005):
        for target_param, param in zip(self.target_net.parameters(), self.q_net.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def compute_loss(self, batch):
        state = torch.tensor(batch['state'], dtype=torch.float32)
        next_state = torch.tensor(batch['next_state'], dtype=torch.float32)
        action = torch.tensor(batch['action'], dtype=torch.long)
        reward = torch.tensor(batch['reward'], dtype=torch.float32)

        # Input NaN checks
        if torch.isnan(state).any():
            print("NaNs in input state")
            return torch.tensor(float('nan'), requires_grad=True)
        if torch.isnan(next_state).any():
            print("NaNs in input next_state")
            return torch.tensor(float('nan'), requires_grad=True)

        q_vals = self.q_net(state)
        if torch.isnan(q_vals).any():
            print("NaN detected in q_vals output from q_net.")
            return torch.tensor(float('nan'), requires_grad=True)

        q_val = q_vals.gather(1, action.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            next_q_vals = self.target_net(next_state)
            if torch.isnan(next_q_vals).any():
                print("NaN detected in next_q_vals output from target_net.")
                return torch.tensor(float('nan'), requires_grad=True)
            max_next_q_val = next_q_vals.max(1)[0]
            expected_q_val = reward + self.gamma * max_next_q_val

        if torch.isnan(q_val).any() or torch.isnan(expected_q_val).any():
            print("NaN detected in Q-values used in loss calculation.")
            return torch.tensor(float('nan'), requires_grad=True)

        loss = F.mse_loss(q_val, expected_q_val)
        return loss

    def train_model(self, batchs, epoch):
        self.q_net.train()
        loss = self.compute_loss(batchs)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.polyak_target_update()
        return loss.item()

    def get_action(self, state):
        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            q_vals = self.q_net(state)
            return q_vals.argmax().item()
