In [1]:
import mujoco
from mujoco import mjx
import jax
from jax import numpy as jnp
from jax.lib import xla_bridge
from dm_control import mjcf
import numpy as np

import os
import pickle
import time
import argparse
import random
import logging 
import sys
import controller as ctrl
import utils

from compute_stac import *
import operations as op
import stac_base

#Turning off warnings bc jax of deprecation warnings
import warnings
warnings.filterwarnings("ignore")

%cd stac-mjx


[Errno 2] No such file or directory: 'stac-mjx'
/Users/charleszhang/GitHub/stac-mjx/stac-mjx


In [2]:
def rootopt_then_frameopt(mj_model, kp_data):
    
    # Create mjx model and data
    mjx_model = mjx.put_model(mj_model)
    mjx_data = mjx.make_data(mjx_model)
    
    # Get and set the offsets of the markers
    offsets = jnp.copy(op.get_site_pos(mjx_model))
    offsets *= utils.params['SCALE_FACTOR']
    
    # print(mjx_model.site_pos, mjx_model.site_pos.shape)
    mjx_model = op.set_site_pos(mjx_model, offsets)

    # forward is used to calculate xpos and such
    mjx_data = mjx.kinematics(mjx_model, mjx_data)
    mjx_data = mjx.com_pos(mjx_model, mjx_data)
    
    # Begin optimization steps
    mjx_data = root_optimization(mjx_model, mjx_data, kp_data)

    print("optimizing first frame")
    return pose_optimization(mjx_model, mjx_data, kp_data)

    # print("starting offset optimization")
    # mjx_model, mjx_data = offset_optimization(mjx_model, mjx_data, kp_data, offsets, q)

In [5]:
utils.init_params("../params/params.yaml")

# When using nvidia gpu do this thing
if xla_bridge.get_backend().platform == 'gpu':
    os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
    )
    # Set N_GPUS
    utils.params["N_GPUS"] = jax.local_device_count("gpu")


ratpath = "../models/rodent_stac_optimized.xml"
rat23path = "../models/rat23.mat"
kp_names = utils.loadmat(rat23path)["joint_names"]
utils.params["kp_names"] = kp_names
# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)   

model = mujoco.MjModel.from_xml_path(ratpath)

model.opt.solver = {
    'cg': mujoco.mjtSolver.mjSOL_CG,
    'newton': mujoco.mjtSolver.mjSOL_NEWTON,
}['cg']

model.opt.iterations = 6
model.opt.ls_iterations = 6

data_path = "../save_data_AVG.mat"

# Load kp_data, /1000 to scale data (from mm to meters)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

kp_data = ctrl.prep_kp_data(kp_data, stac_keypoint_order)

# Set up mjcf
root = mjcf.from_path(ratpath)
physics, mj_model = ctrl.create_body_sites(root)
ctrl.part_opt_setup(physics)

### running q_phase for a couple frames te tune hyperparams to minimize error

In [6]:
utils.params['Q_LR'] = 1e-3
utils.params['M_LR'] = 1e-4
utils.params['MAXITER'] = 20000
mjx_model, q, x, walker_body_sites, clip_data = ctrl.fit(mj_model, kp_data[:20])

Root Optimization:
q_opt 1 finished in 0.016587018966674805 with an error of 0.0009948238730430603
Replace 1 finished in 0.0013530254364013672
starting q_opt 2
q_opt 2 finished in 0.004716157913208008 with an error of 0.0009773391066119075
Replace 2 finished in 0.0005338191986083984
Root optimization finished in 0.025484800338745117
Calibration iteration: 1/6
Pose Optimization:
Pose Optimization done in 0.17239093780517578
Frame 1 done in 0.016810894012451172 with a final error of 0.00047071382869035006
Frame 2 done in 0.009987115859985352 with a final error of 0.0004154726048000157
Frame 3 done in 0.008481025695800781 with a final error of 0.00037103722570464015
Frame 4 done in 0.008517265319824219 with a final error of 0.0003326745063532144
Frame 5 done in 0.008247852325439453 with a final error of 0.00028430274687707424
Frame 6 done in 0.008331060409545898 with a final error of 0.0002446977305226028
Frame 7 done in 0.008553028106689453 with a final error of 0.00020171154756098986
Fr