In [1]:
import utils
import mujoco
import os
import pickle
from scipy.io import savemat 
from dm_control import mjcf
import numpy as np
import jax
from jax import numpy as jnp
import time
from controller import *


In [2]:
utils.init_params("/Users/charleszhang/GitHub/stac-mjx/params/params.yaml")

ratpath = "../models/rodent_stac.xml"
rat23path = "../models/rat23.mat"
model = mujoco.MjModel.from_xml_path(ratpath)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
model.opt.iterations = 1
model.opt.ls_iterations = 4

root = mjcf.from_path(ratpath)

In [3]:
data_path = "../save_data_AVG.mat"

# Default ordering of mj sites is alphabetical, so we reorder to match
kp_names = utils.loadmat(rat23path)["joint_names"]
utils.params["kp_names"] = kp_names

# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)
# Load kp_data, /1000 to scale data (from mm to meters i think?)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

kp_data = prep_kp_data(kp_data, stac_keypoint_order)

# setup for fit
physics, mj_model = set_body_sites(root)
part_opt_setup(physics)

In [5]:
# Create mjx model and data
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.make_data(mjx_model)

# Get and set the offsets of the markers
offsets = jnp.copy(stac_base.get_site_pos(mjx_model))
offsets *= utils.params['SCALE_FACTOR']

# print(mjx_model.site_pos, mjx_model.site_pos.shape)
mjx_model = stac_base.set_site_pos(mjx_model, offsets)

# forward is used to calculate xpos and such
mjx_data = mjx.kinematics(mjx_model, mjx_data)
mjx_data = mjx.com_pos(mjx_model, mjx_data)

only do a 2k frame clip

In [6]:
kp_data = kp_data[:2000]

root and pose opt

In [7]:
mjx_data = root_optimization(mjx_model, mjx_data, kp_data)
mjx_data, q, walker_body_sites, x = pose_optimization(mjx_model, mjx_data, kp_data)

Root Optimization:


2024-02-06 19:05:51.010309: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_q_opt] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-02-06 19:18:20.557106: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 14m29.553661s

********************************
[Compiling module jit_q_opt] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 174.57347106933594 Stepsize:1.0  Decrease Error:87.08638763427734  Curvature Error:174.57347106933594 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 83.48860168457031 Stepsize:0.5  Decrease Error:20.76815414428711  Curvature Error:83.48860168457031 
INFO: jaxopt.ZoomLineSearch: Iter: 3 Minimum Decrease & Curvature Errors (stop. crit.): 37.93450164794922 Stepsize:0.25  Decrease Error:4.691154479980469  Curvature Error:37.93450164794922 
INFO: jaxopt.ZoomLineSearch: Iter: 4 Minimum Decrease & Curvature Errors (stop. crit.): 15.165499687194824 Stepsize:0.125  Decrease Error:0.9227980375289917  Curvature Error:15.165499687194824 
INFO: jaxopt.ZoomLineSearch: Iter: 5 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.021953148767352104  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors

initial offsets (scaled down already)

In [9]:
offsets

Array([[-0.02616425, -0.00382891, -0.01786827],
       [-0.02616425,  0.00382891, -0.01786827],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [-0.01290428,  0.00838778, -0.01806599],
       [-0.01290428, -0.00838778, -0.01806599],
       [ 0.01709095,  0.00616761, -0.00222925],
       [ 0.01709095, -0.00616761, -0.00222925],
       [ 0.00245571,  0.00122786, -0.00675322],
       [ 0.00245571, -0.00122786, -0.00675322],
       [-0.01215   ,  0.01215   ,  0.        ],
       [-0.01215   , -0.01215   ,  0.        ],
       [ 0.01249234,  0.01415799, -0.02082057],
       [ 0.01249234, -0.01415799, -0.02082057],
       [ 0.023247  ,  0.0079704 , -0.0205902 ],
       [ 0.023247  , -0.0079704 , -0.0205902 ],
       [ 0.        ,  0.        ,  0.        ],
       [-0.01215   ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.   

save interim data to "pose_opt_qs.p"

In [2]:
data = {
    "kp_data": kp_data,
    "mjx_model": mjx_model,
    "mjx_data": mjx_data,
    "qpos": q,
    "offsets": offsets,
    "walker_body_sites": walker_body_sites,
    "xpos": x,
    "physics": physics,
    "site_index_map": utils.params["site_index_map"],
    "names_qpos": utils.params["part_names"],
    "names_xpos": physics.named.data.xpos.axes.row.names,
}
save_path = "pose_opt_qs.p"
if os.path.dirname(save_path) != "":
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as output_file:
    pickle.dump(data, output_file, protocol=2)

NameError: name 'kp_data' is not defined

load data and run m_phase

In [2]:
utils.init_params("/Users/charleszhang/GitHub/stac-mjx/params/params.yaml")

with open("pose_opt_qs.p", "rb") as file:
    data = pickle.load(file)

mjx_model = data["mjx_model"]
mjx_data = data["mjx_data"]
kp_data = data["kp_data"]
offsets = data["offsets"]
q = data["qpos"]
x = data["xpos"]
physics = data["physics"]
walker_body_sites = data["walker_body_sites"]

utils.params["site_index_map"] = data["site_index_map"]
utils.params["part_names"] = data["names_qpos"]



In [3]:
mjx_model, mjx_data = offset_optimization(
    mjx_model,
    mjx_data,
    kp_data,
    offsets,
    q
    )

m_path = "m_phase.p"
m_data = package_data(
        mjx_model, physics, q, x, walker_body_sites, kp_data
    )

save(m_data, m_path)

Begining offset optimization:


2024-02-08 10:32:49.796429: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_m_opt] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-02-08 10:40:07.508537: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 9m17.717593s

********************************
[Compiling module jit_m_opt] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


learned offsets: [-0.02907139 -0.00425435 -0.01985363 -0.02907139 -0.01985096 -0.01985363
  0.          0.0124962   0.          0.          0.          0.
 -0.01433809  0.00931976 -0.02007332 -0.01433809 -0.00931976 -0.02007332
  0.01898995  0.00685289 -0.00247695  0.01898995 -0.00587683 -0.00247695
  0.00245571  0.00122786 -0.00675322  0.00245571 -0.00122786 -0.00675322
 -0.0135      0.          0.         -0.0135     -0.0135      0.
  0.01388038  0.0157311  -0.02313397  0.01388038 -0.0157311   0.00382891
  0.02583     0.008856   -0.022878    0.02583    -0.008856   -0.022878
  0.          0.          0.         -0.01215     0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
offset optimization finished in 1066.6220920085907


In [4]:
offsets

Array([[-0.02616425, -0.00382891, -0.01786827],
       [-0.02616425,  0.00382891, -0.01786827],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [-0.01290428,  0.00838778, -0.01806599],
       [-0.01290428, -0.00838778, -0.01806599],
       [ 0.01709095,  0.00616761, -0.00222925],
       [ 0.01709095, -0.00616761, -0.00222925],
       [ 0.00245571,  0.00122786, -0.00675322],
       [ 0.00245571, -0.00122786, -0.00675322],
       [-0.01215   ,  0.01215   ,  0.        ],
       [-0.01215   , -0.01215   ,  0.        ],
       [ 0.01249234,  0.01415799, -0.02082057],
       [ 0.01249234, -0.01415799, -0.02082057],
       [ 0.023247  ,  0.0079704 , -0.0205902 ],
       [ 0.023247  , -0.0079704 , -0.0205902 ],
       [ 0.        ,  0.        ,  0.        ],
       [-0.01215   ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.   