In [None]:
###########
# Imports #
###########

# Mario Envrionment Libraries
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

# Other Libraries
import gym
from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from matplotlib import pyplot as plt
from collections import deque
import torch.nn as nn
import torch
import time
import random
import copy
import numpy as np
import torchvision
import os

In [None]:
###############
# Agent Class #
###############

class Agent:
    def __init__(self, num_actions, model_path=None):
        self.num_actions = num_actions
        self.epsilon = 0.0
        
        # Allow torch to use GPU for processing, if avaiable
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda')
        print("Torch device: ", device)
        
        self.q1 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )
        
        self.q1.to(device)
        
        # Check if previously trained model has been provided
        if model_path != None:
            loaded_model = torch.load(model_path)
            self.q1.load_state_dict(loaded_model['model_state_dict'])
            self.q1.eval()
            self.epsilon = loaded_model['epsilon']
    
    def get_action(self, state):
        # epsilon-greedy policy
        if (random.uniform(0, 1) < self.epsilon):
            # Explore randomly
            random_action = random.randint(0, self.num_actions-1)
            #print("Selected action:", random_action)
            return random_action
        else:
            # Follow policy greedily (A_t = argmax_q1(S_t, a, theta1))
            state = torch.tensor(state.__array__()).squeeze().cuda() # Squeeze to get rid of uneccessary channel dimension
            #print(state.shape)
            state = torchvision.transforms.functional.convert_image_dtype(state, dtype=torch.float)
            state = state.unsqueeze(0) # Add first dimension to match shape of minibatches being fed into network
            #print(state.shape)
            
            action_values = self.q1(state)
            #print(action_values)
            best_action = torch.argmax(action_values, axis=1).item()
            #print("Selected action:", best_action)
            return best_action

In [None]:
# FrameSkipReplicator class modified from the original MaxAndSkipEnv wrapper to display all 4 frames (for nice video playback), but only return max_frame
from stable_baselines3.common.type_aliases import GymStepReturn
class FrameSkipReplicator(MaxAndSkipEnv):
    def __init__(self, env, skip=4, fps=60):
        super().__init__(env, skip)
        self.fps = fps
        self.skip = skip
    def step(self, action: int) -> GymStepReturn:
        total_reward = 0.0
        done = None
        last_frame_rendered = time.perf_counter()
        for i in range(self.skip):
            obs, reward, done, info = self.env.step(action)
            current_time = time.perf_counter()
            while current_time - last_frame_rendered < 1/self.fps:
                current_time = time.perf_counter()
            last_frame_rendered = current_time
            env.render()
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        # Note that the observation on the done=True frame
        # doesn't matter
        max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, done, info

In [None]:
##################################################
# Playback loop to see trained agent performance #
##################################################

# Seed the RNG for consistent sequence of random actions
random.seed(1)

# Setup a new environment without frameskipping (so that playback looks normal)
env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')

# Simplest moveset that can complete the level
NO_RUNNING = [
    ['right'],
    ['right', 'A']
]

# Specify what controller inputs are available
chosen_inputs = SIMPLE_MOVEMENT
env = JoypadSpace(env, chosen_inputs)

# Wrapper to ensure observations are in same format as was used in training
num_skip_frames = 4
game_fps = 60
env = FrameSkipReplicator(env, skip=num_skip_frames, fps=game_fps) # Replicate the frameskipping progress (one ouptut every n frames), but render all skipped frames
env = GrayScaleObservation(env) # Convert state from RGB to grayscale (1/3 number of pixels for NN to process)
env = ResizeObservation(env, shape=84) # Scale each frame to 64x64 pixels (15x fewer pixels for NN to process)
env = FrameStack(env, 4) # Stack the last 4 observations together to give NN temporal awareness

# Create a new agent using saved model
agent = Agent(len(chosen_inputs), model_path='./models/lr0.000025/20220419-122232/10000000.pth')
#agent.epsilon = 0.05 # Use this if you still want some degree of randomness in playback or to match training value
print(agent.epsilon)

steps_since_restart = 0
last_frame_rendered = time.perf_counter()

with torch.no_grad():
    #env.render()
    #time.sleep(5)
    for episode in range(100):
        state = env.reset()
        start_time = time.perf_counter()
        while True:
            with torch.no_grad():
                chosen_action = agent.get_action(state)
                # Perform n frames of the same action to account for the frameskipping used in training
                state, reward, done, info = env.step(chosen_action)

                #plt.imshow(state, cmap='gray')

                steps_since_restart += 1

                if done:
                    print("Episode", episode, end=': ')
                    if info['time'] == 0:
                        print("Mario ran out of time")
                    elif info['flag_get']:
                        print("########## Mario reached the FLAG in", time.perf_counter() - start_time, "seconds ##########")
                    else:
                        print("Mario died after", steps_since_restart, "steps at x_position", info['x_pos'])
                    break

env.close()

In [None]:
env.close()