In [1]:
import utils
import mujoco
import os
import pickle
from scipy.io import savemat 
from dm_control import mjcf
import numpy as np
import jax
from jax import numpy as jnp


In [2]:
# If youre machine is low on ram:
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'

In [3]:
def save(fit_data, save_path):
    """Save data.

    Args:
        save_path (Text): Path to save data. Defaults to None.
    """
    if os.path.dirname(save_path) != "":
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    _, file_extension = os.path.splitext(save_path)
    if file_extension == ".p":
        with open(save_path, "wb") as output_file:
            pickle.dump(fit_data, output_file, protocol=2)
    elif file_extension == ".mat":
        savemat(save_path, fit_data)

In [4]:
from controller import *

In [5]:
# relative pathing no working in notebook rn
utils.init_params("/Users/charleszhang/github/stac-mjx/params/params.yaml") # "/home/charles/github/stac-mjx/params/params.yaml"
ratpath = "/Users/charleszhang/GitHub/stac-mjx/models/rodent.xml" # "/home/charles/github/stac-mjx/models/rodent.xml"
rat23path = "/Users/charleszhang/github/stac-mjx/models/rat23.mat" # "/home/charles/github/stac-mjx/models/rat23.mat"
model = mujoco.MjModel.from_xml_path(ratpath)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.iterations = 1
model.opt.ls_iterations = 1

# Need to download this data file and provide the path
data_path = "/Users/charleszhang/Research Projects/VNL/save_data_AVG.mat" # "/home/charles/Desktop/save_data_AVG.mat"
offset_path = "offset.p"

root = mjcf.from_path(ratpath)

# Default ordering of mj sites is alphabetical, so we reorder to match
kp_names = utils.loadmat(rat23path)["joint_names"]
utils.params["kp_names"] = kp_names

# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)
# Load kp_data
kp_data = utils.loadmat(data_path)["pred"][:] / 1000


In [6]:
# kp_data
# TODO: store kp_data used in fit in another variable (small slice of kpdata)
kp_data = prep_kp_data(kp_data, stac_keypoint_order)
# chunk it to pass int vmapped functions
kp_data, n_envs = chunk_kp_data(kp_data)

In [7]:
fit_kp_data = kp_data[:1]
fit_kp_data.shape

(1, 250, 69)

In [8]:
fit_kp_data[0].shape

(250, 69)

In [9]:
# fit
fit_data = fit(root, fit_kp_data[0])
save(fit_data, offset_path)

Root Optimization:
[ True  True  True  True  True  True  True False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False]
q-phase:
[-0.08807933  0.35345283  0.04053693  1.          0.          0.
  0.        ]
optimized params: [-0.08807933  0.35345283  0.04053693  1.          0.          0.
  0.        ]
last forward step
[ True  True  True  True  True  True  True False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False Fal

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

In [None]:
import jax
from jax import numpy as jnp
import mujoco
from mujoco import mjx

def get_site_xpos(mjx_data):
    """Returns MjxData.site_xpos of keypoint body sites

    Args:
        mjx_data (_type_): _description_
        site_index_map (_type_): _description_

    Returns:
        jax.Array: _description_
    """
    return jnp.array([mjx_data.site_xpos[i] for i in utils.params["site_index_map"].values()])

def q_loss(
    q: jnp.ndarray,
    mjx_model,
    mjx_data,
    kp_data: jnp.ndarray,
    qs_to_opt: jnp.ndarray = None,
    q_copy: jnp.ndarray = None,
    kps_to_opt: jnp.ndarray = None,
) -> float:
    """Compute the marker loss for q_phase optimization.

    Args:
        q (jnp.ndarray): Qpos for current frame.
        env (TYPE): env of current environment.
        kp_data (jnp.ndarray): Reference keypoint data.
        sites (jnp.ndarray): sites of keypoints at frame_index
        qs_to_opt (List, optional): Binary vector of qposes to optimize.
        q_copy (jnp.ndarray, optional): Copy of current qpos, for use in optimization of subsets
        kps_to_opt (List, optional): Vector denoting which keypoints to use in loss.

    Returns:
        float: loss value
    """
    # If optimizing arbitrary sets of qpos, add the optimizer qpos to the copy.
    if qs_to_opt is not None:
        q_copy = q_copy.at[qs_to_opt].set(q.copy())
        q = jnp.copy(q_copy)

    # pred = q_joints_to_markers(q, mjx_model, mjx_data)
    residual = kp_data - q_joints_to_markers(q, mjx_model, mjx_data)
    if kps_to_opt is not None:
        residual = residual[kps_to_opt]
        
    # residual = jnp.sum(residual)

    return residual

def q_joints_to_markers(q: jnp.ndarray, mjx_model, mjx_data) -> jnp.ndarray:
    """Convert site information to marker information.

    Args:
        q (jnp.ndarray): Postural state
        env (TYPE): env of current environment
        sites (jnp.ndarray): Sites of keypoint data.

    Returns:
        jnp.ndarray: Array of marker positions.
    """
    mjx_data = mjx_data.replace(qpos=q.copy())
    # Forward kinematics
    mjx_data = mjx.forward(mjx_model, mjx_data)

    return get_site_xpos(mjx_data).flatten()

In [None]:
from stac_base import q_loss

jit_q_loss = jax.jit(q_loss)
z = jnp.zeros(74)
physics, model = set_body_sites(root)
mjx_model = mjx.put_model(model)
mjx_data = mjx.make_data(mjx_model)

In [None]:
import time
start = time.time()
jit_q_loss(z, mjx_model, mjx_data, fit_kp_data[0])
print(time.time() - start)

141.9522738456726


In [None]:
%timeit jit_q_loss(z, mjx_model, mjx_data, fit_kp_data[0])

636 µs ± 4.71 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
