# Installing the dependencies

In [73]:
# !pip install stable-baselines3[extra]

Stable Baselines3(SB3) is a set of reliable implementations of reinforcement learning algorithms in Pytorch. 

In [74]:
# !pip install gym --upgrade

Gym is a framework that helps to develop and compare reinforcement learning algorithm

In [4]:
import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


# Load Environments

In [10]:
environment_name = 'CartPole-v1'
env = gym.make(environment_name)

In [12]:
env.reset()

(array([-0.04105214, -0.04845896, -0.02119668,  0.0456057 ], dtype=float32),
 {})

In [13]:
print(env.step(1))

(array([-0.04202132,  0.14696044, -0.02028456, -0.2536889 ], dtype=float32), 1.0, False, False, {})


There is a truncated flag and a done flag, the latter is intuitive and is acheived once the model reaches any terminal stage, bu the former could be when say it doesn't reaches a terminal stage but it could be ended for example if the time runs out or the energy is finished

In [18]:
episodes = 5
for episode in range(1, episodes + 1):
    state = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, truncated, info = env.step(action)  
        done = done or truncated  
        score += reward
    print(f'Episode: {episode} Score: {score}')
env.close()

Episode: 1 Score: 18.0
Episode: 2 Score: 39.0
Episode: 3 Score: 16.0
Episode: 4 Score: 20.0
Episode: 5 Score: 12.0


# Understanding the Environment

In [19]:
env.action_space

Discrete(2)

In [22]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

Cart position, Cart Velocity, Pole angle and Pole Angular Velocity

In [23]:
env.action_space.sample()

1

In [24]:
env.observation_space.sample()

array([2.6141179e+00, 1.3114164e+37, 6.8146497e-02, 8.2323025e+37],
      dtype=float32)

# Training RL Model

In [25]:
log_path = os.path.join('Training','Logs')

In [26]:
log_path

'Training\\Logs'

In [27]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy',env,verbose = 1, tensorboard_log = log_path)



Using cpu device


The first policy MlpPolicy stands for multilayer neural network policy. There are also other kinds of policy such as CNNPolicy, MultiInputPolicy and LSTMPolicy

In [29]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\PPO_1


  if not isinstance(terminated, (bool, np.bool8)):


-----------------------------
| time/              |      |
|    fps             | 1581 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1101        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008944992 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | -0.0171     |
|    learning_rate        | 0.0003      |
|    loss                 | 9.27        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0195     |
|    value_loss           | 59.8        |
-----------------------------------------
----------------------------------

<stable_baselines3.ppo.ppo.PPO at 0x25a792fe640>

In [30]:
PPO_Path = os.path.join('Training','Saved Models','PPO_Model_CartPole')

In [32]:
model.save(PPO_Path)

In [33]:
del model

In [34]:
model = PPO.load(PPO_Path,env)

# Testing & Evaluation

In [36]:
evaluate_policy(model, env, n_eval_episodes = 10, render = True)



(493.1, 18.785366645343924)

In [37]:
env.close()

In [40]:
action,_ = model.predict(obs)

In [41]:
action

array([0], dtype=int64)

In [44]:
episodes = 5
for episode in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action,_  = model.predict(obs)
        obs, reward, done, info = env.step(action)  
        score += reward
    print(f'Episode: {episode} Score: {score}')
env.close()

Episode: 1 Score: [500.]
Episode: 2 Score: [500.]
Episode: 3 Score: [479.]
Episode: 4 Score: [500.]
Episode: 5 Score: [500.]


In [45]:
env.close()

In [46]:
env.step(action)

(array([[ 0.02975142,  0.14879909,  0.0157673 , -0.32864133]],
       dtype=float32),
 array([1.], dtype=float32),
 array([False]),
 [{'TimeLimit.truncated': False}])

# Viewing Logs in Tensorboard

In [50]:
training_log_path = os.path.join(log_path,'PPO_1')

In [51]:
training_log_path

'Training\\Logs\\PPO_1'

In [53]:
# !tensorboard --logdir={training_log_path}

# Adding a callback to the training Stage

In [54]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [55]:
save_path = os.path.join('Training','Saved Models')

In [57]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold= 500, verbose = 1)
eval_callback = EvalCallback(env, 
                            callback_on_new_best = stop_callback,
                            eval_freq = 10000,
                            best_model_save_path = save_path,
                            verbose = 1)

