# Silence Warnings

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Imports

In [None]:
import os
import numpy as np
import random

import torch
from torch import nn
from torch.serialization import add_safe_globals

import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace

from agent import Agent
from wrappers import apply_wrappers
from utils import *

In [None]:
add_safe_globals([Agent])

# Configs

In [None]:
DISPLAY = True
NUM_OF_EPISODES = 50_000
CKPT_SAVE_INTERVAL = 3

## Load Model

In [None]:
LOAD = False
PATH = 'models/model_v1/checkpoint.pt'

In [None]:
if LOAD:
    checkpoint = torch.load(PATH, weights_only=False)

    LEVELS = checkpoint['levels']
    SKIP_FRAME = checkpoint['skip_frame']
    RESIZE = checkpoint['resize']
    FRAME_STACK = checkpoint['frame_stack']

    agent = checkpoint['agent']

In [None]:
if not LOAD:
    # Environment Configuration
    LEVELS = ['SuperMarioBros-1-1-v0']
    SKIP_FRAME = 4
    RESIZE = 84
    FRAME_STACK = 4

    # Hyperparameter Configuration
    LR = 0.00025
    GAMMA = 0.9
    EPSILON = 1.0
    EPS_DECAY = 0.99999975
    EPS_MIN = 0.1
    REPLAY_BUFFER_CAPACITY = 100_000
    BATCH_SIZE = 32
    SYNC_NETWORK_RATE = 10_000

    # Network Architecture Configuration
    conv_layers = nn.Sequential(
        nn.Conv2d(FRAME_STACK, 32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1),
        nn.ReLU(),
    )

    o = conv_layers(torch.zeros(1, FRAME_STACK, RESIZE, RESIZE))
    conv_out_size = int(np.prod(o.size()))

    network = nn.Sequential(
        conv_layers,
        nn.Flatten(),
        nn.Linear(conv_out_size, 512),
        nn.ReLU(),
        nn.Linear(512, len(RIGHT_ONLY))
    )

    # Create Agent
    agent = Agent(
        network,
        len(RIGHT_ONLY),
        LR,
        GAMMA,
        EPSILON,
        EPS_DECAY,
        EPS_MIN,
        REPLAY_BUFFER_CAPACITY,
        BATCH_SIZE,
        SYNC_NETWORK_RATE
    )

# Train

In [None]:
if not LOAD:
    base = "models"
    os.makedirs(base, exist_ok=True)

    existing = [d for d in os.listdir(base) if d.startswith("model_v")]
    nums = [int(d.replace("model_v", "")) for d in existing if d.replace("model_v", "").isdigit()]
    next_version = max(nums) + 1 if nums else 1

    save_dir = os.path.join(base, f"model_v{next_version}")
    os.makedirs(save_dir, exist_ok=True)

    PATH = os.path.join(save_dir, "checkpoint.pt")

In [None]:
for i in range(NUM_OF_EPISODES):
    print("Episode:", i)

    env = gym_super_mario_bros.make(random.choice(LEVELS), render_mode='human' if DISPLAY else 'rgb', apply_api_compatibility=True)
    env = JoypadSpace(env, RIGHT_ONLY)
    env = apply_wrappers(env, SKIP_FRAME, RESIZE, FRAME_STACK)

    try:
        state, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            a = agent.choose_action(state)
            new_state, reward, done, truncated, info  = env.step(a)
            agent.episode_counter += 1
            total_reward += reward

            agent.store_in_memory(state, a, reward, new_state, done)
            agent.learn()

            state = new_state

        print("Episode Number", agent.episode_counter)
        print("Learn step counter:", agent.learn_step_counter)
        print("Total reward:", total_reward)
        print("Epsilon:", agent.epsilon)
        print("Size of replay buffer:", len(agent.replay_buffer))
        print()

        if (i + 1) % CKPT_SAVE_INTERVAL == 0:
            torch.save(
                {
                    "agent": agent,
                    "levels": LEVELS,
                    "skip_frame": SKIP_FRAME,
                    "resize": RESIZE,
                    "frame_stack": FRAME_STACK
                },
                PATH
            )

    finally:
        env.close()