# Necessary imports

In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Dueling DQN

In [None]:
class DuelingDQN(nn.Module):
    def __init__(self, input_shape: torch.Tensor, n_actions: int, type: int):
        super(DuelingDQN, self).__init__()

        self.fc_val = nn.Sequential(
            nn.Linear(input_shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

        self.fc_adv = nn.Sequential(
            nn.Linear(input_shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

        self.type = type

    def forward(self, x: torch.Tensor):
        val = self.fc_val(x)
        adv = self.fc_adv(x)

        if self.type == 1:
            return val + adv - adv.mean()
        else:
            return val + adv - adv.max()

# Acrobat-v1

In [3]:
env = gym.make("Acrobot-v1")

In [None]:
# train the dueling dqn

n_actions = env.action_space.n
input_shape = env.observation_space.shape
input_shape = (input_shape[0],)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = DuelingDQN(input_shape, n_actions, 1).to(device)
target_net = DuelingDQN(input_shape, n_actions, 1).to(device)
target_net.load_state_dict(net.state_dict())

optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.MSELoss()

BATCH_SIZE = 32
GAMMA = 0.99
EPSILON_START = 1.0

EPSILON_DECAY = 0.0001

EPSILON_FINAL = 0.01

def train(net, target_net, optimizer, criterion, batch, gamma):
    states, actions, rewards, dones, next_states = batch

    states_v = torch.tensor(states).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.BoolTensor(dones).to(device)

    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    next_state_values = target_net(next_states_v).max(1)[0]
    next_state_values[done_mask] = 0.0
    next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * gamma + rewards_v

    loss = criterion(state_action_values, expected_state_action_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()