# Meta Motivo: online training tutorial
This notebook is designed for showcasing how to use the library for training an FB-CPR agent. It is not designed to exactly reproduce the results in the paper.

In [None]:
from __future__ import annotations
import torch

torch.set_float32_matmul_precision("high")

import gymnasium
import numpy as np
import dataclasses
from humenv import make_humenv
from humenv.bench.gym_utils.rollouts import rollout
import mediapy as media
from metamotivo.buffers.buffers import DictBuffer, TrajectoryBuffer
from metamotivo.fb_cpr import FBcprAgent, FBcprAgentConfig
from tqdm.notebook import trange, tqdm
import time
from gymnasium import ObservationWrapper

from packaging.version import Version

if Version(gymnasium.__version__) >= Version("1.0"):
    raise RuntimeError("This tutorial does not support yet gymnasium >= 1.0")

We need to provide the time step inside an episode to the agent since it is used to decide when to switch policy (i.e., embedding `z`) in a rollout of the online training. Gymnasium >=1.0 provides this wrapper but here we report a simpler version for compatibility with previous versions.

In [None]:
class TimeAwareObservation(ObservationWrapper):
    """
    The MIT License

    Copyright (c) 2016 OpenAI
    Copyright (c) 2022 Farama Foundation
    """
    def __init__(self, env):
        super().__init__(env)
        self.max_timesteps = env.spec.max_episode_steps
        self.timesteps: int = 0
        self._time_preprocess_func = lambda time: np.array([time], dtype=np.int32)
        time_space = gymnasium.spaces.Box(0, self.max_timesteps, dtype=np.int32)
        assert not isinstance(
            env.observation_space, (gymnasium.spaces.Dict, gymnasium.spaces.Tuple)
        )

        observation_space = gymnasium.spaces.Dict(
            obs=env.observation_space, time=time_space
        )
        self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
        self.observation_space = observation_space
        self._obs_postprocess_func = lambda obs: obs

    def observation(self, observation):
        return self._obs_postprocess_func(
            self._append_data_func(
                observation, self._time_preprocess_func(self.timesteps)
            )
        )

    def step(self, action):
        self.timesteps += 1
        return super().step(action)

    def reset(self, *, seed=None, options=None):
        self.timesteps = 0
        return super().reset(seed=seed, options=options)

## Agent and Train parameters

We start by defining the parameters of the FB-CPR agent.

In [None]:
env, _ = make_humenv(
    num_envs=1,
    vectorization_mode="sync",
    wrappers=[gymnasium.wrappers.FlattenObservation],
    render_width=320,
    render_height=320,
)

agent_config = FBcprAgentConfig()
agent_config.model.obs_dim = env.observation_space.shape[0]
agent_config.model.action_dim = env.action_space.shape[0]
agent_config.model.device = "cpu"
agent_config.model.norm_obs = True
agent_config.model.seq_length = 1
# misc
agent_config.train.discount = 0.98
agent_config.compile = False
agent_config.cudagraphs = False
agent = FBcprAgent(**dataclasses.asdict(agent_config))

We also define a few parameters for online training.

In [None]:
buffer_size = 1_000_000
online_parallel_envs = 5
log_every_updates = 100
online_num_env_steps = 2000
num_seed_steps = 1000

# "Expert" trajectories
FB-CPR leverages expert observation-only trajecteries in the training process. For training Meta Motivo you can use the motion capture dataset as described in the HumEnv repository. Here for simplicity we create "expert" trajectories running a random agent.

In [None]:
class RandomAgent:
    def __init__(self, env):
        self.env = env

    def act(self, *args, **kwargs):
        return self.env.action_space.sample()


random_agent = RandomAgent(env)
_, episodes = rollout(env=env, agent=random_agent, num_episodes=4)
for ep in episodes:
    ep["observation"] = ep["observation"].astype(np.float32)
    del ep["action"]

We can visualize an episode by reloading `qpos` (and optionally `qvel`) information.

In [None]:
ep = episodes[0]
frames = []
for i in range(len(ep["info"]["qpos"])):
    env.unwrapped.set_physics(ep["info"]["qpos"][i])
    frames.append(env.render())
media.show_video(frames, fps=30)

With this tutorial we provide a simple buffer for storing trajectories (see `examples/trajecotory_buffer.py`).

In [None]:
expert_buffer = TrajectoryBuffer(
    capacity=len(episodes),
    seq_length=agent_config.model.seq_length,
    device=agent.device,
)
expert_buffer.extend(episodes)
print(expert_buffer)
env.close()

### Training loop
This section describes the training loop that should be self explanatory.

In [None]:
train_env, _ = make_humenv(
    num_envs=online_parallel_envs,
    vectorization_mode="sync",
    wrappers=[
        gymnasium.wrappers.FlattenObservation,
        lambda env: TimeAwareObservation(env),
    ],
    render_width=320,
    render_height=320,
)

replay_buffer = {
    "train": DictBuffer(capacity=buffer_size, device=agent.device),
    "expert_slicer": expert_buffer,
}
obs, _ = train_env.reset()
print(obs.keys())

In [None]:
progb = tqdm(total=online_num_env_steps)
td, info = train_env.reset()
total_metrics, context = None, None
start_time = time.time()
for t in range(0, online_num_env_steps, online_parallel_envs):
    with torch.no_grad():
        obs = torch.tensor(td["obs"], dtype=torch.float32, device=agent.device)
        step_count = torch.tensor(td["time"], device=agent.device)
        context = agent.maybe_update_rollout_context(z=context, step_count=step_count)
        if t < num_seed_steps:
            action = train_env.action_space.sample().astype(np.float32)
        else:
            # this works in inference mode
            action = agent.act(obs=obs, z=context, mean=False).cpu().detach().numpy()
    new_td, reward, terminated, truncated, new_info = train_env.step(action)
    real_next_obs = new_td["obs"].astype(np.float32).copy()
    done = np.logical_or(terminated.ravel(), truncated.ravel())
    for idx, trunc in enumerate(done):
        if trunc:
            print(new_info["final_observation"])
            real_next_obs[idx] = new_info["final_observation"][idx]["obs"].astype(
                np.float32
            )
    data = {
        "observation": obs,
        "action": action,
        "z": context,
        "step_count": step_count,
        "next": {
            "observation": real_next_obs,
            "terminated": terminated.reshape(-1, 1),
            "truncated": truncated.reshape(-1, 1),
            "reward": reward.reshape(-1, 1),
        },
    }
    replay_buffer["train"].extend(data)

    metrics = agent.update(replay_buffer, t)

    if total_metrics is None:
        total_metrics = {k: metrics[k] * 1 for k in metrics.keys()}
    else:
        total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()}
    if t % log_every_updates == 0:
        m_dict = {}
        for k in sorted(list(total_metrics.keys())):
            tmp = total_metrics[k] / (1 if t == 0 else log_every_updates)
            m_dict[k] = np.round(tmp.mean().item(), 6)
        m_dict["duration"] = time.time() - start_time
        print(f"Steps: {t}\n{m_dict}")
        total_metrics = None
    progb.update(online_parallel_envs)
    td = new_td