<a href="https://colab.research.google.com/github/zzmtsvv/RL-with-gym/blob/main/supermario.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install gym-super-mario-bros==7.3.0



In [None]:
'''
Implementation of DDQN (CNN) via arxiv.org/pdf/1509.06461.pdf
on super mario bros environment using pytorch
'''

In [2]:
import torch
from torch import nn
from torchvision import transforms as T
import numpy as np
import random
import datetime
import time
import matplotlib.pyplot as plt
import os
import copy
from pathlib import Path
from collections import deque
import gym
from gym.wrappers import FrameStack, GrayScaleObservation, TransformObservation
from gym.spaces import Box
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from skimage import transform

In [3]:
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0")

COMPLEX_MOVEMENT = [
    ['NOOP'],
    ['right'],
    ['right', 'A'],
    ['right', 'B'],
    ['right', 'A', 'B'],
    ['A'],
    ['left'],
    ['left', 'A'],
    ['left', 'B'],
    ['left', 'A', 'B'],
    ['down'],
    ['up'],
]

env = JoypadSpace(env, COMPLEX_MOVEMENT)

env.reset()
next_state, reward, done, info = env.step(action=0)
print(f"{next_state.shape},\n {reward},\n {done},\n {info}")

(240, 256, 3),
 0,
 False,
 {'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 400, 'world': 1, 'x_pos': 40, 'x_pos_screen': 40, 'y_pos': 79}


In [4]:
class SkipFrame(gym.Wrapper):
  def __init__(self, env, skip):
    """Return only every `skip`-th frame"""
    super().__init__(env)
    self._skip = skip

  def step(self, action):
    """Repeat action, and sum reward"""
    total_reward = 0.0
    done = False
    for i in range(self._skip):
      # Accumulate reward and repeat the same action
      obs, reward, done, info = self.env.step(action)
      total_reward += reward
      if done:
        break
    return obs, total_reward, done, info


class ResizeObservation(gym.ObservationWrapper):
  def __init__(self, env, shape):
    super().__init__(env)
    if isinstance(shape, int):
      self.shape = (shape, shape)
    else:
      self.shape = tuple(shape)

    obs_shape = self.shape + self.observation_space.shape[2:]
    self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

  def observation(self, observation):
    resize_obs = transform.resize(observation, self.shape)
    # cast float back to uint8
    resize_obs *= 255
    resize_obs = resize_obs.astype(np.uint8)
    return resize_obs


env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env, keep_dim=False)
env = ResizeObservation(env, shape=84)
env = TransformObservation(env, f=lambda x: x / 255.)
env = FrameStack(env, num_stack=4)

In [5]:
class ArseNet(nn.Module):
  def __init__(self, input_dim, output_dim):
    super().__init__()
    c, h, w = input_dim
    if h != 84:
      raise ArithmeticError(f"Expected input height: 84, got: {h}")
    if w != 84:
      raise ArithmeticError(f"Expected input width: 84, got: {w}")

    self.online = nn.Sequential(
        nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3136, 512),
        nn.ReLU(),
        nn.Linear(512, output_dim)
    )
    self.target = copy.deepcopy(self.online)

    for theta in self.target.parameters():
      theta.requires_grad = False
  
  def forward(self, inputs, model):
    if model == 'online':
      return self.online(inputs)
    elif model == 'target':
      return self.target(inputs)

