In [9]:
import os
import sys

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gymnasium import spaces
from rl_zoo3.train import train
from stable_baselines3 import PPO
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import minari
from minari import DataCollector


torch.manual_seed(42)

<torch._C.Generator at 0x10fc754f0>

In [10]:
# sys.argv = ["python", "--algo", "ppo", "--env", "CartPole-v1"]
# train()

In [11]:
# env = DataCollector(gym.make('CartPole-v1'))
# path = os.path.abspath('') + '/logs/ppo/CartPole-v1_1/best_model'
# agent = PPO.load(path)

# total_episodes = 1_000
# # total_episodes = 1

# for i in tqdm(range(total_episodes)):
#     obs, _ = env.reset(seed=42)
#     while True:
#         action, _ = agent.predict(obs)
#         obs, rew, terminated, truncated, info = env.step(action)

#         if terminated or truncated:
#             break

In [12]:
# dataset = env.create_dataset(
#     dataset_id="cartpole/expert-v0",
#     algorithm_name="ExpertPolicy",
#     code_permalink="https://minari.farama.org/tutorials/behavioral_cloning",
#     author="Farama",
#     author_email="contact@farama.org"
# )

In [13]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [14]:
def collate_fn(batch):
    return {
        "id": torch.Tensor([x.id for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.observations) for x in batch],
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.actions) for x in batch],
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.rewards) for x in batch],
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.terminations) for x in batch],
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.truncations) for x in batch],
            batch_first=True
        )
    }

In [None]:
import datetime

NUM_EPOCHS = 64
EVAL_EVERY_N_EPOCHS = 5
now = datetime.datetime.now()

minari_dataset = minari.load_dataset("cartpole/expert-v0")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

env = minari_dataset.recover_environment()
observation_space = env.observation_space
action_space = env.action_space
assert isinstance(observation_space, spaces.Box)
assert isinstance(action_space, spaces.Discrete)

policy_net = PolicyNetwork(np.prod(observation_space.shape), action_space.n)
optimizer = torch.optim.Adam(policy_net.parameters())
loss_fn = nn.CrossEntropyLoss()

# Set up eval environment for periodic testing
eval_env = minari_dataset.recover_environment()

for epoch in range(NUM_EPOCHS):
    for batch in dataloader:
        a_pred = policy_net(batch['observations'][:, :-1])
        a_hat = F.one_hot(batch["actions"].type(torch.int64))
        loss = loss_fn(a_pred, a_hat.type(torch.float32))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch: {epoch}/{NUM_EPOCHS}, Loss: {loss.item()}")

    # Periodic evaluation
    if (epoch + 1) % EVAL_EVERY_N_EPOCHS == 0:
        obs, _ = eval_env.reset(seed=42)
        terminated, truncated = False, False
        accumulated_rew = 0
        while not (terminated or truncated):
            action = policy_net(torch.Tensor(obs)).argmax().item()
            obs, reward, terminated, truncated, _ = eval_env.step(action)
            accumulated_rew += reward
        print(f"Eval after epoch {epoch+1}: reward = {accumulated_rew}")


Epoch: 0/64, Loss: 1542.4703369140625
Epoch: 1/64, Loss: 1530.2423095703125
Epoch: 2/64, Loss: 1517.13037109375
Epoch: 3/64, Loss: 1503.80517578125
Epoch: 4/64, Loss: 1491.7354736328125
Eval after epoch 5: reward = 29.0
Epoch: 5/64, Loss: 1482.278076171875
Epoch: 6/64, Loss: 1475.798095703125
Epoch: 7/64, Loss: 1472.278564453125
Epoch: 8/64, Loss: 1469.7216796875
Epoch: 9/64, Loss: 1468.1065673828125
Eval after epoch 10: reward = 56.0
Epoch: 10/64, Loss: 1467.03759765625
Epoch: 11/64, Loss: 1465.8514404296875
Epoch: 12/64, Loss: 1464.6651611328125
Epoch: 13/64, Loss: 1463.3984375
Epoch: 14/64, Loss: 1462.1812744140625
Eval after epoch 15: reward = 98.0
Epoch: 15/64, Loss: 1461.126708984375
Epoch: 16/64, Loss: 1460.30712890625
Epoch: 17/64, Loss: 1459.1807861328125
Epoch: 18/64, Loss: 1458.323486328125
Epoch: 19/64, Loss: 1457.716552734375
Eval after epoch 20: reward = 123.0
Epoch: 20/64, Loss: 1456.4222412109375
Epoch: 21/64, Loss: 1455.5244140625
Epoch: 22/64, Loss: 1455.338623046875


In [None]:
import os
import imageio

# Set up directories and video saving vars
env_name = "CartPole-v1"
date_str = now.strftime("%Y-%m-%d")
time_str = now.strftime("%H%M%S")
base_dir = f"videos/offline/bc/{env_name}/{date_str}_{time_str}"
os.makedirs(base_dir, exist_ok=True)

video_path = os.path.join(
    base_dir,
    f"epoch={epoch+1}_reward={{REWARD_PLACEHOLDER}}.mp4"
)
# Note: reward is not known yet at time of instantiation, so we'll rename at the end

env = gym.make(
    "CartPole-v1",
    render_mode="rgb_array"
)

frames = []
obs, _ = env.reset(seed=42)
terminated, truncated = False, False
accumulated_rew = 0
while not (terminated or truncated):
    frame = env.render()
    frames.append(frame)
    action = policy_net(torch.Tensor(obs)).argmax()
    obs, reward, terminated, truncated, _ = env.step(action.numpy())
    accumulated_rew += reward

env.close()

final_path = os.path.join(
    base_dir,
    f"epoch={NUM_EPOCHS}_reward={int(accumulated_rew)}.mp4"
)
imageio.mimsave(final_path, frames, fps=30)

print("Accumulated rew: ", accumulated_rew)
print("Video saved to:\n", final_path)



Accumulated rew:  500.0
Video saved to: videos/offline/bc/CartPole-v1/2025-11-30_125131/epoch=64_reward=500.mp4
