In [1]:
import mujoco
from mujoco import mjx
import jax
import dm_control
from dm_control import mjcf
from dm_control.locomotion.walkers import rescale
import importlib
utils = importlib.import_module("stac-mjx.utils")


In [2]:
params = utils.load_params("params/params.yaml")
mjmodel = mujoco.MjModel.from_xml_path(params["XML_PATH"])
mjdata = mujoco.MjData(mjmodel)
mjx_model = mjx.device_put(mjmodel)
mjx_data = mjx.make_data(mjx_model)

In [3]:
mjcf_root = mjcf.from_path(params["XML_PATH"])

# Rescale
rescale.rescale_subtree(
        mjcf_root,
        params["SCALE_FACTOR"],
        params["SCALE_FACTOR"],
    )

In [4]:
# This function modifies mjcf in place so root should be passed in to make sure we're doing it right
def set_body_sites(root, params):
    # gets part names (but only )
    # part_names = physics.named.data.qpos.axes.row.names
    # for _ in range(6):
    #     part_names.insert(0, part_names[0])
    # print(part_names)
    # mjcf_root = mjcf.from_path(params["XML_PATH"])
    body_sites = []
    for key, v in params["KEYPOINT_MODEL_PAIRS"].items():
        parent = root.find("body", v)
        
        pos = params["KEYPOINT_INITIAL_OFFSETS"][key]
        
        site = parent.add(
            "site",
            name=key,
            type="sphere",
            size=[0.005],
            rgba="0 0 0 1",
            pos=pos,
            group=3,
        )
        body_sites.append(site)
    return body_sites

In [5]:
root = mjcf_root.__copy__()
body_sites = set_body_sites(root, params)

In [6]:
physics = mjcf.Physics.from_mjcf_model(root)
binding = physics.bind(body_sites)

In [28]:
binding.pos