In [6]:
class Agent():
  def __init__(self, state_dim, action_dim, save_dir, checkpoint=False):
    self.state_dim = state_dim
    self.action_dim = action_dim
    self.memory = deque(maxlen=100000)
    self.batch_size = 32

    self.exploration_rate = 1
    self.exploration_rate_decay = 0.9999995
    self.exploration_rate_min = 0.1
    self.gamma = 0.9

    self.curr_step = 0
    self.burnin = 1e5
    self.learn_every = 3
    self.sync_every = 1e4

    self.save_every = 5e5
    self.save_dir = save_dir

    self.gpu = torch.cuda.is_available()

    self.net = ArseNet(self.state_dim, self.action_dim).float()
    if self.gpu:
      self.net = self.net.to(device='cuda')
    if checkpoint:
      self.load(checkpoint)
    
    self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4)
    self.loss_fn = torch.nn.SmoothL1Loss()

  
  def act(self, state):
    '''
    choose epsilon-greedy action given state and update the value of step
    '''
    # exploration
    if np.random.rand() < self.exploration_rate:
      action_idx = np.random.randint(self.action_dim)
    # exploitation
    else:
      state = torch.FloatTensor(state).cuda() if self.gpu else torch.FloatTensor(state)
      state = state.unsqueeze(0)
      action_values = self.net(state, model='online')
      action_idx = torch.argmax(action_values, axis=1).item()
    
    # decrease exploration_rate
    self.exploration_rate *= self.exploration_rate_decay
    if self.exploration_rate < self.exploration_rate_min:
      self.exploration_rate = self.exploration_rate_min
    
    self.curr_step += 1
    return action_idx
  
  def cache(self, state, next_state, action, reward, done):
    '''
    store the experience to self.memory
    '''
    state = torch.FloatTensor(state).cuda() if self.gpu else torch.FloatTensor(state)
    next_state = torch.FloatTensor(next_state).cuda() if self.gpu else torch.FloatTensor(next_state)
    action = torch.LongTensor([action]).cuda() if self.gpu else torch.LongTensor([action])
    reward = torch.DoubleTensor([reward]).cuda() if self.gpu else torch.DoubleTensor([reward])
    done = torch.BoolTensor([done]).cuda() if self.gpu else torch.BoolTensor([done])

    self.memory.append((state, next_state, action, reward, done,))
  
  def recall(self):
    '''
    Take a batch of experience from memory
    '''
    batch = random.sample(self.memory, self.batch_size)
    state, next_state, action, reward, done = map(torch.stack, zip(*batch))
    return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
  
  def td_estimate(self, state, action):
    # the predicted optimal Q_online for a given state
    current_Q = self.net(state, model="online")[np.arange(0, self.batch_size), action]
    return current_Q
  
  @torch.no_grad()
  def td_target(self, reward, next_state, done):
    '''
    TD Target - aggregation of current reward and the estimated Q
    in the next state
    '''
    next_state_Q = self.net(state, model='online')
    best_action = torch.argmax(next_state_Q, axis=1)
    next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
    return (reward + (1 - done.float()) * self.gamma * next_Q).float()
  
  def update_Q_online(self, td_estimate, td_target):
    loss = self.loss_fn(td_estimate, td_target)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    return loss.item()
  
  def sync_Q_target(self):
    self.net.target.load_state_dict(self.net.online.state_dict())
  
  def save(self):
    num = int(self.curr_step // self.save_every)
    save_path = (self.save_dir / f"mario_net_{num}.chkpt")
    torch.save(dict(model=self.net.state_dict(),
                    exploration_rate=self.exploration_rate), save_path)
    print(f"ArseNet saved to {save_path} at step {self.curr_step}")
  
  def learn(self):
    if not self.curr_step % self.sync_every:
      self.sync_Q_target()

    if not self.curr_step % self.save_every:
      self.save()

    if self.curr_step < self.burnin:
      return None, None

    if self.curr_step % self.learn_every:
      return None, None

    
    # sample from memory
    state, next_state, action, reward, done = self.recall()

    # TD Estimate and TD Target
    td_est = self.td_estimate(state, action)
    td_trgt = self.td_target(reward, next_state, done)

    loss = self.update_Q_online(td_est, td_trgt)

    return td_est.mean().item(), loss
  
  def save(self):
    save_path = self.save_dir / f"arse_net_{int(self.curr_step // self.save_every)}.chkpt"
    torch.save(
        dict(
            model=self.net.state_dict(),
            exploration_rate=self.exploration_rate
        ),
        save_path
    )
    print(f"ArseNet saved to {save_path} at step {self.curr_step}")
  
  def load(self, load_path):
    if not load_path.exists():
      raise ValueError(f"{load_path} does not exist")

    ckp = torch.load(load_path, map_location=('cuda' if self.gpu else 'cpu'))
    exploration_rate = ckp.get('exploration_rate')
    state_dict = ckp.get('model')

    print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
    self.net.load_state_dict(state_dict)
    self.exploration_rate = exploration_rate

In [7]:
class MetricLogger():
    def __init__(self, save_dir):
        self.save_log = save_dir / "log"
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = save_dir / "reward_plot.jpg"
        self.ep_lengths_plot = save_dir / "length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / "q_plot.jpg"

        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()


    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        "Mark end of episode"
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)


        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
            plt.plot(getattr(self, f"moving_avg_{metric}"))
            plt.savefig(getattr(self, f"{metric}_plot"))
            plt.clf()

In [8]:
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()

save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)

mario = Agent(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)

logger = MetricLogger(save_dir)

episodes = 10

for e in range(episodes):
  state = env.reset()

  while True:
    action = mario.act(state)
    next_state, reward, done, info = env.step(action)

    mario.cache(state, next_state, action, reward, done)

    q, loss = mario.learn()
    logger.log_step(reward, loss, q)
    state = next_state
    
    if done or info['flag_get']:
      break

  logger.log_episode()

  if e % 20 == 0:
    logger.record(
        episode=e,
        epsilon=mario.exploration_rate,
        step=mario.curr_step
        )

Using CUDA: True



  return (self.ram[0x86] - self.ram[0x071c]) % 256


Episode 0 - Step 58 - Epsilon 0.9999710004132487 - Mean Reward 202.0 - Mean Length 58.0 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 1.723 - Time 2021-07-06T17:54:10


<Figure size 432x288 with 0 Axes>

In [9]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

In [10]:
from gym.wrappers import Monitor
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay

In [11]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x7f0d52b14510>

In [12]:
def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

In [13]:
env = wrap_env(env)
state = env.reset()
done = False
total_reward = 0

In [None]:
while not done:
    env.render()
    action = mario.act(state)
    state, reward, done, _ = env.step(action)
    total_reward += reward
env.close()
print(total_reward)
show_video()