In [None]:
#misc
import numpy as np
from random import sample
# env
import gymnasium as gym
from vizdoom import gymnasium_wrapper
# learning
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
from stable_baselines3 import repl

# Game Loop

In [None]:
env = gym.make("VizdoomBasic-v0", render_mode="human")

obs, info = env.reset(seed=42)
for _ in range(100):
    action = env.action_space.sample()

    obs, rew, term, trun, info = env.step(action)

    if term or trun:
        obs, info = env.reset()

env.close()

# Setup for DQN

We first need a **Transition** class which represents a `(state, action) -> (state', reward)` datapoint.

Then we need a **Replay Memory** class to store and utilize these transitions.

In [None]:
class Transition():
    def __init__(self) -> None:
        self.state
        self.action
        self.next_state
        self.reward

In [None]:
class ReplayMemory(object):
    def __init__(self):
        self.memory = np.array()

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# DQN Class

In [None]:
class DQN(nn.Module):
    def __init__(self, n_actions) -> None:
        super(DQN, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(8, 8, kernel_size=3, stride=2, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(8, 8, kernel_size=3, stride=1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(),
        )

        self.conv4 = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )

        self.state_fc = nn.Sequential(nn.Linear(96, 64), nn.ReLU(), nn.Linear(64, 1))

        self.advantage_fc = nn.Sequential(
            nn.Linear(96, 64), nn.ReLU(), nn.Linear(64, n_actions)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(-1, 192)
        x1 = x[:, :96]  # input for the net to calculate the state value
        x2 = x[:, 96:]  # relative advantage of actions in the state
        state_value = self.state_fc(x1).reshape(-1, 1)
        advantage_values = self.advantage_fc(x2)
        x = state_value + (
            advantage_values - advantage_values.mean(dim=1).reshape(-1, 1)
        )

        return x

# Training