In [1]:
# rendering related
import os, imageio
import jax
import mujoco
from dm_control.mujoco import wrapper
from dm_control.mujoco.wrapper.mjbindings import enums
import pickle
from tqdm import tqdm

In [2]:
# loading data
with open("./../clips/all_snips.p", "rb") as f:
    all_traj = pickle.load(f)

# render overlay
scene_option = wrapper.MjvOption()
scene_option.geomgroup[2] = 1
scene_option.sitegroup[2] = 1

scene_option.sitegroup[3] = 1
scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = True
scene_option.flags[enums.mjtVisFlag.mjVIS_LIGHT] = False
scene_option.flags[enums.mjtVisFlag.mjVIS_CONVEXHULL] = True
scene_option.flags[enums.mjtRndFlag.mjRND_SHADOW] = False
scene_option.flags[enums.mjtRndFlag.mjRND_REFLECTION] = False
scene_option.flags[enums.mjtRndFlag.mjRND_SKYBOX] = False
scene_option.flags[enums.mjtRndFlag.mjRND_FOG] = False
mj_model = mujoco.MjModel.from_xml_path(f"../assets/rodent.xml")

mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]

mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_model.opt.jacobian = 0  # dense
mj_data = mujoco.MjData(mj_model)

os.environ["MUJOCO_GL"] = "egl"

renderer = mujoco.Renderer(mj_model, height=512, width=512)
mujoco.mj_kinematics(mj_model, mj_data)

In [3]:
len(all_traj["snips_order"])

842

In [9]:
all_traj["qpos"].shape

(210500, 74)

In [13]:
qpos_clips = all_traj["qpos"].reshape((-1, 250, 74))

In [20]:
all_traj["snips_order"][0].split("/")[-1].split(".")[0]

'FastWalk_171'

In [24]:
for i, name in enumerate(tqdm(all_traj["snips_order"])):
    frames = []
    # render while stepping using mujoco
    name = name.split("/")[-1].split(".")[0]
    video_path = f"videos/{name}.mp4"
    with imageio.get_writer(video_path, fps=50) as video:
        for qpos in tqdm(qpos_clips[i], leave=False):
            # Set keypoints
            mj_data.qpos = qpos
            mujoco.mj_forward(mj_model, mj_data)
            renderer.update_scene(
                mj_data, camera="close_profile"
            )
            pixels = renderer.render()
            video.append_data(pixels)
            frames.append(pixels)

100%|██████████| 842/842 [07:22<00:00,  1.90it/s]


# Good Clip to Imitate

- FastWalk_61
- FastWalk_65
- FastWalk_80
- FastWalk_93
- Walk_20
- Walk_27
- Walk_34
- Walk_49