In [1]:
import utils
import mujoco
import os
import pickle
from scipy.io import savemat 
from dm_control import mjcf
import numpy as np
import jax
from jax import numpy as jnp
import time
from controller import *
import stac_base

# jax.disable_jit(disable=True)

<contextlib._GeneratorContextManager at 0x140d46240>

In [2]:
# If your machine is low on ram:
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.6'
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

def get_clip(kp_data, n_frames):
    import random
    max_index = kp_data.shape[0] - n_frames + 1
    rand_start = random.randint(0, max_index)
    return kp_data[rand_start:rand_start+n_frames,:]

In [3]:
param_path = "../params/params.yaml"
utils.init_params(param_path)

In [4]:
rat_xml = "../models/rodent_stac.xml"
rat23 = "../models/rat23.mat"
data_path = "../save_data_AVG.mat"
# data_path = "/n/holylabs/LABS/olveczky_lab/holylfs02/Everyone/dannce_rig/dannce_ephys/art/2020_12_22_1/DANNCE/predict03/save_data_AVG.mat" 

fit_path = "floating_fit.p"
transform_path = "floating_transform.p"
utils.params['FTOL'] = 1e-05
utils.params['n_fit_frames'] = 500
utils.params['N_ITERS'] = 1
skip_transform = True

In [5]:
model = mujoco.MjModel.from_xml_path(rat_xml)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
model.opt.iterations = 1
model.opt.ls_iterations = 4

In [6]:
start_time = time.time()

root = mjcf.from_path(rat_xml)

# Default ordering of mj sites is alphabetical, so we reorder to match
kp_names = utils.loadmat(rat23)["joint_names"]
utils.params["kp_names"] = kp_names

# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)
# Load kp_data, /1000 to scale data (from mm to meters i think?)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

kp_data = prep_kp_data(kp_data, stac_keypoint_order)

# setup for fit
physics, mj_model = set_body_sites(root)
part_opt_setup(physics)

# Run root optimization
# Create mjx model and data
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.make_data(mjx_model)

# Get and set the offsets of the markers
offsets = jnp.copy(stac_base.get_site_pos(mjx_model))
offsets *= utils.params['SCALE_FACTOR']

# print(mjx_model.site_pos, mjx_model.site_pos.shape)
mjx_model = stac_base.set_site_pos(mjx_model, offsets)

# forward is used to calculate xpos and such
mjx_data = mjx.kinematics(mjx_model, mjx_data)
mjx_data = mjx.com_pos(mjx_model, mjx_data)
mjx_data = root_optimization(mjx_model, mjx_data, kp_data)

Root Optimization:


TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

save in the same format and use viz.py

In [None]:
kp_data.shape

(360000, 69)

In [None]:
offset_path = "root.p"
data = {
        "kp_data": kp_data,
        "qpos": [mjx_data.qpos[:]],
        "offsets": offsets,
        # "walker_body_sites": [stac_base.get_site_xpos(mjx_data)],
        "xpos": [mjx_data.xpos[:]],
        "names_qpos": initialize_part_names(physics) # utils.params["part_names"],
        # "names_xpos": physics.named.data.xpos.axes.row.names,
    }
if os.path.dirname(offset_path) != "":
    os.makedirs(os.path.dirname(offset_path), exist_ok=True)
with open(offset_path, "wb") as output_file:
    pickle.dump(data, output_file, protocol=2)

In [None]:
import viz
from IPython.display import Video

save_path = "../videos/root.mp4"
viz.render_mujoco(
    "../params/params.yaml",
    offset_path,
    save_path,
    frames = np.arange(1),
)
Video(save_path)

(1, 74)
walker: <walkers.Rat object at 0x17ea7f230>
arena: <arenas.DannceArena object at 0x3b46a12b0>
<dm_control.composer.environment.Environment object at 0x3ac21e750>


  from pkg_resources import resource_filename


In [None]:

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = 1
scene_option.sitegroup[2] = 1
scene_option.sitegroup[3] = 1
scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = True
scene_option.flags[enums.mjtVisFlag.mjVIS_LIGHT] = False
scene_option.flags[enums.mjtVisFlag.mjVIS_CONVEXHULL] = True
# render the frame (frame 0)
renderer = mujoco.Renderer(model)
mj_data = mjx.get_data(mj_model, mjx_data)
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()

NameError: name 'enums' is not defined