In [27]:
import mujoco
from mujoco import mjx
import jax
import dm_control
from dm_control import mjcf
from dm_control.locomotion.walkers import rescale
import importlib
utils = importlib.import_module("stac-mjx.utils")


In [28]:
params = utils.load_params("params/params.yaml")
mjmodel = mujoco.MjModel.from_xml_path(params["XML_PATH"])
mjdata = mujoco.MjData(mjmodel)
mjx_model = mjx.device_put(mjmodel)
mjx_data = mjx.make_data(mjx_model)

In [29]:
mjcf_root = mjcf.from_path(params["XML_PATH"])

# Rescale
rescale.rescale_subtree(
        mjcf_root,
        params["SCALE_FACTOR"],
        params["SCALE_FACTOR"],
    )

In [30]:
# This function modifies mjcf in place so root should be passed in to make sure we're doing it right
def set_body_sites(root, params):
    # gets part names (but only )
    # part_names = physics.named.data.qpos.axes.row.names
    # for _ in range(6):
    #     part_names.insert(0, part_names[0])
    # print(part_names)
    # mjcf_root = mjcf.from_path(params["XML_PATH"])
    body_sites = []
    for key, v in params["KEYPOINT_MODEL_PAIRS"].items():
        parent = root.find("body", v)
        
        pos = params["KEYPOINT_INITIAL_OFFSETS"][key]
        
        site = parent.add(
            "site",
            name=key,
            type="sphere",
            size=[0.005],
            rgba="0 0 0 1",
            pos=pos,
            group=3,
        )
        body_sites.append(site)
    return body_sites

In [31]:
root = mjcf_root.__copy__()
body_sites = set_body_sites(root, params)

In [32]:
physics = mjcf.Physics.from_mjcf_model(root)
binding = physics.bind(body_sites)

In [33]:
binding.pos

SynchronizingArrayWrapper([[-0.03230154, -0.00472705, -0.02205959],
                           [-0.03230154,  0.00472705, -0.02205959],
                           [ 0.        ,  0.        ,  0.        ],
                           [ 0.        ,  0.        ,  0.        ],
                           [-0.01593121,  0.01035529, -0.02230369],
                           [-0.01593121, -0.01035529, -0.02230369],
                           [ 0.02109994,  0.00761433, -0.00275217],
                           [ 0.02109994, -0.00761433, -0.00275217],
                           [ 0.00303175,  0.00151587, -0.0083373 ],
                           [ 0.00303175, -0.00151587, -0.0083373 ],
                           [-0.015     ,  0.015     ,  0.        ],
                           [-0.015     , -0.015     ,  0.        ],
                           [ 0.01542265,  0.017479  , -0.02570441],
                           [ 0.01542265, -0.017479  , -0.02570441],
                           [ 0.0287    ,  0.0098