In [3]:
import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

env = gym.make("CartPole-v1")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(1000)  # Note: set to 100000 to train a proficient expert

<stable_baselines3.ppo.ppo.PPO at 0x1ff3da553c0>

In [4]:
import tempfile
import gym
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer

venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])


bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    rng=np.random.default_rng(),
)

with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
    print(tmpdir)
    dagger_trainer = SimpleDAggerTrainer(
        venv=venv,
        scratch_dir=tmpdir,
        expert_policy=expert,
        bc_trainer=bc_trainer,
        rng=np.random.default_rng(),
    )

    dagger_trainer.train(2000)

C:\Users\mnl\AppData\Local\Temp\dagger_example_14b2473b


0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 36.5      |
|    loss           | 0.693     |
|    neglogp        | 0.694     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 43        |
|    return_mean    | 20.6      |
|    return_min     | 9         |
|    return_std     | 11.8      |
---------------------------------


64batch [00:00, 157.00batch/s]
0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000631 |
|    entropy        | 0.631     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 41        |
|    loss           | 0.503     |
|    neglogp        | 0.504     |
|    prob_true_act  | 0.618     |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 149       |
|    return_mean    | 70        |
|    return_min     | 38        |
|    return_std     | 40.7      |
---------------------------------


128batch [00:00, 206.32batch/s]
0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000186 |
|    entropy        | 0.186     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 57.4      |
|    loss           | 0.0982    |
|    neglogp        | 0.0984    |
|    prob_true_act  | 0.919     |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 71        |
|    return_mean    | 61.6      |
|    return_min     | 45        |
|    return_std     | 9.2       |
---------------------------------


192batch [00:00, 234.79batch/s]
0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -8.8e-05 |
|    entropy        | 0.088    |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 72.4     |
|    loss           | 0.0634   |
|    neglogp        | 0.0635   |
|    prob_true_act  | 0.951    |
|    samples_so_far | 32       |
| rollout/          |          |
|    return_max     | 107      |
|    return_mean    | 70.4     |
|    return_min     | 42       |
|    return_std     | 23.9     |
--------------------------------


260batch [00:00, 270.64batch/s]


In [5]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(dagger_trainer.policy, env, 10)
print(reward)



63.3