In [58]:
model = PPO('MlpPolicy',env,verbose = 1,tensorboard_log = log_path)

Using cpu device


In [59]:
model.learn(total_timesteps = 200000, callback= eval_callback)

Logging to Training\Logs\PPO_2
-----------------------------
| time/              |      |
|    fps             | 1923 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1094        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008767787 |
|    clip_fraction        | 0.105       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | -0.00312    |
|    learning_rate        | 0.0003      |
|    loss                 | 9.12        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0156     |
|    value_loss           | 62.2        |
-----------------------------------------
---

------------------------------------------
| time/                   |              |
|    fps                  | 711          |
|    iterations           | 12           |
|    time_elapsed         | 34           |
|    total_timesteps      | 24576        |
| train/                  |              |
|    approx_kl            | 0.0045158183 |
|    clip_fraction        | 0.0495       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.55        |
|    explained_variance   | 0.919        |
|    learning_rate        | 0.0003       |
|    loss                 | 1            |
|    n_updates            | 110          |
|    policy_gradient_loss | -0.00992     |
|    value_loss           | 12.5         |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 715         |
|    iterations           | 13          |
|    time_elapsed         | 37          |
|    total_times

<stable_baselines3.ppo.ppo.PPO at 0x25a125efe80>

# Changing Policies

In [62]:
net_arch = dict(pi=[128,128,128,128], vf=[128,128,128,128])

In [63]:
model = PPO('MlpPolicy',env,tensorboard_log = log_path,policy_kwargs = {'net_arch': net_arch})

In [65]:
model.learn(total_timesteps= 20000, callback = eval_callback)

Eval num_timesteps=10000, episode_reward=373.60 +/- 104.94
Episode length: 373.60 +/- 104.94
Eval num_timesteps=20000, episode_reward=500.00 +/- 0.00
Episode length: 500.00 +/- 0.00


<stable_baselines3.ppo.ppo.PPO at 0x25a12510b50>

# Using alternate algorithm(DQN)

In [66]:
from stable_baselines3 import DQN

In [70]:
model = DQN('MlpPolicy',env,verbose = 1,tensorboard_log = log_path)

Using cpu device


In [71]:
model.learn(total_timesteps = 20000)

