In [1]:
import os
import sys

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
import mediapy as media
import mujoco
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pickle

from track_mjx.analysis.utils import load_from_h5py

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
base_path = Path.cwd()
xml_path = base_path / "rodent.xml"
arena_path = base_path / "arena.xml"
with open(base_path / "coltrane_rollout_0_1000.pkl", "rb") as f:
    data = pickle.load(f)

In [8]:
data.keys()

dict_keys(['ctrl', 'qposes_ref', 'qposes_rollout', 'state_rewards'])

In [9]:
qposes_ref = data["qposes_ref"]
qposes_rollout = data["qposes_rollout"]

In [10]:
qposes_ref.shape

(2000, 74)

## STAC Render

In [18]:
from vnl_mjx.tasks.utils import _scale_body_tree, _recolour_tree


spec = mujoco.MjSpec()
spec = spec.from_file(arena_path.as_posix())

walker_spec = mujoco.MjSpec.from_file(xml_path.as_posix())
# Scale and recolor the ghost body
for body in walker_spec.worldbody.bodies:
    _recolour_tree(body, rgba=(0.8, 0.8, 0.8, 0.3))
# Attach as ghost at the offset frame
pos = (0, 0, 0.05)
frame = spec.worldbody.add_frame(pos=pos, quat=[1, 0, 0, 0])
spawn_body = frame.attach_body(walker_spec.body("walker"), "", suffix="")
mj_model = spec.compile()

##### Set solver options #####
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]

mj_data = mujoco.MjData(mj_model)
scene_option = mujoco.MjvOption()
# scene_option.sitegroup[:] = [1, 1, 1, 1, 0, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

mujoco.mj_forward(mj_model, mj_data)

frames = []
# render while stepping using mujoco
xposes = []
with mujoco.Renderer(mj_model, height=1000, width=1000) as renderer:
    for qpos in tqdm(qposes_ref):
        mj_data.qpos = qpos
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera="back_right", scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)
media.show_video(frames, fps=100)
media.write_video(base_path / "rodent_stac_mujoco.mp4", frames, fps=100)

100%|██████████| 2000/2000 [00:05<00:00, 369.59it/s]


0
This browser does not support the video tag.


In [12]:
np.hstack((qposes_ref, qposes_rollout)).shape

(2000, 148)

In [15]:
from vnl_mjx.tasks.utils import _scale_body_tree, _recolour_tree
from etils import epath

spec = mujoco.MjSpec()
spec = spec.from_file(xml_path.as_posix())

walker_spec = mujoco.MjSpec.from_file(xml_path.as_posix())
# Scale and recolor the ghost body
for body in walker_spec.worldbody.bodies:
    _recolour_tree(body, rgba=(0.8, 0.8, 0.8, 0.3))
# Attach as ghost at the offset frame
pos = (0, 0, 0.05)
frame = spec.worldbody.add_frame(pos=pos, quat=[1, 0, 0, 0])
spawn_body = frame.attach_body(walker_spec.body("walker"), "", suffix="-ghost")
mj_model = spec.compile()


##### Set solver options #####
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]

mj_data = mujoco.MjData(mj_model)
scene_option = mujoco.MjvOption()
# scene_option.sitegroup[:] = [1, 1, 1, 1, 0, 0]
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

mujoco.mj_forward(mj_model, mj_data)

frames = []
# render while stepping using mujoco
xposes = []
with mujoco.Renderer(mj_model, height=1000, width=1000) as renderer:
    for qpos in tqdm(np.hstack((qposes_rollout, qposes_ref))):
        mj_data.qpos = qpos
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera="back_right", scene_option=scene_option)
        pixels = renderer.render()
        frames.append(pixels)
media.show_video(frames, fps=100)
media.write_video(base_path / "rodent_stac_with_track_mujoco.mp4", frames, fps=100)

 63%|██████▎   | 1251/2000 [00:02<00:01, 442.44it/s]

100%|██████████| 2000/2000 [00:04<00:00, 437.60it/s]


0
This browser does not support the video tag.
