In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), # 4, 84, 84 -> 32, 20, 20
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), # 64, 9, 9
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), # 64, 7, 7
            nn.ReLU(),

            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 6),
        )

        self.__inititalize_weights()
    
    def __inititalize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = x / 255.0
        return self.net(x)

In [None]:
from trainer import DDQN
import gymnasium as gym
from gymnasium.wrappers import FrameStackObservation, GrayscaleObservation, ResizeObservation
import ale_py

EPISODES = 10000

gym.register_envs(ale_py)

env = gym.make("ALE/Pong-v5", render_mode="rgb_array", frameskip=1)
env = GrayscaleObservation(env)
env = ResizeObservation(env, shape=(84,84))
env = FrameStackObservation(env, stack_size=4)

ddqn = DDQN(env, QNetwork, buffer_size=100000, batch_size=32,
                    gamma=0.99, lr=1e-4, epsilon_start=1.0, 
                    epsilon_end=0.01, epsilon_decay=0.995,
                    target_update_interval=1000)
rewards, losses = ddqn.train(episodes=EPISODES, stats_interval=50)

Box(0, 255, (210, 160, 3), uint8)
Box(0, 255, (210, 160), uint8)
