In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class DistributionalDQN(nn.Module):
    def __init__(self, state_dim, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        )
        self.value_stream = nn.Linear(128, 1)
        self.advantage_stream = nn.Linear(128, n_actions)

    def forward(self, state):
        x = self.conv(state)
        v = self.value_stream(x)
        a = self.advantage_stream(x)
        q = v + (a - a.mean(dim=1, keepdim=True))
        return q

class Dist_DQN(object):
    def __init__(self, state_dim, n_actions):
        self.q_net = DistributionalDQN(state_dim, n_actions)
        self.target_net = copy.deepcopy(self.q_net)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=1e-5)
        self.gamma = 0.99

    def polyak_target_update(self, tau=0.005):
        for tp, p in zip(self.target_net.parameters(), self.q_net.parameters()):
            tp.data.copy_(tau * p.data + (1 - tau) * tp.data)

    def compute_loss(self, batch):
        state = torch.tensor(batch['state'], dtype=torch.float32)
        next_state = torch.tensor(batch['next_state'], dtype=torch.float32)
        action = torch.tensor(batch['action'], dtype=torch.long)
        reward = torch.tensor(batch['reward'], dtype=torch.float32)

        q_vals = self.q_net(state)
        q_val = q_vals.gather(1, action.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            next_q = self.target_net(next_state)
            max_next = next_q.max(1)[0]
            target = reward + self.gamma * max_next

        loss = F.mse_loss(q_val, target)
        return loss

    def train_model(self, batchs, epoch):
        self.q_net.train()
        loss = self.compute_loss(batchs)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.polyak_target_update()
        return loss.item()
