# 1. Import Dependencies

In [1]:
import gym 
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
import os

# 2. Test Environment

In [2]:
environment_name = "Breakout-v0"

In [3]:
env = gym.make(environment_name)

In [20]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:0.0
Episode:2 Score:2.0
Episode:3 Score:0.0
Episode:4 Score:0.0
Episode:5 Score:0.0


In [17]:
env.action_space.sample()

1

In [19]:
env.observation_space.sample()

array([[[255,  87, 191],
        [191, 150, 140],
        [ 90, 232,  25],
        ...,
        [ 33, 143,   6],
        [231,  97, 197],
        [211, 207,  43]],

       [[ 55, 241,  12],
        [ 93, 137, 233],
        [100, 176,   0],
        ...,
        [174, 148, 218],
        [244, 226, 197],
        [235, 184, 157]],

       [[132,   9, 171],
        [140,  23, 204],
        [243, 226,  55],
        ...,
        [ 40,  31, 224],
        [236,  46,  56],
        [199, 219, 192]],

       ...,

       [[131, 111, 180],
        [107, 156, 231],
        [155, 244,  50],
        ...,
        [115, 160, 253],
        [100,  87,  14],
        [  1,  29, 168]],

       [[ 66, 243, 200],
        [ 52,  23, 251],
        [173,   1,  63],
        ...,
        [205, 200,  18],
        [196,  49, 123],
        [245, 193, 207]],

       [[181, 206,  90],
        [ 77, 189, 206],
        [ 95, 219,  48],
        ...,
        [163,  61, 169],
        [ 62, 248, 149],
        [117, 118,  90]]

# 3. Vectorise Environment and Train Model

In [21]:
env = make_atari_env('Breakout-v0', n_envs=4, seed=0)

In [22]:
env = VecFrameStack(env, n_stack=4)

In [23]:
log_path = os.path.join('Training', 'Logs')

In [24]:
model = A2C("CnnPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device
Wrapping the env in a VecTransposeImage.


In [25]:
model.learn(total_timesteps=400000)

Logging to Training/Logs/A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 284      |
|    ep_rew_mean        | 1.63     |
| time/                 |          |
|    fps                | 238      |
|    iterations         | 100      |
|    time_elapsed       | 8        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.39    |
|    explained_variance | 0.057    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 0.389    |
|    value_loss         | 0.27     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 292      |
|    ep_rew_mean        | 1.73     |
| time/                 |          |
|    fps                | 255      |
|    iterations         | 200      |
|    time_elapsed       | 15       |
|    total_timesteps    | 4000     |
| train

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 334      |
|    ep_rew_mean        | 2.69     |
| time/                 |          |
|    fps                | 270      |
|    iterations         | 1400     |
|    time_elapsed       | 103      |
|    total_timesteps    | 28000    |
| train/                |          |
|    entropy_loss       | -1.02    |
|    explained_variance | 0.933    |
|    learning_rate      | 0.0007   |
|    n_updates          | 1399     |
|    policy_loss        | -0.251   |
|    value_loss         | 0.0631   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 349      |
|    ep_rew_mean        | 3.14     |
| time/                 |          |
|    fps                | 271      |
|    iterations         | 1500     |
|    time_elapsed       | 110      |
|    total_timesteps    | 30000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 451      |
|    ep_rew_mean        | 5.2      |
| time/                 |          |
|    fps                | 272      |
|    iterations         | 2700     |
|    time_elapsed       | 198      |
|    total_timesteps    | 54000    |
| train/                |          |
|    entropy_loss       | -0.353   |
|    explained_variance | 0.453    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2699     |
|    policy_loss        | -0.123   |
|    value_loss         | 0.103    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 447      |
|    ep_rew_mean        | 5.08     |
| time/                 |          |
|    fps                | 272      |
|    iterations         | 2800     |
|    time_elapsed       | 205      |
|    total_timesteps    | 56000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 496      |
|    ep_rew_mean        | 5.91     |
| time/                 |          |
|    fps                | 281      |
|    iterations         | 4100     |
|    time_elapsed       | 291      |
|    total_timesteps    | 82000    |
| train/                |          |
|    entropy_loss       | -0.163   |
|    explained_variance | 0.887    |
|    learning_rate      | 0.0007   |
|    n_updates          | 4099     |
|    policy_loss        | -0.0355  |
|    value_loss         | 0.0923   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 507      |
|    ep_rew_mean        | 6.13     |
| time/                 |          |
|    fps                | 282      |
|    iterations         | 4200     |
|    time_elapsed       | 297      |
|    total_timesteps    | 84000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 543      |
|    ep_rew_mean        | 7.03     |
| time/                 |          |
|    fps                | 293      |
|    iterations         | 5400     |
|    time_elapsed       | 368      |
|    total_timesteps    | 108000   |
| train/                |          |
|    entropy_loss       | -0.0581  |
|    explained_variance | 0.92     |
|    learning_rate      | 0.0007   |
|    n_updates          | 5399     |
|    policy_loss        | -0.00145 |
|    value_loss         | 0.0526   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 533      |
|    ep_rew_mean        | 6.94     |
| time/                 |          |
|    fps                | 293      |
|    iterations         | 5500     |
|    time_elapsed       | 374      |
|    total_timesteps    | 110000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 563      |
|    ep_rew_mean        | 7.49     |
| time/                 |          |
|    fps                | 300      |
|    iterations         | 6700     |
|    time_elapsed       | 445      |
|    total_timesteps    | 134000   |
| train/                |          |
|    entropy_loss       | -0.0769  |
|    explained_variance | 0.863    |
|    learning_rate      | 0.0007   |
|    n_updates          | 6699     |
|    policy_loss        | -0.0108  |
|    value_loss         | 0.15     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 565      |
|    ep_rew_mean        | 7.54     |
| time/                 |          |
|    fps                | 301      |
|    iterations         | 6800     |
|    time_elapsed       | 451      |
|    total_timesteps    | 136000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 575      |
|    ep_rew_mean        | 7.84     |
| time/                 |          |
|    fps                | 306      |
|    iterations         | 8100     |
|    time_elapsed       | 529      |
|    total_timesteps    | 162000   |
| train/                |          |
|    entropy_loss       | -0.031   |
|    explained_variance | 0.88     |
|    learning_rate      | 0.0007   |
|    n_updates          | 8099     |
|    policy_loss        | -0.00213 |
|    value_loss         | 0.119    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 582      |
|    ep_rew_mean        | 8.04     |
| time/                 |          |
|    fps                | 306      |
|    iterations         | 8200     |
|    time_elapsed       | 535      |
|    total_timesteps    | 164000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 562      |
|    ep_rew_mean        | 7.53     |
| time/                 |          |
|    fps                | 309      |
|    iterations         | 9400     |
|    time_elapsed       | 607      |
|    total_timesteps    | 188000   |
| train/                |          |
|    entropy_loss       | -0.0752  |
|    explained_variance | 0.816    |
|    learning_rate      | 0.0007   |
|    n_updates          | 9399     |
|    policy_loss        | 0.121    |
|    value_loss         | 0.144    |
------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 561       |
|    ep_rew_mean        | 7.47      |
| time/                 |           |
|    fps                | 309       |
|    iterations         | 9500      |
|    time_elapsed       | 613       |
|    total_timesteps    | 190000    |
| train/                |    

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 616      |
|    ep_rew_mean        | 8.76     |
| time/                 |          |
|    fps                | 312      |
|    iterations         | 10700    |
|    time_elapsed       | 685      |
|    total_timesteps    | 214000   |
| train/                |          |
|    entropy_loss       | -0.0352  |
|    explained_variance | 0.918    |
|    learning_rate      | 0.0007   |
|    n_updates          | 10699    |
|    policy_loss        | 0.00813  |
|    value_loss         | 0.0837   |
------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 626       |
|    ep_rew_mean        | 8.89      |
| time/                 |           |
|    fps                | 312       |
|    iterations         | 10800     |
|    time_elapsed       | 691       |
|    total_timesteps    | 216000    |
| train/                |    

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 660      |
|    ep_rew_mean        | 10.2     |
| time/                 |          |
|    fps                | 314      |
|    iterations         | 12000    |
|    time_elapsed       | 762      |
|    total_timesteps    | 240000   |
| train/                |          |
|    entropy_loss       | -0.0362  |
|    explained_variance | 0.929    |
|    learning_rate      | 0.0007   |
|    n_updates          | 11999    |
|    policy_loss        | -0.0149  |
|    value_loss         | 0.0417   |
------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 653       |
|    ep_rew_mean        | 9.96      |
| time/                 |           |
|    fps                | 314       |
|    iterations         | 12100     |
|    time_elapsed       | 768       |
|    total_timesteps    | 242000    |
| train/                |    

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 636      |
|    ep_rew_mean        | 9.47     |
| time/                 |          |
|    fps                | 316      |
|    iterations         | 13300    |
|    time_elapsed       | 840      |
|    total_timesteps    | 266000   |
| train/                |          |
|    entropy_loss       | -0.216   |
|    explained_variance | 0.847    |
|    learning_rate      | 0.0007   |
|    n_updates          | 13299    |
|    policy_loss        | -0.0784  |
|    value_loss         | 0.107    |
------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 615       |
|    ep_rew_mean        | 8.75      |
| time/                 |           |
|    fps                | 316       |
|    iterations         | 13400     |
|    time_elapsed       | 846       |
|    total_timesteps    | 268000    |
| train/                |    

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 684      |
|    ep_rew_mean        | 10.4     |
| time/                 |          |
|    fps                | 317      |
|    iterations         | 14600    |
|    time_elapsed       | 918      |
|    total_timesteps    | 292000   |
| train/                |          |
|    entropy_loss       | -0.0301  |
|    explained_variance | 0.92     |
|    learning_rate      | 0.0007   |
|    n_updates          | 14599    |
|    policy_loss        | 0.000889 |
|    value_loss         | 0.0586   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 669      |
|    ep_rew_mean        | 10.2     |
| time/                 |          |
|    fps                | 318      |
|    iterations         | 14700    |
|    time_elapsed       | 924      |
|    total_timesteps    | 294000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 684      |
|    ep_rew_mean        | 10.6     |
| time/                 |          |
|    fps                | 319      |
|    iterations         | 15900    |
|    time_elapsed       | 995      |
|    total_timesteps    | 318000   |
| train/                |          |
|    entropy_loss       | -0.0753  |
|    explained_variance | 0.405    |
|    learning_rate      | 0.0007   |
|    n_updates          | 15899    |
|    policy_loss        | 0.0378   |
|    value_loss         | 0.0975   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 675      |
|    ep_rew_mean        | 10.7     |
| time/                 |          |
|    fps                | 319      |
|    iterations         | 16000    |
|    time_elapsed       | 1001     |
|    total_timesteps    | 320000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 691      |
|    ep_rew_mean        | 10.3     |
| time/                 |          |
|    fps                | 317      |
|    iterations         | 17200    |
|    time_elapsed       | 1083     |
|    total_timesteps    | 344000   |
| train/                |          |
|    entropy_loss       | -0.0619  |
|    explained_variance | 0.252    |
|    learning_rate      | 0.0007   |
|    n_updates          | 17199    |
|    policy_loss        | 0.00588  |
|    value_loss         | 0.252    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 697      |
|    ep_rew_mean        | 10.5     |
| time/                 |          |
|    fps                | 317      |
|    iterations         | 17300    |
|    time_elapsed       | 1090     |
|    total_timesteps    | 346000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 743      |
|    ep_rew_mean        | 12.2     |
| time/                 |          |
|    fps                | 314      |
|    iterations         | 18500    |
|    time_elapsed       | 1175     |
|    total_timesteps    | 370000   |
| train/                |          |
|    entropy_loss       | -0.0649  |
|    explained_variance | 0.946    |
|    learning_rate      | 0.0007   |
|    n_updates          | 18499    |
|    policy_loss        | 0.000663 |
|    value_loss         | 0.0742   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 735      |
|    ep_rew_mean        | 12.3     |
| time/                 |          |
|    fps                | 314      |
|    iterations         | 18600    |
|    time_elapsed       | 1182     |
|    total_timesteps    | 372000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 697      |
|    ep_rew_mean        | 11.1     |
| time/                 |          |
|    fps                | 314      |
|    iterations         | 19800    |
|    time_elapsed       | 1259     |
|    total_timesteps    | 396000   |
| train/                |          |
|    entropy_loss       | -0.0939  |
|    explained_variance | 0.653    |
|    learning_rate      | 0.0007   |
|    n_updates          | 19799    |
|    policy_loss        | 0.00132  |
|    value_loss         | 0.0758   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 698      |
|    ep_rew_mean        | 11.3     |
| time/                 |          |
|    fps                | 314      |
|    iterations         | 19900    |
|    time_elapsed       | 1266     |
|    total_timesteps    | 398000   |
| train/                |          |
|

<stable_baselines3.a2c.a2c.A2C at 0x7f2affcf7b80>

# 4. Save and Reload Model

In [26]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_model')

In [27]:
model.save(a2c_path)

In [28]:
del model

In [29]:
env = make_atari_env('Breakout-v0', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [30]:
model = A2C.load(a2c_path, env)

Wrapping the env in a VecTransposeImage.


# 5. Evaluate and Test

In [35]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(12.9, 3.419064199455752)

In [32]:
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

KeyboardInterrupt: 

In [36]:
env.close()