In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [39]:
import os
import os.path as osp

import gym
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.callbacks import (
    EvalCallback,
    StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

<IPython.core.display.Javascript object>

# Load environment

In [3]:
environment_name = "CartPole-v1"
env = gym.make(environment_name)

<IPython.core.display.Javascript object>

In [4]:
env.reset()  # Initial set of observations

array([ 0.01902798, -0.04283102,  0.04000649, -0.00608882], dtype=float32)

<IPython.core.display.Javascript object>

In [5]:
print(env.action_space)
print(env.action_space.sample())

Discrete(2)
1


<IPython.core.display.Javascript object>

In [6]:
print(env.observation_space)

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


<IPython.core.display.Javascript object>

In [7]:
print(env.observation_space.sample())

[ 2.3039937e+00 -3.3043992e+38  2.0541702e-01  3.1132560e+38]


<IPython.core.display.Javascript object>

In [8]:
print(env.step(1))

(array([ 0.01817136,  0.15169501,  0.03988472, -0.28588563], dtype=float32), 1.0, False, {})


<IPython.core.display.Javascript object>

In [9]:
episodes = 5
for episode in range(1, episodes + 1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward

    print(f"Episode: {episode}, Score: {score}")
env.close()

Episode: 1, Score: 18.0
Episode: 2, Score: 21.0
Episode: 3, Score: 36.0
Episode: 4, Score: 11.0
Episode: 5, Score: 18.0


<IPython.core.display.Javascript object>

# Training

In [10]:
log_path = osp.join("Training", "Logs")

<IPython.core.display.Javascript object>

In [11]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device


<IPython.core.display.Javascript object>

In [12]:
model.learn(total_timesteps=20_000)

Logging to Training\Logs\PPO_4
-----------------------------
| time/              |      |
|    fps             | 1842 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1290        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009169854 |
|    clip_fraction        | 0.106       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.0031      |
|    learning_rate        | 0.0003      |
|    loss                 | 9.98        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0165     |
|    value_loss           | 67.9        |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x2202a5902e0>

<IPython.core.display.Javascript object>

# Save and reload model

In [13]:
PPO_path = osp.join("Training", "SavedModels", "PPO_Model_Cartpole")

model.save(PPO_path)

<IPython.core.display.Javascript object>

In [14]:
del model

model = PPO.load(PPO_path, env=env)

<IPython.core.display.Javascript object>

# Evaluation

In [15]:
evaluate_policy(model=model, env=env, n_eval_episodes=1, render=True)



(500.0, 0.0)

<IPython.core.display.Javascript object>

In [16]:
env.close()

<IPython.core.display.Javascript object>

# Test model

In [38]:
episodes = 5
for episode in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action, next_state = model.predict(obs)  # Now using model to make predictions
        obs, reward, done, info = env.step(action)
        score += reward

    print(f"Episode: {episode}, Score: {score}")
env.close()

Episode: 1, Score: [28.]
Episode: 2, Score: [22.]
Episode: 3, Score: [177.]
Episode: 4, Score: [71.]
Episode: 5, Score: [236.]


<IPython.core.display.Javascript object>

In [23]:
action, _ = model.predict(obs)
observations, reward, done, info = env.step(action)

array([1.], dtype=float32)

<IPython.core.display.Javascript object>

# Adding a callback to the training stage

In [28]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=190, verbose=1)
eval_callback = EvalCallback(
    env,
    callback_on_new_best=stop_callback,
    eval_freq=1000,
    verbose=1,
)

<IPython.core.display.Javascript object>

In [30]:
model.learn(total_timesteps=20_000, callback=eval_callback)

Logging to Training\Logs\PPO_5




Eval num_timesteps=1000, episode_reward=457.40 +/- 45.54
Episode length: 457.40 +/- 45.54
---------------------------------
| eval/              |          |
|    mean_ep_length  | 457      |
|    mean_reward     | 457      |
| time/              |          |
|    total_timesteps | 1000     |
---------------------------------
New best mean reward!
Stopping training because the mean reward 457.40  is above the threshold 190


<stable_baselines3.ppo.ppo.PPO at 0x220435451b0>

<IPython.core.display.Javascript object>

# Changing policy

In [35]:
new_arch = {"pi": [182, 128, 128, 128], "vf": [128, 128, 128, 128]}
model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    tensorboard_log=log_path,
    policy_kwargs={"net_arch": new_arch},
)

