In [None]:
import mujoco
import numpy as np
import stac_mjx
from pathlib import Path
import h5py
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import re
import time
import imageio
import mediapy as media
import os
from jax import numpy as jp
import hydra
from omegaconf import DictConfig, OmegaConf

OmegaConf.register_new_resolver("eq", lambda x, y: x.lower() == y.lower())
OmegaConf.register_new_resolver("contains", lambda x, y: x.lower() in y.lower())
OmegaConf.register_new_resolver(
    "resolve_default", lambda default, arg: default if arg == "" else arg
)

In [None]:
# glfw is faster, but use osmesa if glfw not available
%env MUJOCO_GL=glfw
%env PYOPENGL_PLATFORM=glfw

# Disable jax's memory preallocation if you're running multiple notebooks using jax
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false" 

base_path = Path.cwd()


# Load configs
cfg = stac_mjx.load_configs(base_path / "configs")

stac_cfg = cfg.stac
model_cfg = cfg.model

file_path = stac_cfg.data_path

In [None]:
df = pd.read_csv(file_path)

# --- Extract coordinate columns (ending in _x/_y/_z) ---
kp_coord_cols = df.filter(regex='_[xyz]$', axis=1).columns.tolist()

# --- Parse keypoint base names ---
kp_base_names = [re.sub(r'_[xyz]$', '', col) for col in kp_coord_cols]
kp_names_ordered = []
#not_list = ["f_l_coxa", "f_r_coxa", "m_l_coxa", "m_r_coxa", "h_l_coxa", "h_r_coxa", "mouth", "r_antenna_base", "l_antenna_base", "r_antenna_tip", "l_antenna_tip", "f_l_trochanter", "f_r_trochanter", "m_l_trochanter", "m_r_trochanter", "h_l_trochanter", "h_r_trochanter"]
seen = set()
#not_list = ['l_antenna_tip', 'r_antenna_tip']
for name in kp_base_names:
    #if name not in seen and name not in not_list:
    if name not in seen:
        kp_names_ordered.append(name)
        seen.add(name)

# --- Build ordered list of coord columns: [kp1_x, kp1_y, kp1_z, kp2_x, ...] ---
sorted_kp_coord_cols = []
for name in kp_names_ordered:
    sorted_kp_coord_cols.extend([f"{name}_x", f"{name}_y", f"{name}_z"])

# --- Extract and reshape data ---
actual_kp_df = df[sorted_kp_coord_cols]
flat_kp_data = actual_kp_df.values
kp_data = model_cfg["MOCAP_SCALE_FACTOR"] * flat_kp_data
sorted_kp_names = kp_names_ordered
for each in sorted_kp_names:
    print(each)

In [None]:
# Run stac
fit_path, ik_only_path = stac_mjx.run_stac(
 cfg,
 kp_data, 
 sorted_kp_names, 
 base_path=base_path
)

In [None]:
data_path = base_path / "stick_fit_offsets.h5"
n_frames = 250
save_path = base_path / "videos/render.mp4"

# Call mujoco_viz
cfg, frames = stac_mjx.viz_stac(
    data_path=data_path,
    n_frames=n_frames,
    save_path=save_path,
    start_frame=0,
    camera="side_alt",
    base_path=Path.cwd(),
    width=1920,
    height=1200
)

# Show the video in the notebook (it is also saved to the save_path)
media.show_video(frames, fps=cfg.model.RENDER_FPS)