[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/3_train_gail.ipynb)
# Train an Agent using Generative Adversarial Imitation Learning

The idea of generative adversarial imitation learning is to train a discriminator network to distinguish between expert trajectories and learner trajectories.
The learner is trained using a traditional reinforcement learning algorithm such as PPO and is rewarded for trajectories that make the discriminator think that it was an expert trajectory.

As usual, we first need an expert. Again, we download one from the HuggingFace model hub for convenience.

Note that we use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/main-concepts/variable_horizon.html).

In [1]:
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

pygame 2.5.2 (SDL 2.28.3, Python 3.12.3)
Hello from the pygame community. https://www.pygame.org/contribute.html


  from .autonotebook import tqdm as notebook_tqdm
Exception: code() argument 13 must be str, not int
Exception: code() argument 13 must be str, not int
Exception: code() argument 13 must be str, not int


We generate some expert trajectories, that the discriminator needs to distinguish from the learner's trajectories.

In [8]:
from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=120),
    rng=np.random.default_rng(SEED),
)

Now we are ready to set up our GAIL trainer.
Note, that the `reward_net` is actually the network of the discriminator.
We evaluate the learner before and after training so we can see if it made any progress.

First we construct a GAIL trainer ...

In [14]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

... then we evaluate it before training ...

In [15]:
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

... and train it ...

In [16]:
gail_trainer.train(500_000)

