In [9]:
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder

env = make_vec_env(
    "HalfCheetah-v4",
    rng=np.random.default_rng(),
    n_envs=1,
    env_make_kwargs={"render_mode": "rgb_array"},
)

In [10]:
expert = load_policy(
    "sac-huggingface",
    organization="sb3",
    env_name="HalfCheetah-v3",
    venv=env,
)

sac-HalfCheetah-v3.zip:   0%|          | 0.00/3.24M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [11]:
import imageio
images = []
obs = env.reset()
print(env.render_mode)
done = False
img = env.render()
while not done:
    images.append(img)
    action, _ = expert.predict(obs)
    obs, reward, done, info = env.step(action)
    img = env.render()
    
print(len(images))


rgb_array
1000


In [12]:
from datetime import datetime
imageio.mimsave(f'src/videos/imitation_render_{datetime.now().strftime("%d_%m_%Y_%H_%M")}.gif', images)

In [15]:
import tempfile

from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    rng=np.random.default_rng(),
)

with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
    print(tmpdir)
    dagger_trainer = SimpleDAggerTrainer(
        venv=env,
        scratch_dir=tmpdir,
        expert_policy=expert,
        bc_trainer=bc_trainer,
        rng=np.random.default_rng(),
    )

    dagger_trainer.train(10000)

C:\Users\xavier\AppData\Local\Temp\dagger_example_futb_a31


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -0.00851 |
|    entropy        | 8.51     |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 98.5     |
|    loss           | 7.43     |
|    neglogp        | 7.44     |
|    prob_true_act  | 0.000627 |
|    samples_so_far | 32       |
| rollout/          |          |
|    return_max     | -244     |
|    return_mean    | -359     |
|    return_min     | -437     |
|    return_std     | 64.1     |
--------------------------------


79batch [00:04, 40.43batch/s]
176batch [00:05, 129.96batch/s]A
269batch [00:05, 194.34batch/s][A
366batch [00:05, 226.69batch/s][A
372batch [00:06, 61.81batch/s] [A


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -0.00599 |
|    entropy        | 5.99     |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 110      |
|    loss           | 3.75     |
|    neglogp        | 3.76     |
|    prob_true_act  | 0.0325   |
|    samples_so_far | 32       |
| rollout/          |          |
|    return_max     | -236     |
|    return_mean    | -296     |
|    return_min     | -342     |
|    return_std     | 39.6     |
--------------------------------


168batch [00:03, 131.08batch/s]
357batch [00:04, 224.99batch/s][A
479batch [00:05, 228.57batch/s][A

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 500      |
|    ent_loss       | -0.0034  |
|    entropy        | 3.4      |
|    epoch          | 2        |
|    l2_loss        | 0        |
|    l2_norm        | 124      |
|    loss           | 1.25     |
|    neglogp        | 1.25     |
|    prob_true_act  | 0.396    |
|    samples_so_far | 16032    |
| rollout/          |          |
|    return_max     | 84.4     |
|    return_mean    | -140     |
|    return_min     | -401     |
|    return_std     | 166      |
--------------------------------


551batch [00:08, 44.56batch/s] 
745batch [00:09, 180.80batch/s][A
748batch [00:09, 80.70batch/s] [A


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

0batch [00:00, ?batch/s]

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 0        |
|    ent_loss       | -0.00235 |
|    entropy        | 2.35     |
|    epoch          | 0        |
|    l2_loss        | 0        |
|    l2_norm        | 130      |
|    loss           | 1.15     |
|    neglogp        | 1.15     |
|    prob_true_act  | 1        |
|    samples_so_far | 32       |
| rollout/          |          |
|    return_max     | 619      |
|    return_mean    | 175      |
|    return_min     | -252     |
|    return_std     | 297      |
--------------------------------


264batch [00:04, 200.42batch/s]
480batch [00:05, 228.70batch/s][A

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 500      |
|    ent_loss       | -0.00128 |
|    entropy        | 1.28     |
|    epoch          | 1        |
|    l2_loss        | 0        |
|    l2_norm        | 139      |
|    loss           | 0.432    |
|    neglogp        | 0.434    |
|    prob_true_act  | 2.51     |
|    samples_so_far | 16032    |
| rollout/          |          |
|    return_max     | 591      |
|    return_mean    | 217      |
|    return_min     | -229     |
|    return_std     | 283      |
--------------------------------


543batch [00:08, 36.15batch/s] 
843batch [00:10, 191.37batch/s][A
982batch [00:11, 187.67batch/s][A

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 1000      |
|    ent_loss       | -0.000847 |
|    entropy        | 0.847     |
|    epoch          | 3         |
|    l2_loss        | 0         |
|    l2_norm        | 147       |
|    loss           | -0.6      |
|    neglogp        | -0.599    |
|    prob_true_act  | 4.34      |
|    samples_so_far | 32032     |
| rollout/          |           |
|    return_max     | 495       |
|    return_mean    | 197       |
|    return_min     | -83.2     |
|    return_std     | 186       |
---------------------------------


1107batch [00:15, 75.54batch/s]
1124batch [00:15, 73.36batch/s][A


Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000732 |
|    entropy        | 0.732     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 149       |
|    loss           | -0.861    |
|    neglogp        | -0.86     |
|    prob_true_act  | 4.73      |
|    samples_so_far | 32        |
| rollout/          |           |
|    return_max     | 764       |
|    return_mean    | 449       |
|    return_min     | -32.3     |
|    return_std     | 268       |
---------------------------------


360batch [00:05, 184.21batch/s]
493batch [00:06, 183.96batch/s][A

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 500       |
|    ent_loss       | -0.000849 |
|    entropy        | 0.849     |
|    epoch          | 1         |
|    l2_loss        | 0         |
|    l2_norm        | 154       |
|    loss           | -0.216    |
|    neglogp        | -0.215    |
|    prob_true_act  | 4.12      |
|    samples_so_far | 16032     |
| rollout/          |           |
|    return_max     | 867       |
|    return_mean    | 452       |
|    return_min     | 164       |
|    return_std     | 234       |
---------------------------------


734batch [00:10, 168.12batch/s]
981batch [00:11, 197.27batch/s][A

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 1000     |
|    ent_loss       | -0.00086 |
|    entropy        | 0.86     |
|    epoch          | 2        |
|    l2_loss        | 0        |
|    l2_norm        | 158      |
|    loss           | 1.01     |
|    neglogp        | 1.01     |
|    prob_true_act  | 3.67     |
|    samples_so_far | 32032    |
| rollout/          |          |
|    return_max     | 919      |
|    return_mean    | 376      |
|    return_min     | -83.6    |
|    return_std     | 333      |
--------------------------------


1119batch [00:15, 90.93batch/s]
1497batch [00:17, 192.63batch/s][A
1500batch [00:17, 84.41batch/s] [A


In [16]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(dagger_trainer.policy, env, 20)
print(reward)

152.50439255


In [18]:
import imageio
images_trainer = []
obs = env.reset()
print(env.render_mode)
done = False
img = env.render()
while not done:
    images_trainer.append(img)
    action, _ = dagger_trainer.policy.predict(obs)
    obs, reward, done, info = env.step(action)
    img = env.render()
    
print(len(images_trainer))

rgb_array
1000


In [19]:
from datetime import datetime
imageio.mimsave(f'src/videos/dagger_trainer_{datetime.now().strftime("%d_%m_%Y_%H_%M")}.gif', images_trainer)