In [3]:
import gym

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

import optuna
from optuna.samplers import TPESampler

import wandb
from wandb.integration.sb3 import WandbCallback

### Creating Gym envs

In [2]:
env = make_vec_env("LunarLander-v2", n_envs=16)

In [3]:
eval_env = Monitor(gym.make("LunarLander-v2"))

### Hyperparameter Tuning

In [4]:
def run_training(params, verbose=0, save_model=False):
    model = PPO(
        policy="MlpPolicy",
        env=env,
        n_epochs=params["n_epochs"],
        gamma=params["gamma"],
        learning_rate=params["learning_rate"],
        verbose=verbose
    )
    model.learn(total_timesteps=params["total_timesteps"])
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=50)
    if save_model:
        model.save('PPO-LunarLander-v2')
    return model, mean_reward-std_reward

In [5]:
def objective(trial):
    params = {
        "n_epochs": trial.suggest_int("n_epochs", 1, 10),
        "gamma": trial.suggest_float("gamma", 0.99, 0.9999),
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-1),
        "total_timesteps": trial.suggest_int("total_timesteps", 5e5, 2e6)
    }

    config = dict(trial.params)
    config['trial.number'] = trial.number
    config['policy'] = 'MlpPolicy'
    run = wandb.init(
        project="ppo-lunarlander",
        config=config,
        sync_tensorboard=True,
        monitor_gym=True,
        reinit=True
    )

    model, reward = run_training(params, verbose=0)
    wandb.log({"reward": reward})
    run.finish(quiet=True)
    return reward

In [6]:
study = optuna.create_study(sampler=TPESampler(seed=42), study_name="PPO-LunarLander-v2", direction="maximize")
study.optimize(objective, n_trials=1)

[32m[I 2022-09-18 12:25:21,307][0m A new study created in memory with name: PPO-LunarLander-v2[0m
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrram12[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 13:09:56,970][0m Trial 0 finished with value: -735.820986247044 and parameters: {'n_epochs': 4, 'gamma': 0.9994120716334581, 'learning_rate': 0.0732020742417224, 'total_timesteps': 1397988}. Best is trial 0 with value: -735.820986247044.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03334027926127116, max=1.0)…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 14:49:23,732][0m Trial 1 finished with value: 255.2893621331975 and parameters: {'n_epochs': 2, 'gamma': 0.9915443457513284, 'learning_rate': 0.005817780380698264, 'total_timesteps': 1799265}. Best is trial 1 with value: 255.2893621331975.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03334172566731771, max=1.0)…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 17:19:50,034][0m Trial 2 finished with value: 260.40295923902914 and parameters: {'n_epochs': 7, 'gamma': 0.9970099185201808, 'learning_rate': 0.002068243584637287, 'total_timesteps': 1954865}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.033339190483093264, max=1.0…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 18:40:35,759][0m Trial 3 finished with value: -1262.700103020085 and parameters: {'n_epochs': 9, 'gamma': 0.9921021571957149, 'learning_rate': 0.01819067847103799, 'total_timesteps': 775106}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0333403746287028, max=1.0))…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 19:01:50,826][0m Trial 4 finished with value: -190.18300451548194 and parameters: {'n_epochs': 4, 'gamma': 0.9951950886731592, 'learning_rate': 0.043200182414025165, 'total_timesteps': 936844}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03333714803059896, max=1.0)…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 19:27:48,800][0m Trial 5 finished with value: -742.5212213003574 and parameters: {'n_epochs': 7, 'gamma': 0.9913809892204553, 'learning_rate': 0.029221543407036466, 'total_timesteps': 1049543}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.033340589205423994, max=1.0…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 19:48:55,335][0m Trial 6 finished with value: -742.0119162600843 and parameters: {'n_epochs': 5, 'gamma': 0.9977732420177908, 'learning_rate': 0.019975381478014392, 'total_timesteps': 1271352}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03333758513132731, max=1.0)…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 20:05:45,914][0m Trial 7 finished with value: -802.7476495720744 and parameters: {'n_epochs': 6, 'gamma': 0.9904598590859279, 'learning_rate': 0.06075840974162483, 'total_timesteps': 755786}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03333817323048909, max=1.0)…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 20:19:05,963][0m Trial 8 finished with value: -1119.5961525674804 and parameters: {'n_epochs': 1, 'gamma': 0.999393966818808, 'learning_rate': 0.0965635469871252, 'total_timesteps': 1712596}. Best is trial 2 with value: 260.40295923902914.[0m


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.03333842754364014, max=1.0)…

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[32m[I 2022-09-18 20:33:08,923][0m Trial 9 finished with value: -166.51402277765695 and parameters: {'n_epochs': 4, 'gamma': 0.9909669539286632, 'learning_rate': 0.06842646032095057, 'total_timesteps': 1160229}. Best is trial 2 with value: 260.40295923902914.[0m


In [7]:
print("Best trial score: ", study.best_trial.values)
print("Best trial params: ", study.best_trial.params)

Best trial score:  [260.40295923902914]
Best trial params:  {'n_epochs': 7, 'gamma': 0.9970099185201808, 'learning_rate': 0.002068243584637287, 'total_timesteps': 1954865}


### Recreating and Saving the Best Model

In [8]:
model, score = run_training(study.best_trial.params, verbose=1, save_model=True)

Using cuda device
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 97.2     |
|    ep_rew_mean     | -172     |
| time/              |          |
|    fps             | 4967     |
|    iterations      | 1        |
|    time_elapsed    | 6        |
|    total_timesteps | 32768    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 98          |
|    ep_rew_mean          | -133        |
| time/                   |             |
|    fps                  | 1495        |
|    iterations           | 2           |
|    time_elapsed         | 43          |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.009259807 |
|    clip_fraction        | 0.121       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.38       |
|    explained_variance   | -0.00321    |
|    learnin