round:   0%|          | 0/30 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 36.7     |
|    gen/time/fps             | 13495    |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 1        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.497    |
|    disc/disc_acc_expert             | 0        |
|    disc/disc_acc_gen                | 0.993    |
|    disc/disc_entropy                | 0.69     |
|    disc/disc_loss                   | 0.697    |
|    disc/disc_proportion_expert_pred | 0.00342  |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   3%|▎         | 1/30 [00:02<01:26,  2.97s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 500        |
|    gen/rollout/ep_rew_mean         | 36.2       |
|    gen/rollout/ep_rew_wrapped_mean | 282        |
|    gen/time/fps                    | 13930      |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 1          |
|    gen/time/total_timesteps        | 32768      |
|    gen/train/approx_kl             | 0.00884708 |
|    gen/train/clip_fraction         | 0.0362     |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.686     |
|    gen/train/explained_variance    | -0.0494    |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.0889     |
|    gen/train/n_updates             | 5          |
|    gen/train/policy_gradient_loss  | -0.0012    |
|    gen/train/value_loss            | 4.92       |
------------

round:   7%|▋         | 2/30 [00:05<01:22,  2.93s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 35.1         |
|    gen/rollout/ep_rew_wrapped_mean | 286          |
|    gen/time/fps                    | 13914        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 49152        |
|    gen/train/approx_kl             | 0.0052242763 |
|    gen/train/clip_fraction         | 0.0196       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.685       |
|    gen/train/explained_variance    | 0.749        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.0923       |
|    gen/train/n_updates             | 10           |
|    gen/train/policy_gradient_loss  | -0.00126     |
|    gen/train/value_loss   

round:  10%|█         | 3/30 [00:08<01:19,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 35.4        |
|    gen/rollout/ep_rew_wrapped_mean | 283         |
|    gen/time/fps                    | 13899       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 65536       |
|    gen/train/approx_kl             | 0.009301297 |
|    gen/train/clip_fraction         | 0.103       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.678      |
|    gen/train/explained_variance    | 0.656       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.049       |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.00459    |
|    gen/train/value_loss            | 0.046  

round:  13%|█▎        | 4/30 [00:11<01:16,  2.95s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 38          |
|    gen/rollout/ep_rew_wrapped_mean | 276         |
|    gen/time/fps                    | 13921       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 81920       |
|    gen/train/approx_kl             | 0.013422859 |
|    gen/train/clip_fraction         | 0.175       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.669      |
|    gen/train/explained_variance    | 0.843       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.00533    |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.0091     |
|    gen/train/value_loss            | 0.019  

round:  17%|█▋        | 5/30 [00:14<01:13,  2.94s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 40.6        |
|    gen/rollout/ep_rew_wrapped_mean | 269         |
|    gen/time/fps                    | 14158       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 98304       |
|    gen/train/approx_kl             | 0.012339175 |
|    gen/train/clip_fraction         | 0.142       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.656      |
|    gen/train/explained_variance    | 0.874       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.00391    |
|    gen/train/n_updates             | 25          |
|    gen/train/policy_gradient_loss  | -0.0119     |
|    gen/train/value_loss            | 0.0215 

round:  20%|██        | 6/30 [00:17<01:10,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 44.1        |
|    gen/rollout/ep_rew_wrapped_mean | 264         |
|    gen/time/fps                    | 14026       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 114688      |
|    gen/train/approx_kl             | 0.010689682 |
|    gen/train/clip_fraction         | 0.121       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.632      |
|    gen/train/explained_variance    | 0.919       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0241     |
|    gen/train/n_updates             | 30          |
|    gen/train/policy_gradient_loss  | -0.00919    |
|    gen/train/value_loss            | 0.0221 

round:  23%|██▎       | 7/30 [00:20<01:07,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 43.9        |
|    gen/rollout/ep_rew_wrapped_mean | 260         |
|    gen/time/fps                    | 13843       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.009313423 |
|    gen/train/clip_fraction         | 0.0965      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.606      |
|    gen/train/explained_variance    | 0.931       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0156     |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.00512    |
|    gen/train/value_loss            | 0.0248 

round:  27%|██▋       | 8/30 [00:23<01:04,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 46.1        |
|    gen/rollout/ep_rew_wrapped_mean | 251         |
|    gen/time/fps                    | 13929       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 147456      |
|    gen/train/approx_kl             | 0.009131025 |
|    gen/train/clip_fraction         | 0.102       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.569      |
|    gen/train/explained_variance    | 0.95        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.02       |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.00622    |
|    gen/train/value_loss            | 0.0279 

round:  30%|███       | 9/30 [00:26<01:01,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 49.2        |
|    gen/rollout/ep_rew_wrapped_mean | 237         |
|    gen/time/fps                    | 14011       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 163840      |
|    gen/train/approx_kl             | 0.008577997 |
|    gen/train/clip_fraction         | 0.0753      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.574      |
|    gen/train/explained_variance    | 0.962       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.00466    |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.0035     |
|    gen/train/value_loss            | 0.027  

round:  33%|███▎      | 10/30 [00:29<00:58,  2.92s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 500        |
|    gen/rollout/ep_rew_mean         | 59.2       |
|    gen/rollout/ep_rew_wrapped_mean | 226        |
|    gen/time/fps                    | 13738      |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 1          |
|    gen/time/total_timesteps        | 180224     |
|    gen/train/approx_kl             | 0.00953524 |
|    gen/train/clip_fraction         | 0.118      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.587     |
|    gen/train/explained_variance    | 0.97       |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | -0.00182   |
|    gen/train/n_updates             | 50         |
|    gen/train/policy_gradient_loss  | -0.00683   |
|    gen/train/value_loss            | 0.0295     |
------------

round:  37%|███▋      | 11/30 [00:32<00:55,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 68.6        |
|    gen/rollout/ep_rew_wrapped_mean | 225         |
|    gen/time/fps                    | 13214       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 196608      |
|    gen/train/approx_kl             | 0.011665577 |
|    gen/train/clip_fraction         | 0.123       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.587      |
|    gen/train/explained_variance    | 0.972       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0138     |
|    gen/train/n_updates             | 55          |
|    gen/train/policy_gradient_loss  | -0.00691    |
|    gen/train/value_loss            | 0.0413 

round:  40%|████      | 12/30 [00:35<00:52,  2.94s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 84.6        |
|    gen/rollout/ep_rew_wrapped_mean | 226         |
|    gen/time/fps                    | 13874       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 212992      |
|    gen/train/approx_kl             | 0.010306148 |
|    gen/train/clip_fraction         | 0.136       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.591      |
|    gen/train/explained_variance    | 0.98        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0235     |
|    gen/train/n_updates             | 60          |
|    gen/train/policy_gradient_loss  | -0.0105     |
|    gen/train/value_loss            | 0.053  

round:  43%|████▎     | 13/30 [00:38<00:49,  2.94s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 97.9         |
|    gen/rollout/ep_rew_wrapped_mean | 223          |
|    gen/time/fps                    | 13967        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 229376       |
|    gen/train/approx_kl             | 0.0134653095 |
|    gen/train/clip_fraction         | 0.189        |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.578       |
|    gen/train/explained_variance    | 0.983        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.0122       |
|    gen/train/n_updates             | 65           |
|    gen/train/policy_gradient_loss  | -0.0182      |
|    gen/train/value_loss   

round:  47%|████▋     | 14/30 [00:41<00:46,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 111         |
|    gen/rollout/ep_rew_wrapped_mean | 219         |
|    gen/time/fps                    | 13950       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 245760      |
|    gen/train/approx_kl             | 0.013135181 |
|    gen/train/clip_fraction         | 0.184       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.565      |
|    gen/train/explained_variance    | 0.987       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00456     |
|    gen/train/n_updates             | 70          |
|    gen/train/policy_gradient_loss  | -0.0165     |
|    gen/train/value_loss            | 0.0704 

round:  50%|█████     | 15/30 [00:43<00:43,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 121         |
|    gen/rollout/ep_rew_wrapped_mean | 219         |
|    gen/time/fps                    | 13556       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 262144      |
|    gen/train/approx_kl             | 0.010006411 |
|    gen/train/clip_fraction         | 0.116       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.547      |
|    gen/train/explained_variance    | 0.987       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.00337    |
|    gen/train/n_updates             | 75          |
|    gen/train/policy_gradient_loss  | -0.0104     |
|    gen/train/value_loss            | 0.0583 

round:  53%|█████▎    | 16/30 [00:46<00:41,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 132         |
|    gen/rollout/ep_rew_wrapped_mean | 229         |
|    gen/time/fps                    | 13761       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 278528      |
|    gen/train/approx_kl             | 0.007364448 |
|    gen/train/clip_fraction         | 0.0785      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.531      |
|    gen/train/explained_variance    | 0.986       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0284      |
|    gen/train/n_updates             | 80          |
|    gen/train/policy_gradient_loss  | -0.00645    |
|    gen/train/value_loss            | 0.0524 

round:  57%|█████▋    | 17/30 [00:49<00:38,  2.93s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 143          |
|    gen/rollout/ep_rew_wrapped_mean | 237          |
|    gen/time/fps                    | 13778        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 294912       |
|    gen/train/approx_kl             | 0.0077552325 |
|    gen/train/clip_fraction         | 0.0792       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.515       |
|    gen/train/explained_variance    | 0.984        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.0539       |
|    gen/train/n_updates             | 85           |
|    gen/train/policy_gradient_loss  | -0.00724     |
|    gen/train/value_loss   

round:  60%|██████    | 18/30 [00:52<00:35,  2.95s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 156          |
|    gen/rollout/ep_rew_wrapped_mean | 244          |
|    gen/time/fps                    | 13909        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 311296       |
|    gen/train/approx_kl             | 0.0066608293 |
|    gen/train/clip_fraction         | 0.0731       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.5         |
|    gen/train/explained_variance    | 0.988        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.0133       |
|    gen/train/n_updates             | 90           |
|    gen/train/policy_gradient_loss  | -0.00436     |
|    gen/train/value_loss   

round:  63%|██████▎   | 19/30 [00:55<00:32,  2.95s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 172         |
|    gen/rollout/ep_rew_wrapped_mean | 246         |
|    gen/time/fps                    | 13877       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 327680      |
|    gen/train/approx_kl             | 0.007317047 |
|    gen/train/clip_fraction         | 0.0761      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.477      |
|    gen/train/explained_variance    | 0.988       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0328      |
|    gen/train/n_updates             | 95          |
|    gen/train/policy_gradient_loss  | -0.0047     |
|    gen/train/value_loss            | 0.06   

round:  67%|██████▋   | 20/30 [00:58<00:29,  2.94s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 199         |
|    gen/rollout/ep_rew_wrapped_mean | 251         |
|    gen/time/fps                    | 13728       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 344064      |
|    gen/train/approx_kl             | 0.008654136 |
|    gen/train/clip_fraction         | 0.0933      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.455      |
|    gen/train/explained_variance    | 0.992       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0378      |
|    gen/train/n_updates             | 100         |
|    gen/train/policy_gradient_loss  | -0.00585    |
|    gen/train/value_loss            | 0.0675 

round:  70%|███████   | 21/30 [01:01<00:26,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 218         |
|    gen/rollout/ep_rew_wrapped_mean | 256         |
|    gen/time/fps                    | 13539       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 360448      |
|    gen/train/approx_kl             | 0.008946661 |
|    gen/train/clip_fraction         | 0.103       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.424      |
|    gen/train/explained_variance    | 0.989       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0222      |
|    gen/train/n_updates             | 105         |
|    gen/train/policy_gradient_loss  | -0.00514    |
|    gen/train/value_loss            | 0.071  

round:  73%|███████▎  | 22/30 [01:04<00:23,  2.94s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 228         |
|    gen/rollout/ep_rew_wrapped_mean | 252         |
|    gen/time/fps                    | 13895       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 376832      |
|    gen/train/approx_kl             | 0.007514088 |
|    gen/train/clip_fraction         | 0.0906      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.416      |
|    gen/train/explained_variance    | 0.987       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0118      |
|    gen/train/n_updates             | 110         |
|    gen/train/policy_gradient_loss  | -0.0034     |
|    gen/train/value_loss            | 0.0659 

round:  77%|███████▋  | 23/30 [01:07<00:20,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 245         |
|    gen/rollout/ep_rew_wrapped_mean | 235         |
|    gen/time/fps                    | 13565       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 393216      |
|    gen/train/approx_kl             | 0.012193836 |
|    gen/train/clip_fraction         | 0.113       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.42       |
|    gen/train/explained_variance    | 0.98        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0178      |
|    gen/train/n_updates             | 115         |
|    gen/train/policy_gradient_loss  | -0.00309    |
|    gen/train/value_loss            | 0.0749 

round:  80%|████████  | 24/30 [01:10<00:17,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 260         |
|    gen/rollout/ep_rew_wrapped_mean | 224         |
|    gen/time/fps                    | 13797       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 409600      |
|    gen/train/approx_kl             | 0.010835385 |
|    gen/train/clip_fraction         | 0.106       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.405      |
|    gen/train/explained_variance    | 0.991       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.036       |
|    gen/train/n_updates             | 120         |
|    gen/train/policy_gradient_loss  | -0.00554    |
|    gen/train/value_loss            | 0.0559 

round:  83%|████████▎ | 25/30 [01:13<00:14,  2.92s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 288          |
|    gen/rollout/ep_rew_wrapped_mean | 229          |
|    gen/time/fps                    | 13888        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 425984       |
|    gen/train/approx_kl             | 0.0094377035 |
|    gen/train/clip_fraction         | 0.0903       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.407       |
|    gen/train/explained_variance    | 0.988        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.00805      |
|    gen/train/n_updates             | 125          |
|    gen/train/policy_gradient_loss  | -0.00168     |
|    gen/train/value_loss   

round:  87%|████████▋ | 26/30 [01:16<00:11,  2.92s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 304         |
|    gen/rollout/ep_rew_wrapped_mean | 258         |
|    gen/time/fps                    | 13384       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 442368      |
|    gen/train/approx_kl             | 0.008577564 |
|    gen/train/clip_fraction         | 0.0821      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.392      |
|    gen/train/explained_variance    | 0.989       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0139      |
|    gen/train/n_updates             | 130         |
|    gen/train/policy_gradient_loss  | -0.00326    |
|    gen/train/value_loss            | 0.0671 

round:  90%|█████████ | 27/30 [01:19<00:08,  2.93s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 329         |
|    gen/rollout/ep_rew_wrapped_mean | 277         |
|    gen/time/fps                    | 13905       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 1           |
|    gen/time/total_timesteps        | 458752      |
|    gen/train/approx_kl             | 0.009542388 |
|    gen/train/clip_fraction         | 0.105       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.374      |
|    gen/train/explained_variance    | 0.994       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0109     |
|    gen/train/n_updates             | 135         |
|    gen/train/policy_gradient_loss  | -0.00766    |
|    gen/train/value_loss            | 0.091  

round:  93%|█████████▎| 28/30 [01:22<00:05,  2.92s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 362          |
|    gen/rollout/ep_rew_wrapped_mean | 300          |
|    gen/time/fps                    | 14072        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 475136       |
|    gen/train/approx_kl             | 0.0100263115 |
|    gen/train/clip_fraction         | 0.12         |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.352       |
|    gen/train/explained_variance    | 0.994        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.0199       |
|    gen/train/n_updates             | 140          |
|    gen/train/policy_gradient_loss  | -0.00725     |
|    gen/train/value_loss   

round:  97%|█████████▋| 29/30 [01:24<00:02,  2.90s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 413          |
|    gen/rollout/ep_rew_wrapped_mean | 315          |
|    gen/time/fps                    | 13692        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 1            |
|    gen/time/total_timesteps        | 491520       |
|    gen/train/approx_kl             | 0.0072787697 |
|    gen/train/clip_fraction         | 0.095        |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.318       |
|    gen/train/explained_variance    | 0.992        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.0271       |
|    gen/train/n_updates             | 145          |
|    gen/train/policy_gradient_loss  | -0.00634     |
|    gen/train/value_loss   

round: 100%|██████████| 30/30 [01:27<00:00,  2.93s/it]


... and finally evaluate it again.

In [17]:
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):

In [18]:
print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

Rewards before training: 102.6 +/- 24.11514047232568
Rewards after training: 487.65 +/- 24.725442362069074
