In [1]:
from gymnasium.wrappers import FrameStack, GrayScaleObservation, ResizeObservation, TransformObservation
import retro
import cv2
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
from collections import deque
import itertools
import numpy as np
import random

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

In [2]:
class QApproximator(nn.Module):
    def __init__(self):
        # Convolutional output dimensions formula (in each depth slice): W_new = (W-F + 2P)/S + 1 where W=input_shape, F=kernel_shape, P=padding_amount, S=stride_amount
        super(QApproximator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=8, stride=4) # Feature map size: 20
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2) # Feature map size: 9
        self.fc1 = nn.Linear(in_features=32*9*9, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=4)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        x = x.view(-1, 32*9*9) # Flattening for fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x
    
    def act(self, state, device):
        state_tensor = torch.tensor(state, device=device, dtype=torch.float32)
        q_values = self(state_tensor.unsqueeze(0))
        max_q_index = torch.argmax(q_values, dim=1)[0]
        action = max_q_index.detach().item()

        return action




In [3]:
# Helper function to plot model performance during training
def plot_durations(show_result=False, episode_durations=[]):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [11]:
EPISODES = 100
EPSILON_START = 1.0
EPSILON_END = 0.02
EPSILON_DECAY=15000
BUFFER_SIZE = 50000
BATCH_SIZE = 128
GAMMA = 0.99
LEARNING_RATE=5e-4
TARGET_UPDATE_FREQ = 1000
MIN_REPLAY_SIZE = 2000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize experience replay buffer
replay_buffer = deque(maxlen=BUFFER_SIZE)

# Initialize Q and fixed target network
QNet = QApproximator().to(device)
TargNet = QApproximator().to(device).eval() # Set TargNet in eval mode as we never calculate gradients for it
TargNet.load_state_dict(QNet.state_dict())

optimizer = torch.optim.Adam(QNet.parameters(), lr=LEARNING_RATE)

# Actions are: Run right, run right and jump, jump, run left
actions = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0]])

episode_rewards = []

In [5]:
# Custom scenario.json which ends the episode when there are only 2 lives left rather than none
env = retro.RetroEnv(game="SuperMarioBros-Nes", scenario='./scenario.json', render_mode='human')

# Modify observations to be preprocessed.
env = ResizeObservation(env, (84, 84))
env = GrayScaleObservation(env)
env = TransformObservation(env, lambda obs: obs / 255.0)
env = FrameStack(env, num_stack=4)

In [12]:
# Storing the action and reward over 4 actions, as we're making a decision every 4 frames
current_action = 0
action_step = 0
single_experience_buffer = [None, None, None, None, None]

state, info = env.reset()
for _ in range(MIN_REPLAY_SIZE):
    if action_step % 4 != 0 and action_step % 4 != 3:
        new_state, reward, terminated, truncated, info = env.step(actions[current_action])
        action_step += 1
        single_experience_buffer[2] += reward

        if terminated or truncated:
            single_experience_buffer[3] = terminated
            single_experience_buffer[4] = new_state
            replay_buffer.append(single_experience_buffer)
            action_step = 0
            env.reset()

        continue

    elif action_step % 4 == 3:
        new_state, reward, terminated, truncated, info = env.step(actions[current_action])
        single_experience_buffer[2] += reward
        single_experience_buffer[3] = terminated
        single_experience_buffer[4] = new_state
        replay_buffer.append(single_experience_buffer)

        action_step = 0

        if terminated or truncated:
            env.reset()
            acton_step = 0
        continue

    current_action = random.choice([0, 1, 2, 3])
    action_step += 1

    new_state, reward, terminated, truncated, info = env.step(actions[current_action])
    single_experience_buffer = [state, current_action, reward, None, None]

    if truncated or terminated:
        single_experience_buffer = [state, current_action, reward, terminated, new_state]
        replay_buffer.append(single_experience_buffer)
        action_step = 0

        env.reset()

In [13]:
def optimize_model():
    experiences = random.sample(replay_buffer, BATCH_SIZE)

    states = np.asarray([t[0] for t in experiences])
    actions = np.asarray([t[1] for t in experiences])
    rewards = np.asarray([t[2] for t in experiences])
    terminated = np.asarray([t[3] for t in experiences])
    new_states = np.asarray([t[4] for t in experiences])

    states_t = torch.as_tensor(states, dtype=torch.float32, device=device)
    actions_t = torch.as_tensor(actions, dtype=torch.int64, device=device).unsqueeze(-1)
    rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=device).unsqueeze(-1)
    terminated_t = torch.as_tensor(terminated, dtype=torch.float32, device=device).unsqueeze(-1)
    new_states_t = torch.as_tensor(new_states, dtype=torch.float32, device=device)

    # Compute Targets
    target_q_values = TargNet(new_states_t)
    max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]

    targets = rewards_t + GAMMA * (1 - terminated_t) * max_target_q_values

    q_values = QNet(states_t)
    action_q_values = torch.gather(input=q_values, dim=1, index=actions_t)

    loss = F.mse_loss(action_q_values, targets)

    # Gradient descent
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [14]:
total_steps = 0

