# Imports

In [None]:
import torch
from torch.serialization import add_safe_globals

from agent import Agent
from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace

import gym_super_mario_bros
from wrappers import apply_wrappers

In [None]:
add_safe_globals([Agent])

# Configuration

In [None]:
PATH = 'models/model_v1/checkpoint.pt'
EPISODES_PER_LEVEL = 10
DISPLAY = True

# Load

In [None]:
checkpoint = torch.load(PATH, weights_only=False)

LEVELS = checkpoint['levels']
SKIP_FRAME = checkpoint['skip_frame']
RESIZE = checkpoint['resize']
FRAME_STACK = checkpoint['frame_stack']
agent = checkpoint['agent']

agent.epsilon = 0.0

In [None]:
possible_levels = []
for world in range(1, 9):
    for level in range(1, 5):
        possible_levels.append(f"SuperMarioBros-{world}-{level}-v0")
test_levels = [x for x in possible_levels if x not in LEVELS]
test_levels = ['SuperMarioBros-1-1-v0']

# Test

In [None]:
for level in test_levels:
    print(f"=== Evaluating on {level} ===")

    for ep in range(EPISODES_PER_LEVEL):
        print(f"Episode {ep + 1}/{EPISODES_PER_LEVEL}")

        env = gym_super_mario_bros.make(level, render_mode='human' if DISPLAY else 'rgb', apply_api_compatibility=True)
        env = JoypadSpace(env, RIGHT_ONLY)
        env = apply_wrappers(env, SKIP_FRAME, RESIZE, FRAME_STACK)

        try:
            state, _ = env.reset()
            done = False
            total_reward = 0

            while not done:
                with torch.no_grad():
                    action = agent.choose_action(state)

                new_state, reward, done, truncated, info = env.step(action)
                total_reward += reward
                state = new_state

            print(f"Total reward: {total_reward}")

        finally:
            env.close()

    print()