In [None]:
import numpy as np
from mjx_planner import cem_planner
import mujoco.mjx as mjx 
import mujoco
import time
import jax.numpy as jnp
import jax
import os
import matplotlib.pyplot as plt

num_dof = 6
num_batch = 500
num_steps = 50
maxiter_cem = 10

start_time = time.time()
cem =  cem_planner(
    num_dof=num_dof, 
    num_batch=num_batch, 
    num_steps=num_steps, 
    maxiter_cem=maxiter_cem,
    w_pos=5,
    num_elite=0.02
    )
print(f"Total time: {round(time.time()-start_time, 2)}s")

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))



 Default backend: gpu
 Model path: /home/hurova/thesis_2025/mjx_planner_v2/ur5e_hande_mjx/scene.xml 
 Timestep: 0.02 
 Target position: [-0.3  0.   0.8] 
 CEM Iter: 10 
 Number of batches: 500 
 Number of steps per trajectory: 50 
 Time per trajectory: 1.0
Total time: 10.97s


In [None]:
theta_list = list()
thetadot_list = list()
costs = list()
mjx_model = cem.mjx_model
mjx_data = cem.mjx_data


thetadot = np.array([0]*6)
acc = np.array([0]*6)

theta_list.append(mjx_data.qpos[:6])
thetadot_list.append(thetadot)
costs.append(np.linalg.norm(mjx_data.xpos[cem.hande_id] - cem.target_pos))

mpc_steps = 100
for step in range(100):

    cost, best_cost_g, best_vels, best_traj = cem.compute_cem(mjx_data)

    # thetadot = np.mean(best_vels[1:4], axis=0)
    thetadot = best_vels[1]

    qvel = mjx_data.qvel.at[:num_dof].set(thetadot)
    mjx_data = mjx_data.replace(qvel=qvel)
    mjx_data = cem.jit_step(mjx_model, mjx_data)

    eef_pos = mjx_data.xpos[cem.hande_id]
    cost_ = np.linalg.norm(eef_pos - cem.target_pos)
    theta_list.append(mjx_data.qpos[:num_dof])
    thetadot_list.append(thetadot)
    costs.append(cost_)




In [None]:
cost, best_cost_g, best_vels, best_traj = cem.compute_cem()

In [2]:
theta_list = list()
thetadot_list = list()
costs = list()

model = cem.model
data = cem.data
xi_mean = jnp.zeros(cem.nvar)

thetadot = np.array([0]*6)

theta_list.append(data.qpos[:6].copy())
thetadot_list.append(thetadot)
costs.append(np.linalg.norm(cem.mjx_data.xpos[cem.hande_id] - cem.target_pos))

mpc_steps = 100
for step in range(100):
    start_time = time.time()
    cost, best_cost_g, best_vels, best_traj, _ = cem.compute_cem(xi_mean, data.qpos[:6], data.qvel[:6])

    thetadot = best_vels[1]

    data.qvel[:6] = thetadot
    mujoco.mj_step(model, data)

    eef_pos = data.xpos[cem.hande_id]
    cost_ = np.linalg.norm(eef_pos - cem.target_pos)
    theta_list.append(data.qpos[:6].copy())
    thetadot_list.append(thetadot)
    costs.append(cost_)
    print(f'Iter #{step}: {round(time.time()-start_time, 2)}s')


Iter #0: 23.61s
Iter #1: 0.36s
Iter #2: 0.36s
Iter #3: 0.36s
Iter #4: 0.36s
Iter #5: 0.36s
Iter #6: 0.36s
Iter #7: 0.35s
Iter #8: 0.35s
Iter #9: 0.35s
Iter #10: 0.35s
Iter #11: 0.35s
Iter #12: 0.36s
Iter #13: 0.36s
Iter #14: 0.36s
Iter #15: 0.36s
Iter #16: 0.36s
Iter #17: 0.35s
Iter #18: 0.36s
Iter #19: 0.36s
Iter #20: 0.36s
Iter #21: 0.36s
Iter #22: 0.36s
Iter #23: 0.36s
Iter #24: 0.36s
Iter #25: 0.36s
Iter #26: 0.36s
Iter #27: 0.36s
Iter #28: 0.36s
Iter #29: 0.36s
Iter #30: 0.36s
Iter #31: 0.36s
Iter #32: 0.36s
Iter #33: 0.36s
Iter #34: 0.36s
Iter #35: 0.36s
Iter #36: 0.36s
Iter #37: 0.36s
Iter #38: 0.36s
Iter #39: 0.36s
Iter #40: 0.36s
Iter #41: 0.36s
Iter #42: 0.36s
Iter #43: 0.36s
Iter #44: 0.36s
Iter #45: 0.36s
Iter #46: 0.36s
Iter #47: 0.36s
Iter #48: 0.36s
Iter #49: 0.36s
Iter #50: 0.36s
Iter #51: 0.36s
Iter #52: 0.36s
Iter #53: 0.36s
Iter #54: 0.36s
Iter #55: 0.36s
Iter #56: 0.36s
Iter #57: 0.36s
Iter #58: 0.36s
Iter #59: 0.36s
Iter #60: 0.36s
Iter #61: 0.36s
Iter #62: 0.36s
I

