# Setup Libraries

In [1]:
!pip install stable-baselines3[extra]

^C


In [12]:
!pip install pyglet

Collecting pyglet
  Downloading pyglet-1.5.21-py3-none-any.whl (1.1 MB)
Installing collected packages: pyglet
Successfully installed pyglet-1.5.21


You should consider upgrading via the 'c:\users\gopal\appdata\local\programs\python\python38\python.exe -m pip install --upgrade pip' command.


In [1]:
import gym
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold


# Setup Environment

In [2]:
env_name = 'CartPole-v0'

In [3]:
env = gym.make(env_name)

In [4]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()


Episode:1 Score:43.0
Episode:2 Score:18.0
Episode:3 Score:17.0
Episode:4 Score:18.0
Episode:5 Score:22.0


# Train RL Model

In [5]:
env = gym.make(env_name)
env = DummyVecEnv([lambda: env])

In [6]:
model_path = os.path.join('Saved Models', 'best_model')
save_path = os.path.join('Saved Models')
log_path = os.path.join('Logs')

In [7]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=190, verbose=1)
eval_callback = EvalCallback(env, 
                             callback_on_new_best=stop_callback, 
                             eval_freq=10000, 
                             best_model_save_path=save_path, 
                             verbose=1)


In [8]:
model = PPO('MlpPolicy', env, verbose = 1, tensorboard_log = log_path)


Using cpu device


In [9]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Logs\PPO_2
-----------------------------
| time/              |      |
|    fps             | 250  |
|    iterations      | 1    |
|    time_elapsed    | 8    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 310          |
|    iterations           | 2            |
|    time_elapsed         | 13           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0074364506 |
|    clip_fraction        | 0.0774       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.687       |
|    explained_variance   | -0.00213     |
|    learning_rate        | 0.0003       |
|    loss                 | 7.33         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.012       |
|    value_loss           | 53.1         |
-------------------------------------



Eval num_timesteps=10000, episode_reward=200.00 +/- 0.00
Episode length: 200.00 +/- 0.00
------------------------------------------
| eval/                   |              |
|    mean_ep_length       | 200          |
|    mean_reward          | 200          |
| time/                   |              |
|    total_timesteps      | 10000        |
| train/                  |              |
|    approx_kl            | 0.0073000733 |
|    clip_fraction        | 0.0683       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.608       |
|    explained_variance   | 0.214        |
|    learning_rate        | 0.0003       |
|    loss                 | 18.8         |
|    n_updates            | 40           |
|    policy_gradient_loss | -0.0162      |
|    value_loss           | 60           |
------------------------------------------
New best mean reward!
Stopping training because the mean reward 200.00  is above the threshold 190


<stable_baselines3.ppo.ppo.PPO at 0x1925fc05880>

In [10]:
model = PPO.load(model_path, env=env)

In [11]:
evaluate_policy(model, env, n_eval_episodes=5, render=True)

(200.0, 0.0)

In [12]:
env.close()