SynchronizingArrayWrapper([[-0.03230154, -0.00472705, -0.02205959],
                           [-0.03230154,  0.00472705, -0.02205959],
                           [ 0.        ,  0.        ,  0.        ],
                           [ 0.        ,  0.        ,  0.        ],
                           [-0.01593121,  0.01035529, -0.02230369],
                           [-0.01593121, -0.01035529, -0.02230369],
                           [ 0.02109994,  0.00761433, -0.00275217],
                           [ 0.02109994, -0.00761433, -0.00275217],
                           [ 0.00303175,  0.00151587, -0.0083373 ],
                           [ 0.00303175, -0.00151587, -0.0083373 ],
                           [-0.015     ,  0.015     ,  0.        ],
                           [-0.015     , -0.015     ,  0.        ],
                           [ 0.01542265,  0.017479  , -0.02570441],
                           [ 0.01542265, -0.017479  , -0.02570441],
                           [ 0.0287    ,  0.0098

In [14]:
physics.data.site_xpos

array([[ 2.78957345e-02,  1.85267244e-07,  5.85806154e-02],
       [ 1.06342352e-02,  1.60370410e-08,  7.37424266e-02],
       [ 1.06342352e-02,  1.60370410e-08,  7.37424266e-02],
       [ 5.56749049e-03,  2.35136027e-06,  7.42213096e-02],
       [-2.97959516e-04, -7.01820552e-06,  7.41361047e-02],
       [-6.39789328e-03,  8.53285497e-07,  7.29667214e-02],
       [-1.25989532e-02, -4.73061479e-06,  7.07184629e-02],
       [-1.81657092e-02, -1.61687166e-09,  6.74414527e-02],
       [-2.99299775e-02,  3.16028242e-06,  6.01325681e-02],
       [-3.42592406e-02,  1.03452890e-02,  5.12335272e-02],
       [-3.42592406e-02, -1.03389684e-02,  5.12335272e-02],
       [-4.49299775e-02,  1.50031603e-02,  6.01325681e-02],
       [-4.49299775e-02, -1.49968397e-02,  6.01325681e-02],
       [-2.99299775e-02,  3.16028242e-06,  6.01325681e-02],
       [-2.99299775e-02,  3.16028242e-06,  6.01325681e-02],
       [-3.42592406e-02,  1.03452890e-02,  5.12335272e-02],
       [-2.03788598e-02,  2.60763872e-02

In [20]:
params = utils.load_params("params/params.yaml")
mjmodel = mujoco.MjModel.from_xml_path(params["XML_PATH"])
mjdata = mujoco.MjData(mjmodel)
mjdata.site_xpos.shape

(86, 3)

In [36]:
physics.named.data.xpos

FieldIndexer(xpos):
                         x         y         z         
 0               world [ 0         0         0       ]
 1               torso [ 0.0279    1.85e-07  0.0586  ]
 2          vertebra_1 [ 0.0106    1.6e-08   0.0737  ]
 3          vertebra_2 [ 0.00557   2.35e-06  0.0742  ]
 4          vertebra_3 [-0.000298 -7.02e-06  0.0741  ]
 5          vertebra_4 [-0.0064    8.53e-07  0.073   ]
 6          vertebra_5 [-0.0126   -4.73e-06  0.0707  ]
 7          vertebra_6 [-0.0182   -1.62e-09  0.0674  ]
 8              pelvis [-0.0299    3.16e-06  0.0601  ]
 9         upper_leg_L [-0.0343    0.0103    0.0512  ]
10         lower_leg_L [-0.0204    0.0261    0.0281  ]
11              foot_L [-0.0495    0.0218    0.00825 ]
12               toe_L [-0.0305    0.0287    0.00577 ]
13         upper_leg_R [-0.0343   -0.0103    0.0512  ]
14         lower_leg_R [-0.0204   -0.0261    0.0281  ]
15              foot_R [-0.0495   -0.0218    0.00825 ]
16               toe_R [-0.0305   -0.0287   

In [41]:
physics.named.data.xpos.

AttributeError: 'numpy.ndarray' object has no attribute '_attributes'

In [52]:
for i, x in enumerate(physics.named.data.site_xpos._field):
    print(i, x)

0 [2.78957345e-02 1.85267244e-07 5.85806154e-02]
1 [1.06342352e-02 1.60370410e-08 7.37424266e-02]
2 [1.06342352e-02 1.60370410e-08 7.37424266e-02]
3 [5.56749049e-03 2.35136027e-06 7.42213096e-02]
4 [-2.97959516e-04 -7.01820552e-06  7.41361047e-02]
5 [-6.39789328e-03  8.53285497e-07  7.29667214e-02]
6 [-1.25989532e-02 -4.73061479e-06  7.07184629e-02]
7 [-1.81657092e-02 -1.61687166e-09  6.74414527e-02]
8 [-2.99299775e-02  3.16028242e-06  6.01325681e-02]
9 [-0.03425924  0.01034529  0.05123353]
10 [-0.03425924 -0.01033897  0.05123353]
11 [-0.04492998  0.01500316  0.06013257]
12 [-0.04492998 -0.01499684  0.06013257]
13 [-2.99299775e-02  3.16028242e-06  6.01325681e-02]
14 [-2.99299775e-02  3.16028242e-06  6.01325681e-02]
15 [-0.03425924  0.01034529  0.05123353]
16 [-0.02037886  0.02607639  0.02809956]
17 [-0.0188366   0.02782429  0.02552912]
18 [-0.02037886  0.02607639  0.02809956]
19 [-0.04945025  0.02182204  0.00824593]
20 [-0.0526804   0.02134933  0.00603997]
21 [-0.04945025  0.02182204  

In [53]:
import numpy as np

def get_name_arr_and_len(fi, dim_idx):
      """Returns a string array of element names and the max name length."""
      axis = fi._axes[dim_idx]
      for name in axis.names:
        axis.convert_key_item(name)
      
      size = fi._field.shape[dim_idx]
      try:
        name_arr = np.zeros(size, dtype='S{}'.format(name_len))
        for name in axis.names:
          if name:
            # Use the `Axis` object to convert the name into a numpy index, then
            # use this index to write into name_arr.
            name_arr[axis.convert_key_item(name)] = name
      except AttributeError:
        name_arr = np.zeros(size, dtype='S0')  # An array of zero-length strings
      return name_arr

In [55]:
names, _ = get_name_arr_and_len(physics.named.data.site_xpos, 0)

In [58]:
names

array([b'tracking[torso]', b'tracking[vertebra_1]', b'SpineM',
       b'tracking[vertebra_2]', b'tracking[vertebra_3]',
       b'tracking[vertebra_4]', b'tracking[vertebra_5]',
       b'tracking[vertebra_6]', b'tracking[pelvis]', b'hip_L', b'hip_R',
       b'HipL', b'HipR', b'SpineL', b'TailBase', b'tracking[upper_leg_L]',
       b'knee_L', b'KneeL', b'tracking[lower_leg_L]', b'ankle_L',
       b'AnkleL', b'tracking[foot_L]', b'toe_L', b'FootL',
       b'tracking[toe_L]', b'sole_L', b'tracking[upper_leg_R]', b'knee_R',
       b'KneeR', b'tracking[lower_leg_R]', b'ankle_R', b'AnkleR',
       b'tracking[foot_R]', b'toe_R', b'FootR', b'tracking[toe_R]',
       b'sole_R', b'tracking[vertebra_C1]', b'tracking[vertebra_C2]',
       b'tracking[vertebra_C3]', b'tracking[vertebra_C4]',
       b'tracking[vertebra_C5]', b'tracking[vertebra_C6]',
       b'tracking[vertebra_C7]', b'tracking[vertebra_C8]',
       b'tracking[vertebra_C9]', b'tracking[vertebra_C10]',
       b'tracking[vertebra_C11]', 

In [None]:
def get_indices(body_sites, physics):
    """generates a mapping of site name to site_xpos mapping

    Args:
        body_sites (_type_): _description_
        physics (_type_): _description_
    """
    for key, v in params["KEYPOINT_MODEL_PAIRS"].items():
        parent = root.find("body", v)
        pos = params["KEYPOINT_INITIAL_OFFSETS"][key]
        site = parent.add(
            "site",
            name=key,
            type="sphere",
            size=[0.005],
            rgba="0 0 0 1",
            pos=pos,
            group=3,
        )
        body_sites.append(site)

In [63]:
axis = physics.named.data.site_xpos._axes[0]
site_index_map = {key: axis.convert_key_item(key) for key in params["KEYPOINT_MODEL_PAIRS"].keys()}
site_index_map

{'AnkleL': 20,
 'AnkleR': 31,
 'EarL': 77,
 'EarR': 78,
 'ElbowL': 86,
 'ElbowR': 100,
 'FootL': 23,
 'FootR': 34,
 'HandL': 92,
 'HandR': 106,
 'HipL': 11,
 'HipR': 12,
 'KneeL': 17,
 'KneeR': 28,
 'ShoulderL': 83,
 'ShoulderR': 97,
 'Snout': 79,
 'SpineF': 68,
 'SpineL': 13,
 'SpineM': 2,
 'TailBase': 14,
 'WristL': 89,
 'WristR': 103}