In [None]:
# ID3QNE_deepQnet.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class DistributionalDQN(nn.Module):
    def __init__(self, state_dim, n_actions):
        super(DistributionalDQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
        )
        self.value = nn.Linear(128, 1)
        self.adv   = nn.Linear(128, n_actions)

    def forward(self, x):
        h = self.net(x)
        v = self.value(h)                  # [B,1]
        a = self.adv(h)                    # [B,A]
        q = v + (a - a.mean(dim=1, keepdim=True))
        return q

class Dist_DQN(object):
    def __init__(self, state_dim, n_actions, gamma=0.99, lr=1e-5, seed=42):
        torch.manual_seed(seed); np.random.seed(seed)
        self.q_net = DistributionalDQN(state_dim, n_actions)
        self.target_net = copy.deepcopy(self.q_net)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)
        self.gamma = gamma

    def polyak_target_update(self, tau=0.005):
        for tp, p in zip(self.target_net.parameters(), self.q_net.parameters()):
            tp.data.copy_(tau * p.data + (1 - tau) * tp.data)

    def compute_loss(self, batchs):
        state = torch.tensor(batchs['state'], dtype=torch.float32)
        next_state = torch.tensor(batchs['next_state'], dtype=torch.float32)
        action = torch.tensor(batchs['action'], dtype=torch.long)
        reward = torch.tensor(batchs['reward'], dtype=torch.float32)

        q_vals = self.q_net(state)                      # [B,A]
        q_val = q_vals.gather(1, action.reshape(-1,1)).squeeze(1)  # [B]

        with torch.no_grad():
            next_q = self.target_net(next_state)        # [B,A]
            max_next = next_q.max(1)[0]                 # [B]
            target = reward + self.gamma * max_next

        loss = F.mse_loss(q_val, target)
        return loss

    def train_model(self, batchs, epoch):
        self.q_net.train()
        loss = self.compute_loss(batchs)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
        self.optimizer.step()
        self.polyak_target_update()
        return float(loss.item())

    def get_action(self, state_np):
        with torch.no_grad():
            s = torch.tensor(state_np, dtype=torch.float32).unsqueeze(0)
            return int(self.q_net(s).argmax().item())