Logging to Training\Logs\DQN_2
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.952    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 4249     |
|    time_elapsed     | 0        |
|    total_timesteps  | 102      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.915    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 2224     |
|    time_elapsed     | 0        |
|    total_timesteps  | 179      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.545    |
|    n_updates        | 19       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.883    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.375    |
| time/               |          |
|    episodes         | 80       |
|    fps              | 1318     |
|    time_elapsed     | 0        |
|    total_timesteps  | 1315     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.102    |
|    n_updates        | 303      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.353    |
| time/               |          |
|    episodes         | 84       |
|    fps              | 1299     |
|    time_elapsed     | 1        |
|    total_timesteps  | 1363     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.084    |
|    n_updates        | 315      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 156      |
|    fps              | 1174     |
|    time_elapsed     | 1        |
|    total_timesteps  | 2144     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0182   |
|    n_updates        | 510      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 160      |
|    fps              | 1171     |
|    time_elapsed     | 1        |
|    total_timesteps  | 2184     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0154   |
|    n_updates        | 520      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 232      |
|    fps              | 1092     |
|    time_elapsed     | 2        |
|    total_timesteps  | 2893     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00437  |
|    n_updates        | 698      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 236      |
|    fps              | 1091     |
|    time_elapsed     | 2        |
|    total_timesteps  | 2934     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00325  |
|    n_updates        | 708      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 308      |
|    fps              | 1064     |
|    time_elapsed     | 3        |
|    total_timesteps  | 3679     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00135  |
|    n_updates        | 894      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 312      |
|    fps              | 1063     |
|    time_elapsed     | 3        |
|    total_timesteps  | 3719     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00123  |
|    n_updates        | 904      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 384      |
|    fps              | 1062     |
|    time_elapsed     | 4        |
|    total_timesteps  | 4462     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000571 |
|    n_updates        | 1090     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 388      |
|    fps              | 1063     |
|    time_elapsed     | 4        |
|    total_timesteps  | 4500     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000537 |
|    n_updates        | 1099     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 460      |
|    fps              | 1040     |
|    time_elapsed     | 5        |
|    total_timesteps  | 5286     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000271 |
|    n_updates        | 1296     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 464      |
|    fps              | 1038     |
|    time_elapsed     | 5        |
|    total_timesteps  | 5329     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000228 |
|    n_updates        | 1307     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 536      |
|    fps              | 989      |
|    time_elapsed     | 6        |
|    total_timesteps  | 6126     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00019  |
|    n_updates        | 1506     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 987      |
|    time_elapsed     | 6        |
|    total_timesteps  | 6169     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000146 |
|    n_updates        | 1517     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 612      |
|    fps              | 949      |
|    time_elapsed     | 7        |
|    total_timesteps  | 7002     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00016  |
|    n_updates        | 1725     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 616      |
|    fps              | 946      |
|    time_elapsed     | 7        |
|    total_timesteps  | 7055     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000216 |
|    n_updates        | 1738     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 688      |
|    fps              | 917      |
|    time_elapsed     | 8        |
|    total_timesteps  | 7927     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000299 |
|    n_updates        | 1956     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 692      |
|    fps              | 915      |
|    time_elapsed     | 8        |
|    total_timesteps  | 7975     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000168 |
|    n_updates        | 1968     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 894      |
|    time_elapsed     | 9        |
|    total_timesteps  | 8858     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000124 |
|    n_updates        | 2189     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 768      |
|    fps              | 893      |
|    time_elapsed     | 9        |
|    total_timesteps  | 8911     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 6.33e-05 |
|    n_updates        | 2202     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 840      |
|    fps              | 876      |
|    time_elapsed     | 11       |
|    total_timesteps  | 9874     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 5.95e-05 |
|    n_updates        | 2443     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 844      |
|    fps              | 876      |
|    time_elapsed     | 11       |
|    total_timesteps  | 9920     |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000131 |
|    n_updates        | 2454     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 916      |
|    fps              | 862      |
|    time_elapsed     | 12       |
|    total_timesteps  | 10644    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00826  |
|    n_updates        | 2635     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 920      |
|    fps              | 861      |
|    time_elapsed     | 12       |
|    total_timesteps  | 10681    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0485   |
|    n_updates        | 2645     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 992      |
|    fps              | 850      |
|    time_elapsed     | 13       |
|    total_timesteps  | 11406    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00452  |
|    n_updates        | 2826     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 996      |
|    fps              | 849      |
|    time_elapsed     | 13       |
|    total_timesteps  | 11452    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0128   |
|    n_updates        | 2837     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1068     |
|    fps              | 846      |
|    time_elapsed     | 14       |
|    total_timesteps  | 12322    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00765  |
|    n_updates        | 3055     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1072     |
|    fps              | 845      |
|    time_elapsed     | 14       |
|    total_timesteps  | 12370    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0311   |
|    n_updates        | 3067     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1144     |
|    fps              | 840      |
|    time_elapsed     | 16       |
|    total_timesteps  | 13857    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0163   |
|    n_updates        | 3439     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1148     |
|    fps              | 839      |
|    time_elapsed     | 16       |
|    total_timesteps  | 13963    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0295   |
|    n_updates        | 3465     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1220     |
|    fps              | 830      |
|    time_elapsed     | 18       |
|    total_timesteps  | 15470    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0202   |
|    n_updates        | 3842     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1224     |
|    fps              | 830      |
|    time_elapsed     | 18       |
|    total_timesteps  | 15565    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0146   |
|    n_updates        | 3866     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1296     |
|    fps              | 825      |
|    time_elapsed     | 21       |
|    total_timesteps  | 17444    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.024    |
|    n_updates        | 4335     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1300     |
|    fps              | 825      |
|    time_elapsed     | 21       |
|    total_timesteps  | 17523    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0112   |
|    n_updates        | 4355     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1372     |
|    fps              | 826      |
|    time_elapsed     | 23       |
|    total_timesteps  | 19621    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0334   |
|    n_updates        | 4880     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 1376     |
|    fps              | 826      |
|    time_elapsed     | 23       |
|    total_timesteps  | 19713    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0164   |
|    n_updates        | 4903     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rat

<stable_baselines3.dqn.dqn.DQN at 0x25a125f69a0>

In [72]:
episodes = 5
for episode in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action,_  = model.predict(obs)
        obs, reward, done, info = env.step(action)  
        score += reward
    print(f'Episode: {episode} Score: {score}')
env.close()

Episode: 1 Score: [15.]
Episode: 2 Score: [19.]
Episode: 3 Score: [21.]
Episode: 4 Score: [17.]
Episode: 5 Score: [30.]
