In [1]:
import utils
import mujoco
import os
import pickle
from scipy.io import savemat 
from dm_control import mjcf
import numpy as np
import jax
from jax import numpy as jnp
import time
from controller import *
import os


  from jax.config import config
  from jax.config import config
  from jax.config import config
  from jax.config import config
  from jax.config import config


In [2]:
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.20'
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

def get_clip(kp_data, n_frames):
    import random
    max_index = kp_data.shape[0] - n_frames + 1
    rand_start = random.randint(0, max_index)
    return kp_data[rand_start:rand_start+n_frames,:]

In [3]:
utils.init_params("../params/params.yaml")

### Define data paths and parameters

In [4]:
rat_xml = "../models/rodent_stac.xml"
rat23 = "../models/rat23.mat"
# data_path = "save_data_AVG.mat"
data_path = "/n/holylabs/LABS/olveczky_lab/holylfs02/Everyone/dannce_rig/dannce_ephys/art/2020_12_22_1/DANNCE/predict03/save_data_AVG.mat" 

fit_path = "floating_fit.p"
transform_path = "floating_transform.p"
utils.params['Q_TOL'] = 1e-05
utils.params['M_TOL'] = 1e-05
utils.params['n_fit_frames'] = 500
utils.params['N_ITERS'] = 1
skip_transform = True

### Set up mujoco model

In [5]:
model = mujoco.MjModel.from_xml_path(rat_xml)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
model.opt.iterations = 1
model.opt.ls_iterations = 4

### Run stac

In [None]:
start_time = time.time()

root = mjcf.from_path(rat_xml)

# Default ordering of mj sites is alphabetical, so we reorder to match
kp_names = utils.loadmat(rat23)["joint_names"]
utils.params["kp_names"] = kp_names

# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)
# Load kp_data, /1000 to scale data (from mm to meters i think?)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

kp_data = prep_kp_data(kp_data, stac_keypoint_order)

# setup for fit
physics, mj_model = set_body_sites(root)
part_opt_setup(physics)

# Running fit then transform
print(f"kp_data shape: {kp_data.shape}")
print(f"Running fit() on {utils.params['n_fit_frames']}")
clip = get_clip(kp_data, utils.params['n_fit_frames'])
print(f"clip shape: {clip.shape}")
mjx_model, q, x, walker_body_sites, kp_data = fit(mj_model, clip)

fit_data = package_data(
    mjx_model, physics, q, x, walker_body_sites, kp_data
)

print(f"saving data to {fit_path}")
save(fit_data, fit_path)

if skip_transform:
    print("skipping transform()")
    return

print("Running transform()")
with open(fit_path, "rb") as file:
    fit_data = pickle.load(file)

offsets = fit_data["offsets"] 
kp_data, n_envs = chunk_kp_data(kp_data)
mjx_model, q, x, walker_body_sites, kp_data = transform(mj_model, kp_data, offsets)

transform_data = package_data(
    mjx_model, physics, q, x, walker_body_sites, kp_data, batched=True
)

print(f"saving data to {transform_path}")
save(transform_data, transform_path)

print(f"Job complete in {time.time()-start_time}")


kp_data shape: (360000, 69)
Running fit() on 500
clip shape: (500, 69)
Root Optimization:
data.qpos: [2.78957345e-02 1.85267244e-07 6.12806154e-02 1.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00