Using cpu device


<IPython.core.display.Javascript object>

In [37]:
model.learn(total_timesteps=20_000, callback=eval_callback)

Logging to Training\Logs\PPO_6
Eval num_timesteps=1000, episode_reward=30.00 +/- 5.55
Episode length: 30.00 +/- 5.55
---------------------------------
| eval/              |          |
|    mean_ep_length  | 30       |
|    mean_reward     | 30       |
| time/              |          |
|    total_timesteps | 1000     |
---------------------------------




Eval num_timesteps=2000, episode_reward=27.60 +/- 5.75
Episode length: 27.60 +/- 5.75
---------------------------------
| eval/              |          |
|    mean_ep_length  | 27.6     |
|    mean_reward     | 27.6     |
| time/              |          |
|    total_timesteps | 2000     |
---------------------------------
-----------------------------
| time/              |      |
|    fps             | 1310 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
Eval num_timesteps=3000, episode_reward=302.20 +/- 167.19
Episode length: 302.20 +/- 167.19
-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 302         |
|    mean_reward          | 302         |
| time/                   |             |
|    total_timesteps      | 3000        |
| train/                  |             |
|    approx_kl            | 0.012873263 |
|    clip_fraction        | 0.168     

<stable_baselines3.ppo.ppo.PPO at 0x2204da07d90>

<IPython.core.display.Javascript object>

# Using an alternate algorithm

In [40]:
model = DQN("MlpPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device


<IPython.core.display.Javascript object>

In [41]:
model.learn(total_timesteps=20_000)

Logging to Training\Logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.967    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 5625     |
|    time_elapsed     | 0        |
|    total_timesteps  | 70       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.91     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 7127     |
|    time_elapsed     | 0        |
|    total_timesteps  | 189      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.856    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 7742     |
|    time_elapsed     | 0        |
|    total_timesteps  | 303      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 7396     |
|    time_elapsed     | 0        |
|    total_timesteps  | 2680     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 7387     |
|    time_elapsed     | 0        |
|    total_timesteps  | 2765     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 7437     |
|    time_elapsed     | 0        |
|    total_timesteps  | 2918     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 7425     |
|    time_elapsed     | 0        |
|    total_timesteps  | 5276     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 7431     |
|    time_elapsed     | 0        |
|    total_timesteps  | 5362     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 7445     |
|    time_elapsed     | 0        |
|    total_timesteps  | 5458     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 7437     |
|    time_elapsed     | 1        |
|    total_timesteps  | 7580     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 7409     |
|    time_elapsed     | 1        |
|    total_timesteps  | 7636     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 7400     |
|    time_elapsed     | 1        |
|    total_timesteps  | 7719     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 7433     |
|    time_elapsed     | 1        |
|    total_timesteps  | 9933     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 7429     |
|    time_elapsed     | 1        |
|    total_timesteps  | 10029    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 7419     |
|    time_elapsed     | 1        |
|    total_timesteps  | 10078    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 7449     |
|    time_elapsed     | 1        |
|    total_timesteps  | 12468    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 7450     |
|    time_elapsed     | 1        |
|    total_timesteps  | 12550    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 7459     |
|    time_elapsed     | 1        |
|    total_timesteps  | 12675    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 7500     |
|    time_elapsed     | 1        |
|    total_timesteps  | 14877    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 7504     |
|    time_elapsed     | 1        |
|    total_timesteps  | 14982    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 7508     |
|    time_elapsed     | 2        |
|    total_timesteps  | 15072    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 7535     |
|    time_elapsed     | 2        |
|    total_timesteps  | 17294    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 7533     |
|    time_elapsed     | 2        |
|    total_timesteps  | 17376    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 7531     |
|    time_elapsed     | 2        |
|    total_timesteps  | 17444    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 7516     |
|    time_elapsed     | 2        |
|    total_timesteps  | 19635    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 7512     |
|    time_elapsed     | 2        |
|    total_timesteps  | 19708    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 7510     |
|    time_elapsed     | 2        |
|    total_timesteps  | 19784    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x2204dfb7460>

<IPython.core.display.Javascript object>