In [1]:
import utils
import mujoco
import os
import pickle
from scipy.io import savemat 
from dm_control import mjcf
import numpy as np
import jax
from jax import numpy as jnp


In [3]:
# If youre machine is low on ram:
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"

In [4]:
def save(fit_data, save_path):
    """Save data.

    Args:
        save_path (Text): Path to save data. Defaults to None.
    """
    if os.path.dirname(save_path) != "":
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    _, file_extension = os.path.splitext(save_path)
    if file_extension == ".p":
        with open(save_path, "wb") as output_file:
            pickle.dump(fit_data, output_file, protocol=2)
    elif file_extension == ".mat":
        savemat(save_path, fit_data)

In [9]:
from controller import *
import stac_base

In [6]:
# relative pathing no working in notebook rn
utils.init_params("/home/charles/github/stac-mjx/params/params.yaml")
ratpath = "/home/charles/github/stac-mjx/models/rodent.xml"
rat23path = "/home/charles/github/stac-mjx/models/rat23.mat"
model = mujoco.MjModel.from_xml_path(ratpath)
model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
model.opt.iterations = 1
model.opt.ls_iterations = 1

# Need to download this data file and provide the path
data_path = "/home/charles/Desktop/save_data_AVG.mat"
offset_path = "offset.p"

root = mjcf.from_path(ratpath)

# Default ordering of mj sites is alphabetical, so we reorder to match
kp_names = utils.loadmat(rat23path)["joint_names"]
utils.params["kp_names"] = kp_names

# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)
# Load kp_data
kp_data = utils.loadmat(data_path)["pred"][:] / 1000


In [7]:
# kp_data
# TODO: store kp_data used in fit in another variable (small slice of kpdata)
kp_data = prep_kp_data(kp_data, stac_keypoint_order)
# chunk it to pass int vmapped functions
kp_data, n_envs = chunk_kp_data(kp_data)

In [8]:
fit_kp_data = kp_data[:100]
fit_kp_data.shape

(100, 150, 69)

In [8]:
# fit
fit_data = test_opt(root, fit_kp_data)
save(fit_data, offset_path)

first forward step done in 23.4201557636261
Root Optimization:
useless forward step done  in 47.260462522506714
next useless forward step done in 0.007994890213012695
begin optimizing:


testing one opt by looping over a jitted solver.update() function and graphing. this will also help tune hyper params

In [None]:
"""Benchmark LBFGS implementation."""

import time
import jaxopt

import numpy as onp

import matplotlib.pyplot as plt

maxiters = 25
def benchmark_jaxopt(linesearch):
  
  fun = stac_base.q_loss
  init = jnp.zeros((X.shape[1], 5))
  lbfgs = jaxopt.LBFGSB(fun=fun, linesearch=linesearch)
  state = lbfgs.init_state(init, data=data)
  errors = onp.zeros(maxiters)
  params = init

  jit_update = jax.jit(lbfgs.update)
  for it in range(maxiters):
    params, state = jit_update(params, state, data=data)
    errors[it] = state.error

  return errors

errors_backtracking = benchmark_jaxopt("backtracking")
errors_zoom = benchmark_jaxopt("linesearch")

plt.figure()
plt.plot(jnp.arange(30), errors_backtracking, label="backtracking")
plt.plot(jnp.arange(30), errors_zoom, label="zoom")
plt.xlabel("Iterations")
plt.ylabel("Gradient error")
plt.yscale("log")
plt.legend(loc="best")
plt.show()
