# Silence Warnings

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Imports

In [None]:
import os
import numpy as np
import random

import torch
from torch import nn

import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace

from agent import Agent
from wrappers import apply_wrappers
from utils import *

# Configs

## General

In [None]:
DISPLAY = True
NUM_OF_EPISODES = 50_000
CKPT_SAVE_INTERVAL = 5_000

## Environment Configuration

In [None]:
LEVELS = ['SuperMarioBros-1-1-v3']
SKIP_FRAME = 4
RESIZE = 84
FRAME_STACK = 4

## Hyperparameter Configuration

In [None]:
LR = 0.00025
GAMMA = 0.9
EPSILON = 1.0
EPS_DECAY = 0.99999975
EPS_MIN = 0.1
REPLAY_BUFFER_CAPACITY = 100_000
BATCH_SIZE = 32
SYNC_NETWORK_RATE = 10_000

## Network Architecture Configuration

In [None]:
conv_layers = nn.Sequential(
    nn.Conv2d(FRAME_STACK, 32, kernel_size=8, stride=4),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2),
    nn.ReLU(),
    nn.Conv2d(64, 64, kernel_size=3, stride=1),
    nn.ReLU(),
)

o = conv_layers(torch.zeros(1, FRAME_STACK, RESIZE, RESIZE))
# np.prod returns the product of array elements over a given axis
conv_out_size = int(np.prod(o.size()))

network = nn.Sequential(
    conv_layers,
    nn.Flatten(),
    nn.Linear(conv_out_size, 512),
    nn.ReLU(),
    nn.Linear(512, len(RIGHT_ONLY))
)

# Train

In [None]:
model_path = os.path.join("models", get_current_date_time_string())
os.makedirs(model_path, exist_ok=True)

if torch.cuda.is_available():
    print("Using CUDA device:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available")

In [None]:
agent = Agent(
    network,
    len(RIGHT_ONLY),
    LR,
    GAMMA,
    EPSILON,
    EPS_DECAY,
    EPS_MIN,
    REPLAY_BUFFER_CAPACITY,
    BATCH_SIZE,
    SYNC_NETWORK_RATE
)

In [None]:
for i in range(NUM_OF_EPISODES):
    print("Episode:", i)

    env = gym_super_mario_bros.make(random.choice(LEVELS), render_mode='human' if DISPLAY else 'rgb', apply_api_compatibility=True)
    env = JoypadSpace(env, RIGHT_ONLY)
    env = apply_wrappers(env, skip_frame=4, resize=84, frame_stack=4)

    try:
        state, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            a = agent.choose_action(state)
            new_state, reward, done, truncated, info  = env.step(a)
            total_reward += reward

            agent.store_in_memory(state, a, reward, new_state, done)
            agent.learn()

            state = new_state

        print("Total reward:", total_reward, "Epsilon:", agent.epsilon, "Size of replay buffer:", len(agent.replay_buffer), "Learn step counter:", agent.learn_step_counter)

        if (i + 1) % CKPT_SAVE_INTERVAL == 0:
            agent.save_model(os.path.join(model_path, "model_" + str(i + 1) + "_iter.pt"))

        print("Total reward:", total_reward)
    finally:
        env.close()