## 1. Importing Dependencies

In [3]:
## Install the latest package that build for RL
# !pip install stable-baselines3[extras] 

In [4]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO
import gym
import os

## 2. Load Environment

For more info: [link](https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py)

In [5]:
environment_name = 'CartPole-v0'
env = gym.make(environment_name)

In [6]:
environment_name

'CartPole-v0'

In [7]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode, score))
# env.close()

Episode:1 Score:28.0
Episode:2 Score:9.0
Episode:3 Score:20.0
Episode:4 Score:17.0
Episode:5 Score:45.0


In [8]:
env.close()

### Understanding the Environment

In [9]:
env.action_space

Discrete(2)

In [10]:
env.action_space.sample()

0

In [11]:
env.observation_space

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)

In [12]:
env.observation_space.sample()

array([ 3.5214880e-01, -2.6648153e+38,  1.9150622e-01, -1.9533108e+38],
      dtype=float32)

## 3. Train an RL Model

**It's better to check your CUDA first in your NVIDIA** 

It can be skipped

In [13]:
# pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

In [14]:
# Make Your Directory first
log_path = os.path.join('Training', 'Logs')

In [15]:
log_path

'Training\\Logs'

In [16]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


another option for policies:
* MlpPolicy 
* CnnPolicy 
* MultiInputPolicy

In [17]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\PPO_1
-----------------------------
| time/              |      |
|    fps             | 126  |
|    iterations      | 1    |
|    time_elapsed    | 16   |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 192         |
|    iterations           | 2           |
|    time_elapsed         | 21          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.007992342 |
|    clip_fraction        | 0.101       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.00289     |
|    learning_rate        | 0.0003      |
|    loss                 | 6.39        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0161     |
|    value_loss           | 49.9        |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x229cf4c3d90>

## 4. Save and Reload Model

In [18]:
PPO_Path = os.path.join('Training', 'Saved_Model', 'PPO_Model_Cartpole')

In [19]:
model.save(PPO_Path)

In [20]:
del model

In [21]:
PPO_Path

'Training\\Saved_Model\\PPO_Model_Cartpole'

In [22]:
model = PPO.load(PPO_Path, env=env)

In [23]:
model.learn(total_timesteps=1000)

Logging to Training\Logs\PPO_2
-----------------------------
| time/              |      |
|    fps             | 689  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x229cf58f640>

## 5. Evaluation

In [24]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)



(200.0, 0.0)

In [25]:
env.close()

## 6. Test

In [26]:
episodes = 5
for episode in range(1, episodes+1):
    obs = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode, score))
# env.close()

Episode:1 Score:[200.]
Episode:2 Score:[200.]
Episode:3 Score:[200.]
Episode:4 Score:[200.]
Episode:5 Score:[200.]


In [27]:
env.close()

In [28]:
obs = env.reset()

In [35]:
obs

array([[ 0.01522005, -0.04554836, -0.01899192,  0.03984968]],
      dtype=float32)

In [36]:
action, _ = model.predict(obs)

In [34]:
env.action_space.sample()

1

In [37]:
env.step(action)

(array([[ 0.01430908,  0.14984071, -0.01819493, -0.25876436]],
       dtype=float32),
 array([1.], dtype=float32),
 array([False]),
 [{}])

## 7. Viewing Logs in Tensorboard

In [38]:
training_log_path = os.path.join(log_path, 'PPO_1')

In [39]:
training_log_path

'Training\\Logs\\PPO_1'

In [40]:
!tensorboard --logdir={training_log_path}

^C


The command tensorboard that run from jupyter will not prompt anything

but actually open the port 6006 just like if using command

To stop it press ctrl+c or interrupt the kernel

## 8. Adding a Callback to the Training Stage

In [42]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [43]:
save_path = os.path.join('Training', 'Saved_Model')

In [44]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
eval_callback = EvalCallback(env,
                             callback_on_new_best=stop_callback,
                             eval_freq=10000,
                             best_model_save_path=save_path,
                             verbose=1)

In [45]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [46]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_3
-----------------------------
| time/              |      |
|    fps             | 715  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 537         |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008320259 |
|    clip_fraction        | 0.0882      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.00205     |
|    learning_rate        | 0.0003      |
|    loss                 | 5.69        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0123     |
|    value_loss           | 45.8        |
-----------------------------------------
---



