In [1]:
import mujoco
from mujoco import mjx
import jax
from jax import numpy as jnp
from jax.lib import xla_bridge
from dm_control import mjcf
import numpy as np

import os
import pickle
import time
import argparse
import random
import logging 
import sys

import utils
# Gotta do this before importing controller
utils.init_params("../params/params.yaml")
utils.params['LR'] = 1e-1
utils.params['MAXITER'] = 4000

import controller as ctrl
from compute_stac import *
import operations as op

#Turning off warnings bc jax of deprecation warnings
import warnings
warnings.filterwarnings("ignore")

%cd stac-mjx


[Errno 2] No such file or directory: 'stac-mjx'
/Users/charleszhang/GitHub/stac-mjx/stac-mjx


In [2]:
def rootopt_then_frameopt(mj_model, kp_data):
    
    # Create mjx model and data
    mjx_model = mjx.put_model(mj_model)
    mjx_data = mjx.make_data(mjx_model)
    
    # Get and set the offsets of the markers
    offsets = jnp.copy(op.get_site_pos(mjx_model))
    offsets *= utils.params['SCALE_FACTOR']
    
    # print(mjx_model.site_pos, mjx_model.site_pos.shape)
    mjx_model = op.set_site_pos(mjx_model, offsets)

    # forward is used to calculate xpos and such
    mjx_data = mjx.kinematics(mjx_model, mjx_data)
    mjx_data = mjx.com_pos(mjx_model, mjx_data)
    
    # Begin optimization steps
    mjx_data = root_optimization(mjx_model, mjx_data, kp_data)

    print("optimizing first frame")
    return pose_optimization(mjx_model, mjx_data, kp_data)

    # print("starting offset optimization")
    # mjx_model, mjx_data = offset_optimization(mjx_model, mjx_data, kp_data, offsets, q)

In [3]:
# utils.init_params("../params/params.yaml")

# When using nvidia gpu do this thing
if xla_bridge.get_backend().platform == 'gpu':
    os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
    )
    # Set N_GPUS
    utils.params["N_GPUS"] = jax.local_device_count("gpu")


ratpath = "../models/rodent_stac_optimized.xml"
rat23path = "../models/rat23.mat"
kp_names = utils.loadmat(rat23path)["joint_names"]
utils.params["kp_names"] = kp_names
# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)   

model = mujoco.MjModel.from_xml_path(ratpath)

model.opt.solver = {
    'cg': mujoco.mjtSolver.mjSOL_CG,
    'newton': mujoco.mjtSolver.mjSOL_NEWTON,
}['cg']

model.opt.iterations = 6
model.opt.ls_iterations = 6

data_path = "../save_data_AVG.mat"

# Load kp_data, /1000 to scale data (from mm to meters)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

kp_data = ctrl.prep_kp_data(kp_data, stac_keypoint_order)

# Set up mjcf
root = mjcf.from_path(ratpath)
physics, mj_model = ctrl.create_body_sites(root)
ctrl.part_opt_setup(physics)

### running q_phase for a couple frames te tune hyperparams to minimize error

In [4]:
utils.params['LR'] = 1e-1
utils.params['MAXITER'] = 4000
mjx_model, q, x, walker_body_sites, clip_data = ctrl.fit(mj_model, kp_data[:20])

Root Optimization:
q_opt 1 finished in 138.0289797782898 with an error of 0.000966157647781074
Replace 1 finished in 51.051597118377686
starting q_opt 2
q_opt 1 finished in 0.0060160160064697266 with an error of 0.0009694193722680211
Replace 2 finished in 0.0007390975952148438
Root optimization finished in 189.18756294250488
Calibration iteration: 1/6
Pose Optimization:
Pose Optimization done in 0.32557082176208496
Frame 1 done in 0.05457496643066406 with a final error of 0.00010361943714087829
Frame 2 done in 0.01512598991394043 with a final error of 9.16629214771092e-05
Frame 3 done in 0.013934850692749023 with a final error of 9.17407960514538e-05
Frame 4 done in 0.006763935089111328 with a final error of 9.077288268599659e-05
Frame 5 done in 0.012431144714355469 with a final error of 9.095269342651591e-05
Frame 6 done in 0.014518022537231445 with a final error of 9.203859372064471e-05
Frame 7 done in 0.013838052749633789 with a final error of 9.255456097889692e-05
Frame 8 done in 0

2024-04-09 11:34:56.357496: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_m_opt] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-04-09 11:35:42.665374: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m46.313586s

********************************
[Compiling module jit_m_opt] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


Final error of nan
offset optimization finished in 189.64097905158997
Calibration iteration: 2/6
Pose Optimization:
Pose Optimization done in 0.26142406463623047
Frame 1 done in 0.013848066329956055 with a final error of 9.074561967281625e-05
Frame 2 done in 0.011557817459106445 with a final error of 9.303536353399977e-05
Frame 3 done in 0.013826131820678711 with a final error of 9.458755812374875e-05
Frame 4 done in 0.012115001678466797 with a final error of 9.145402145804837e-05
Frame 5 done in 0.012033939361572266 with a final error of 9.091140964301303e-05
Frame 6 done in 0.012900829315185547 with a final error of 9.164124639937654e-05
Frame 7 done in 0.011234045028686523 with a final error of 9.329720342066139e-05
Frame 8 done in 0.012665987014770508 with a final error of 9.17804500204511e-05
Frame 9 done in 0.015472888946533203 with a final error of 9.232705633621663e-05
Frame 10 done in 0.01606297492980957 with a final error of 9.200299973599613e-05
Frame 11 done in 0.0128309726