In [None]:
# ID3QNE_deepQnet.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class DuelingNet(nn.Module):
    def __init__(self, state_dim, num_actions):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
        )
        self.V = nn.Linear(128, 1)
        self.A = nn.Linear(128, num_actions)

    def forward(self, x):
        z = self.backbone(x)
        v = self.V(z)                       # [N,1]
        a = self.A(z)                       # [N,A]
        q = v + (a - a.mean(dim=1, keepdim=True))
        return q

class Dist_DQN:
    """
    Minimal WD3QN-style wrapper expected by our runner:
     - train(batch8, epoch) or train(batch8)
     - get_action(state_1D)
    batch8 is (state, next_state, action, next_action, reward, done, bloc_num, SOFAS)
    """
    def __init__(self, state_dim, num_actions, gamma=0.99, lr=1e-5, seed=42, device=None):
        torch.manual_seed(seed); np.random.seed(seed)
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.gamma = gamma

        self.q_net      = DuelingNet(state_dim, num_actions).to(self.device)
        self.target_net = copy.deepcopy(self.q_net).to(self.device)
        self.optimizer  = torch.optim.Adam(self.q_net.parameters(), lr=lr)

    @torch.no_grad()
    def _target_update(self, tau=0.005):
        for tp, p in zip(self.target_net.parameters(), self.q_net.parameters()):
            tp.data.copy_(tau * p.data + (1 - tau) * tp.data)

    def _loss(self, batch):
        (S, Snext, A, A_next, R, Done, bloc, SOFAS) = batch
        S     = torch.tensor(S, dtype=torch.float32, device=self.device)
        Snext = torch.tensor(Snext, dtype=torch.float32, device=self.device)
        A     = torch.tensor(A, dtype=torch.long,     device=self.device)
        R     = torch.tensor(R, dtype=torch.float32,  device=self.device)
        Done  = torch.tensor(Done, dtype=torch.float32, device=self.device)

        q_all = self.q_net(S)                   # [N,A]
        q_sa  = q_all.gather(1, A.view(-1,1)).squeeze(1)  # [N]

        with torch.no_grad():
            q_next = self.target_net(Snext)     # [N,A]
            max_next = q_next.max(dim=1)[0]     # [N]
            y = R + (1.0 - Done) * self.gamma * max_next

        loss = F.mse_loss(q_sa, y)
        return loss

    def train(self, batchs, epoch=None):
        loss = self._loss(batchs)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
        self.optimizer.step()
        self._target_update()
        return loss.item()

    @torch.no_grad()
    def get_action(self, state_1d):
        s = torch.tensor(state_1d, dtype=torch.float32, device=self.device).unsqueeze(0)
        q = self.q_net(s)
        return int(q.argmax(dim=1).item())
