In [1]:
import gymnasium as gym
import os
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import (
    DummyVecEnv,
    VecFrameStack,
    VecNormalize,
    VecMonitor,
)
from stable_baselines3.dqn.policies import CnnPolicy

import ale_py
import json
from stable_baselines3.common.logger import configure
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.callbacks import (
    BaseCallback,
    CheckpointCallback,
    CallbackList,
)
from stable_baselines3.dqn.policies import DQNPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from gymnasium import spaces

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

2025-10-20 00:19:18.487617: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 512):
        super().__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]

        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        with torch.no_grad():
            sample_input = torch.as_tensor(observation_space.sample()[None]).float()
            n_flatten = self.cnn(sample_input).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim),
            nn.ReLU(),
        )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(obs))


class CustomCnnPolicy(DQNPolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            features_extractor_class=CustomCNN,
            features_extractor_kwargs=dict(features_dim=512),
            net_arch=[],
            **kwargs,
        )

In [3]:
config = {
    "env_name": "PongNoFrameskip-v4",
    "num_envs": 2,
    "seed": 100,
}
env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"])
env = VecFrameStack(env, n_stack=4)
model_path = "./sb3_pong_models_dqn/pong_dqn_final_model.zip"
model2 = DQN.load(model_path, env=env)

model = DQN(
    CustomCnnPolicy,
    env,
    learning_rate=1e-4,  # or 0.0001
    buffer_size=100_000,
    learning_starts=100000,
    batch_size=32,
    gamma=0.99,
    train_freq=4,
    gradient_steps=1,
    target_update_interval=1000,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    verbose=1,
)


A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]


Wrapping the env in a VecTransposeImage.
Using cuda device
Wrapping the env in a VecTransposeImage.