# Storing the action and reward over 4 actions, as we're making a decision every 4 frames
current_action = 0
action_step = 0
single_experience_buffer = [None, None, None, None, None]

for episodes in range(EPISODES):
    episode_reward = 0
    state, info = env.reset()
    for t in itertools.count():
        # Update target network
        if total_steps % TARGET_UPDATE_FREQ == 0:
            TargNet.load_state_dict(QNet.state_dict())

        # Handling the case where we don't make decisions
        if action_step % 4 != 0 and action_step % 4 != 3:
            new_state, reward, terminated, truncated, new_info = env.step(actions[current_action])
            action_step += 1
            single_experience_buffer[2] += reward

            if terminated or truncated:
                single_experience_buffer[3] = terminated
                single_experience_buffer[4] = new_state
                replay_buffer.append(single_experience_buffer)
                action_step = 0

                # Optimize model
                QNet.train()
                optimize_model()
                total_steps += 1

                episode_reward += single_experience_buffer[2]
                episode_rewards.append(episode_reward)
                plot_durations(show_result=False, episode_durations=episode_rewards)
                break

            state, info = new_state, new_info
            continue

        elif action_step % 4 == 3:
            new_state, reward, terminated, truncated, new_info = env.step(actions[current_action])
            single_experience_buffer[2] += reward
            single_experience_buffer[3] = terminated
            single_experience_buffer[4] = new_state
            replay_buffer.append(single_experience_buffer)

            QNet.train()
            optimize_model()
            total_steps += 1

            action_step = 0

            episode_reward += single_experience_buffer[2]
            if terminated or truncated:
                acton_step = 0
                episode_rewards.append(episode_reward)
                plot_durations(show_result=False, episode_durations=episode_rewards)
                break

            state, info = new_state, new_info
            continue


        # Select action
        epsilon = np.interp(total_steps, [0, EPSILON_DECAY], [EPSILON_START, EPSILON_END])
        rand_sample = random.random()
        
        if rand_sample <= epsilon:
            action = random.choice([0, 1, 2, 3])
        else:
            QNet.eval()
            current_action = QNet.act(state=state, device=device)

        # Take action
        new_state, reward, terminated, truncated, new_info = env.step(actions[current_action])
        action_step += 1

        single_experience_buffer = [state, action, reward, None, None]

        # Append to experience buffer immediately if terminated or truncated
        if terminated or truncated:
            single_experience_buffer = [state, action, reward, terminated, new_state]
            replay_buffer.append(single_experience_buffer)
            action_step = 0

            # Optimize model
            QNet.train()
            optimize_model()
            total_steps += 1

            episode_reward += single_experience_buffer[2]
            episode_rewards.append(episode_reward)
            plot_durations(show_result=False, episode_durations=episode_rewards)

            break

        state, info = new_state, new_info

    if episodes % 5 == 0 and episodes != 0:
        checkpoint = {
            'model_state_dict': QNet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'total_steps': total_steps
        }
        torch.save(checkpoint, f'./model_checkpoint_episode_{episodes}.pt')


torch.save(QNet.state_dict(), './mario-dqn-model.pt')
torch.save(optimizer.state_dict(), './mario-dqn-optimizer.pt')

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [16]:
env = retro.RetroEnv(game="SuperMarioBros-Nes", scenario='./scenario.json', render_mode='human')

# Modify observations to be preprocessed.
env = ResizeObservation(env, (84, 84))
env = GrayScaleObservation(env)
env = TransformObservation(env, lambda obs: obs / 255.0)
env = FrameStack(env, num_stack=4)

loaded_model = QApproximator().to(device)
checkpoint = torch.load('./model_checkpoint_episode_25.pt')
loaded_model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [15]:
env.close()

In [17]:
# Storing the action and reward over 4 actions, as we're making a decision every 4 frames
current_action = 0
action_step = 0

state, info = env.reset()
while(True):
    if action_step % 4 != 0 and action_step % 4 != 3:
        state, reward, terminated, truncated, info = env.step(actions[current_action])
        action_step += 1

        if terminated or truncated:
            action_step = 0
            state, info = env.reset()

        continue

    elif action_step % 4 == 3:
        state, reward, terminated, truncated, info = env.step(actions[current_action])
        action_step = 0

        if terminated or truncated:
            state, info = env.reset()
            acton_step = 0
        continue

    current_action = loaded_model.act(state, device=device)
    action_step += 1

    state, reward, terminated, truncated, info = env.step(actions[current_action])

    if truncated or terminated:
        action_step = 0
        state, info = env.reset()

KeyboardInterrupt: 