In [None]:
###########
# Imports #
###########

# Mario Envrionment Libraries
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

# Other Libraries
import gym
from gym.wrappers import GrayScaleObservation, ResizeObservation, FrameStack
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from matplotlib import pyplot as plt
from collections import deque
import torch.nn as nn
import torch
import time
import random
import copy
import numpy as np
import torchvision
import os

In [None]:
###############
# Agent Class #
###############

class Agent:
    def __init__(self, num_actions, replay_buffer_size=100000, num_replay_samples=32, model_path=None):
        self.num_actions = num_actions
        self.replay_buffer = deque(maxlen=replay_buffer_size) # Stores up to the given number of past experiences (called 'D' in slides)
        self.num_replay_samples = num_replay_samples
        self.epsilon = 1.0
        self.epsilon_decay = 0.99999975
        self.epsilon_min = 0.1
        self.gamma = 0.9
        self.loss = torch.nn.HuberLoss()
        self.learning_rate = 0.000025
        
        # Allow torch to use GPU for processing, if avaiable
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda')
        print("Torch device: ", device)
        
        self.q1 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        ).to(device)
        
        self.previous_episodes = 0
        self.previous_steps = 0
        
        self.optimizer = torch.optim.Adam(self.q1.parameters(), lr=self.learning_rate)
        
        # Check if previously trained model has been provided
        if model_path != None:
            loaded_model = torch.load(model_path)
            self.q1.load_state_dict(loaded_model['model_state_dict'])
            self.q1.train()
            self.optimizer = torch.optim.Adam(self.q1.parameters(), lr=loaded_model['learning_rate'])
            self.optimizer.load_state_dict(loaded_model['optimizer_state_dict'])
            self.previous_episodes = loaded_model['num_episodes']
            self.previous_steps = loaded_model['num_steps']
            self.epsilon = loaded_model['epsilon']
            
        self.q2 = copy.deepcopy(self.q1) # Create target action-value network as copy of action-value network q1
        
        # Prevent weights being updated in target network
        for param in self.q2.parameters():
            param.requires_grad = False
    
    def decay_epsilon(self):
        self.epsilon *= self.epsilon_decay
        if self.epsilon < self.epsilon_min:
            self.epsilon = self.epsilon_min
    
    def update_logs(self, save_location, log_interval, episode, step, average_episode_loss, 
                    average_episode_reward, average_episode_distance, flag_count, death_count, timeout_count):
        print("Updating logs")
        self.log_file = open(save_location + 'log.txt', 'a')
        self.log_file.write(str(episode+self.previous_episodes) + "," 
                            + str(step+self.previous_steps) + "," 
                            + str(average_episode_reward) + ","
                            + str(average_episode_distance) + ","
                            + str(average_episode_loss) + ","
                            + str(flag_count) + ","
                            + str(death_count) + ","
                            + str(timeout_count) + ","
                            + str(self.epsilon) + "\n")
        self.log_file.close()
    
    def save_model(self, step, episode, save_location):
        torch.save({
            'model_state_dict': self.q1.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': self.loss,
            'num_episodes': episode + self.previous_episodes,
            'num_steps': step + self.previous_steps,
            'epsilon': self.epsilon,
            'learning_rate': self.learning_rate
        },
            save_location + str(step + self.previous_steps) + '.pth')
    
    def get_action(self, state):
        # epsilon-greedy policy
        if (random.uniform(0, 1) < self.epsilon):
            # Explore randomly
            random_action = random.randint(0, self.num_actions-1)
            #print("Selected action:", random_action)
            return random_action
        else:
            # Follow policy greedily (A_t = argmax_q1(S_t, a, theta1))
            state = torch.tensor(state.__array__()).squeeze().cuda() # Squeeze to get rid of uneccessary channel dimension
            #print(state.shape)
            state = torchvision.transforms.functional.convert_image_dtype(state, dtype=torch.float)
            state = state.unsqueeze(0) # Add first dimension to match shape of minibatches being fed into network
            #print(state.shape)
            
            action_values = self.q1(state)
            #print(action_values)
            best_action = torch.argmax(action_values, axis=1).item()
            #print("Selected action:", best_action)
            return best_action
        
    def add_to_replay_buffer(self, state, chosen_action, reward, next_state, done):
        # Convert to tensors
        state = torch.tensor(state.__array__()).squeeze().cuda()
        chosen_action = torch.tensor([chosen_action]).cuda()
        reward = torch.tensor([reward]).cuda()
        next_state = torch.tensor(next_state.__array__()).squeeze().cuda()
        done = torch.tensor([int(done)]).cuda()
        
        self.replay_buffer.append((state, chosen_action, reward, next_state, done))
        
    def sync_networks(self):
        self.q2.load_state_dict(self.q1.state_dict())
        
    def perform_updates(self):
        num_samples = self.num_replay_samples
        
        # Check if replay buffer has enough samples for full minibatch
        # Don't perform updates until enough samples in replay buffer
        if len(self.replay_buffer) < self.num_replay_samples:
            #print("Replay buffer only has", len(self.replay_buffer), "transitions. Skipping updates...")
            return
        
        # Randomly select a number of transitions from the replay buffer
        minibatch = random.sample(self.replay_buffer, self.num_replay_samples)
        state, chosen_action, reward, next_state, done = map(torch.stack, zip(*minibatch)) # https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html?highlight=transformer
        chosen_action = chosen_action.squeeze()
        reward = reward.squeeze()
        done = done.squeeze()
        #print("Shape of batch:", next_state.shape)
        
        # Get TD Value Estimate
        state = torchvision.transforms.functional.convert_image_dtype(state, dtype=torch.float)
        estimated_value = self.q1(state)[np.arange(0, self.num_replay_samples), chosen_action]
        
        # Get TD Value Target
        with torch.no_grad(): # Disable gradient calculation: https://pytorch.org/docs/stable/generated/torch.no_grad.html
            next_state = torchvision.transforms.functional.convert_image_dtype(next_state, dtype=torch.float)
            next_state_Q = self.q1(next_state)
            best_action = torch.argmax(next_state_Q, axis=1)
            next_Q = self.q2(next_state)[np.arange(0, self.num_replay_samples), best_action]

        target_value = (reward + (1 - done.float()) * self.gamma * next_Q).float()
        # print("TD Target:", target_value)
        
        # Calculate loss using Huber Loss
        loss = self.loss(estimated_value, target_value)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Stuff for logging
        return loss.item()