Eval num_timesteps=10000, episode_reward=200.00 +/- 0.00
Episode length: 200.00 +/- 0.00
------------------------------------------
| eval/                   |              |
|    mean_ep_length       | 200          |
|    mean_reward          | 200          |
| time/                   |              |
|    total timesteps      | 10000        |
| train/                  |              |
|    approx_kl            | 0.0068070525 |
|    clip_fraction        | 0.0529       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.614       |
|    explained_variance   | 0.22         |
|    learning_rate        | 0.0003       |
|    loss                 | 25.1         |
|    n_updates            | 40           |
|    policy_gradient_loss | -0.0148      |
|    value_loss           | 67.3         |
------------------------------------------
New best mean reward!
Stopping training because the mean reward 200.00  is above the threshold 200


<stable_baselines3.ppo.ppo.PPO at 0x22a681490d0>

## 9. Change the Policies

In [48]:
net_arch = [dict(pi=[128,128,128,128], vf=[128,128,128,128])]

In [49]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, policy_kwargs={'net_arch':net_arch})

Using cuda device


In [50]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_4
-----------------------------
| time/              |      |
|    fps             | 590  |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 454         |
|    iterations           | 2           |
|    time_elapsed         | 9           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.015335293 |
|    clip_fraction        | 0.228       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.681      |
|    explained_variance   | 0.00278     |
|    learning_rate        | 0.0003      |
|    loss                 | 4.06        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0255     |
|    value_loss           | 21.8        |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x22a6798a3d0>

## 10 Using Alternate Algorithm

In [51]:
from stable_baselines3 import DQN

In [52]:
model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [53]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration rate | 0.946    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 306      |
|    time_elapsed     | 0        |
|    total timesteps  | 114      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.898    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 521      |
|    time_elapsed     | 0        |
|    total timesteps  | 215      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.863    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 635      |
|    time_elapsed     | 0        |
|    total timesteps  | 289      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 2084     |
|    time_elapsed     | 1        |
|    total timesteps  | 2504     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 2065     |
|    time_elapsed     | 1        |
|    total timesteps  | 2559     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 2069     |
|    time_elapsed     | 1        |
|    total timesteps  | 2620     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 2368     |
|    time_elapsed     | 1        |
|    total timesteps  | 4556     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 2399     |
|    time_elapsed     | 1        |
|    total timesteps  | 4689     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 1900     |
|    time_elapsed     | 2        |
|    total timesteps  | 4824     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 2245     |
|    time_elapsed     | 3        |
|    total timesteps  | 7315     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 2253     |
|    time_elapsed     | 3        |
|    total timesteps  | 7398     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 2261     |
|    time_elapsed     | 3        |
|    total timesteps  | 7493     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 2454     |
|    time_elapsed     | 4        |
|    total timesteps  | 9829     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 2453     |
|    time_elapsed     | 4        |
|    total timesteps  | 9890     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 2459     |
|    time_elapsed     | 4        |
|    total timesteps  | 9983     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 2579     |
|    time_elapsed     | 4        |
|    total timesteps  | 12176    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 2589     |
|    time_elapsed     | 4        |
|    total timesteps  | 12298    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 2590     |
|    time_elapsed     | 4        |
|    total timesteps  | 12365    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 2663     |
|    time_elapsed     | 5        |
|    total timesteps  | 14556    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 2661     |
|    time_elapsed     | 5        |
|    total timesteps  | 14618    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 2665     |
|    time_elapsed     | 5        |
|    total timesteps  | 14684    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 2315     |
|    time_elapsed     | 7        |
|    total timesteps  | 16794    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 2315     |
|    time_elapsed     | 7        |
|    total timesteps  | 16853    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 2322     |
|    time_elapsed     | 7        |
|    total timesteps  | 16951    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 2422     |
|    time_elapsed     | 7        |
|    total timesteps  | 19339    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 2426     |
|    time_elapsed     | 8        |
|    total timesteps  | 19434    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 2433     |
|    time_elapsed     | 8        |
|    total timesteps  | 19565    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x22a67995df0>