In [3]:
np.savetxt('data/best_vels.csv',thetadot_list, delimiter=",")

In [None]:

plt.figure(0)
plt.plot(cost)
plt.title("Output Costs")
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.show()

plt.figure(0)
plt.plot(best_cost_g)
plt.title("Best Cost Goal")
plt.xlabel("Step")
plt.ylabel("Cost")
plt.show()

plt.figure(0)
plt.plot(best_vels)
plt.title("Best Velocities")
plt.xlabel("Step")
plt.ylabel("Velocity")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')
plt.show()

plt.figure(0)
plt.plot(best_traj)
plt.title("Best Trajectory")
plt.xlabel("Step")
plt.ylabel("Joint States")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')
plt.show()


In [None]:
# step = 100
plt.figure(0)
plt.plot(thetadot_list)
plt.title("Best Velocities")
plt.xlabel("Step")
plt.ylabel("Velocity")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')
plt.show()

plt.figure(1)
plt.plot(theta_list)
plt.title("Best Trajectory")
plt.xlabel("Step")
plt.ylabel("Joint States")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')
plt.show()

plt.figure(2)
plt.plot(costs)
plt.title("Best Cost")
plt.xlabel("Step")
plt.ylabel("Cost")
plt.show()


In [None]:
np.savetxt('data/best_traj.csv',theta_list, delimiter=",")
np.savetxt('data/best_vels.csv',thetadot_list, delimiter=",")

In [None]:
pos = list()
vel = list()
costs = list()

model = cem.model
data = cem.data

pos.append(data.qpos[:6])
vel.append(data.qvel[:6])
# costs.append(np.linalg.norm(data.xpos[cem.hande_id] - cem.target_pos))

# qvel = cem.compute_cem(cem.mjx_data.qpos[:6], cem.mjx_data.qvel[:6])
rec = [0]*6

start = time.time()
for step in range(mpc_steps-1):
    step_start = time.time()

    rec = cem.compute_cem(data.qpos[:6], rec)

    data.qvel[:6] = rec

    mujoco.mj_step(model, data)

    # eef_pos = cem.mjx_data.xpos[cem.hande_id]
    # cost = np.linalg.norm(eef_pos - cem.target_pos)
    pos.append(data.qpos[:num_dof])
    vel.append(rec)
    # costs.append(cost)


In [None]:
theta_init = np.tile([1.5, -1.8, 1.75, -1.25, -1.6, 0], (num_batch, 1))
# thetadot_init = np.zeros((num_batch, num_dof  ))
thetadot_init = np.tile(vel[3], (num_batch, 1))
thetaddot_init = np.zeros((num_batch, num_dof  ))
thetadot_fin = np.zeros((num_batch, num_dof  ))
thetaddot_fin = np.zeros((num_batch, num_dof  ))

state_term = np.hstack(( theta_init, thetadot_init, thetaddot_init, thetadot_fin, thetaddot_fin   ))
state_term = jnp.asarray(state_term)

maxiter_projection = 20
v_max = 0.8
a_max = 1.8
p_max = 180*np.pi/180

xi_mean = jnp.zeros(cem.nvar)
xi_cov = 5*jnp.identity(cem.nvar)

key, subkey = jax.random.split(cem.key)

xi_samples, key = cem.compute_xi_samples(key, xi_mean, xi_cov ) # xi_samples are matrix of batch times (cem.num_dof*cem.nvar_single = cem.nvar)
xi_filtered = cem.compute_projection_filter(xi_samples, state_term, maxiter_projection, v_max, a_max, p_max)
theta_batch = jnp.dot(cem.A_theta, xi_filtered.T).T 
thetadot = jnp.dot(cem.A_thetadot, xi_filtered.T).T

In [None]:
thetadot_ = thetadot.reshape(thetadot.shape[0], num_dof, thetadot.shape[1]//num_dof)


In [None]:
for t in thetadot_:
    plt.plot(t.T)
plt.title("Best Velocities")
plt.xlabel("Step")
plt.ylabel("Velocity")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')
plt.show()