In [1]:
import gym

#Import game
import gym_super_mario_bros

#Import Joypad
from nes_py.wrappers import JoypadSpace

#Import Controls
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

#FrameStacking e GreyScaling
from gym.wrappers import GrayScaleObservation

#Vectorization
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv

#Graphs
from matplotlib import pyplot as plt

import os

#Proximal Policy Optimization
from stable_baselines3 import PPO

#To save models
from stable_baselines3.common.callbacks import BaseCallback

import torch as th

In [2]:
# [1] - Setup base environment
env = gym_super_mario_bros.make("SuperMarioBros-v0", render_mode='human', apply_api_compatibility=True)
#256 possible actions
print(env.action_space)

# [2] - Simplify the controls, with wrapping the model will have only 7 possible actions
env = JoypadSpace(env, SIMPLE_MOVEMENT)
print(env.action_space)

# [3] - Grayscale
env = GrayScaleObservation(env, keep_dim=True)

# [4] - Wrap inside a dummy envorinment
env = DummyVecEnv([lambda: env])

# [5] - Stack four frames
env = VecFrameStack(env, 4, channels_order='last')

#JoypadSpace doesn't correctly override the reset() method of Wrapper.
JoypadSpace.reset = lambda self, **kwargs: self.env.reset(**kwargs)

Discrete(256)
Discrete(7)


  logger.warn(
  logger.warn(


In [3]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [4]:
checkpoint_directory = './train/'
log_dir = './logs/'

callback = TrainAndLoggingCallback(check_freq=100000,save_path=checkpoint_directory)

In [5]:
policy_kwargs = dict(activation_fn=th.nn.ReLU,net_arch=dict(pi=[32, 64, 64, 64], vf=[32, 64, 64, 64]))


model = PPO('CnnPolicy', env, verbose=1, tensorboard_log=log_dir, learning_rate=0.0005,n_steps=4096,policy_kwargs=policy_kwargs)

model.policy

Using cuda device
Wrapping the env in a VecTransposeImage.


ActorCriticCnnPolicy(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=46592, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (pi_features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=46592, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (vf_features_extractor): NatureCNN(
    (cnn)

In [6]:
model.learn(total_timesteps=3500000, callback=callback)

Logging to ./logs/PPO_1


  if not isinstance(terminated, (bool, np.bool8)):


-----------------------------
| time/              |      |
|    fps             | 168  |
|    iterations      | 1    |
|    time_elapsed    | 24   |
|    total_timesteps | 4096 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 120         |
|    iterations           | 2           |
|    time_elapsed         | 67          |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.013708018 |
|    clip_fraction        | 0.137       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.93       |
|    explained_variance   | -0.00107    |
|    learning_rate        | 0.0005      |
|    loss                 | 0.686       |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.00338    |
|    value_loss           | 4.02        |
-----------------------------------------
----------------------------------

<stable_baselines3.ppo.ppo.PPO at 0x14553f96110>

: 