In [None]:
##############################
# Training Environment Setup #
##############################

env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
chosen_inputs = SIMPLE_MOVEMENT
env = JoypadSpace(env, chosen_inputs)

# Environment Preprocessing - Apply wrappers to the environment to reduce load on NN
num_skip_frames = 4
env = MaxAndSkipEnv(env, skip=num_skip_frames) # Skip the given number of frames
env = GrayScaleObservation(env) # Convert state from RGB to grayscale (1/3 number of pixels for NN to process)
env = ResizeObservation(env, shape=84) # Scale each frame to 64x64 pixels (15x fewer pixels for NN to process)
env = FrameStack(env, 4) # Stack the last 4 observations together to give NN temporal awareness
# Final observation shape: (        4         ,   84   ,   84    ,      1      )
#                           num_stacked_frames  height    width    num_channels

In [None]:
#######################
# Agent Training Loop #
#######################

resuming_training = False
saved_model_path = './models/20220418-175357/30000.pth'

if resuming_training:
    agent = Agent(len(chosen_inputs), model_path=saved_model_path) # Continue training existing model
else:
    agent = Agent(len(chosen_inputs)) # Train NEW agent, set resuming_training to false if training from scratch

num_episodes = 1000000
sync_interval = 10000 #(steps)
save_interval = 10000 #(steps)
log_interval = 20 #(episodes)
timestamp = time.strftime("%Y%m%d-%H%M%S")
save_location = './models/lr0.000025/' + timestamp + '/'
os.mkdir(save_location)

step = 1
rewards_per_episode = []
distance_per_episode = []
average_loss_per_episode = []
death_count = 0
timeout_count = 0
flag_count = 0
for episode in range(num_episodes):
    episode_loss = 0.0
    num_losses = 0
    average_episode_loss = 0.0
    episode_reward = 0.0
    num_episode_steps = 1
    state = env.reset()

    while True:
            
        # Choose and execute an action following epsilon-greedy policy
        chosen_action = agent.get_action(state)
        next_state, reward, done, info = env.step(chosen_action)
        episode_reward += reward
        
        # Add the transition to the replay buffer
        agent.add_to_replay_buffer(state, chosen_action, reward, next_state, done)
        
        # Update random sample of transitions from the replay buffer
        loss = agent.perform_updates()
        if loss is not None:
            episode_loss += loss
            num_losses += 1
            average_episode_loss = episode_loss / num_losses
        
        agent.decay_epsilon()
        
        if step % 10000 == 0:
            print("Done", step, "steps,", episode, "completed episodes")
        
        # Sync parameters of target network every so often
        if step % sync_interval == 0:
            agent.sync_networks()
            
        # Save model every so often
        if step % save_interval == 0:   
            agent.save_model(step, episode, save_location)
        
        #env.render()
        
        state = next_state
        step += 1 # Total steps over all episodes
        num_episode_steps += 1 # Steps in this episode
        
        if done:
            rewards_per_episode.append(episode_reward)
            distance_per_episode.append(info['x_pos'])
            average_loss_per_episode.append(average_episode_loss)
            
            if info['time'] == 0:
                timeout_count += 1
            elif info['flag_get']:
                flag_count += 1
            else:
                death_count += 1
            
            # Log episode stats
            if episode % log_interval == 0 and episode != 0:
                average_episode_reward = sum(rewards_per_episode) / len(rewards_per_episode)
                average_episode_distance = sum(distance_per_episode) / len(distance_per_episode)
                average_loss_across_episodes = sum(average_loss_per_episode) / len(average_loss_per_episode)
                agent.update_logs(save_location, log_interval, episode, step, average_loss_across_episodes, 
                                  average_episode_reward, average_episode_distance, flag_count, death_count, timeout_count)
                rewards_per_episode = []
                distance_per_episode = []
                average_loss_per_episode = []
                timeout_count = 0
                flag_count = 0
                death_count = 0
            
            break

In [None]:
env.close()