In [None]:

import jax.numpy as jnp
import matplotlib.pyplot as plt
import mujoco.mjx as mjx 
import mujoco
import jax
import os
import numpy as np
from functools import partial
import time
import mediapy as media


In [None]:
num_dof = 6
num_batch = 500
num_steps = 50
timestep = 0.05

In [None]:
trajectories_file_path = f"{os.getcwd()}/samples/trajectories.csv" 
trajectories = np.genfromtxt(trajectories_file_path, delimiter=',')
for traj in trajectories:
    plt.plot(traj.reshape(num_dof, num_steps).T)
plt.title("Trajectories")
plt.xlabel("Step")
plt.ylabel("Velocity")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')

plt.show()

In [None]:
model_path = f"{os.getcwd()}/ur5e_hande_mjx/scene.xml" 
model = mujoco.MjModel.from_xml_path(model_path)
model.opt.timestep = timestep

data = mujoco.MjData(model)
data.qpos[:6] = jnp.array([1.5, -1.8, 1.75, -1.25, -1.6, 0])

mujoco.mj_forward(model, data)
renderer = mujoco.Renderer(model)

camera = mujoco.MjvCamera()
camera.lookat[:] = [0.0, 0.0, 0.0]
camera.distance = 3.0

scene_option = mujoco.MjvOption()
scene_option.sitegroup[:] = False
scene_option.sitegroup[1] = True

In [None]:
renderer.update_scene(data, scene_option=scene_option, camera=camera)
pixels = renderer.render()

plt.imshow(pixels)
plt.axis('off') 
plt.show()

In [None]:
class Simulator():
	def __init__(self, model, data, num_dof=6, num_batch=100, num_steps=200):
		super(Simulator, self).__init__()

		self.num_dof = num_dof
		self.num_batch = num_batch
		self.num = num_steps

		self.target_pos = model.body(name="target").pos
		self.target_rot = model.body(name="target").quat

		self.hande_id = model.body(name="hande").id
		self.tcp_id = model.site(name="tcp").id
		
		self.mjx_model = mjx.put_model(model)
		self.mjx_data = mjx.put_data(model, data)
		self.jit_step = jax.jit(mjx.step)

		self.compute_rollout_batch = jax.vmap(self.compute_rollout_single, in_axes = (0))
		self.compute_cost_batch = jax.vmap(self.compute_cost_single, in_axes = (0))

	@partial(jax.jit, static_argnums=(0,))
	def compute_cost_single(self, eef_pos, eef_rot):
		cost_g = jnp.linalg.norm(eef_pos - self.target_pos)

		dot_product = jnp.abs(jnp.dot(eef_rot/jnp.linalg.norm(eef_rot, axis=1).reshape(1, self.num).T, self.target_rot/jnp.linalg.norm(self.target_rot)))
		dot_product = jnp.clip(dot_product, -1.0, 1.0)
		cost_r_ = 2 * jnp.arccos(dot_product)
		cost_r = jnp.sum(cost_r_)

		cost = 1*cost_g + 0.5*cost_r
		return cost
		
	@partial(jax.jit, static_argnums=(0,))
	def mjx_step(self, mjx_data, thetadot_single):
		qvel = mjx_data.qvel.at[:self.num_dof].set(thetadot_single)
		mjx_data = mjx_data.replace(qvel=qvel)
		mjx_data = self.jit_step(self.mjx_model, mjx_data)

		theta = mjx_data.qpos[:self.num_dof]
		eef_rot = mjx_data.xquat[self.hande_id]	
		eef_pos = mjx_data.site_xpos[self.tcp_id]

		return mjx_data, (theta, eef_pos, eef_rot)
	
	@partial(jax.jit, static_argnums=(0,))
	def compute_rollout_single(self, thetadot):
		mjx_data = self.mjx_data
		thetadot_single = thetadot.reshape(self.num_dof, self.num)
		_, out = jax.lax.scan(self.mjx_step, mjx_data, thetadot_single.T, length=self.num)
		theta, eef_pos, eef_rot= out
		return theta.T.flatten(), eef_pos, eef_rot
	
	@partial(jax.jit, static_argnums=(0,))
	def get_best_traj(self, thetadot):
		theta, eef_pos, eef_rot = self.compute_rollout_batch(thetadot)
		cost_batch = self.compute_cost_batch(eef_pos, eef_rot)

		idx_min = jnp.argmin(cost_batch[-1])
		cost = jnp.min(cost_batch)
		best_traj = thetadot[idx_min].reshape((self.num_dof, self.num)).T
		return cost, best_traj
		
    

In [None]:
sim = Simulator(model=model, data=data, num_batch=num_batch, num_dof=num_dof, num_steps=num_steps)

In [None]:
cost, best_traj = sim.get_best_traj(trajectories)

In [None]:
print(f"Trajectory Cost: {cost}")
plt.plot(best_traj)
plt.title("Trajectories")
plt.xlabel("Step")
plt.ylabel("Velocity")
plt.legend(['joint 1', 'joint 2', 'joint 3', 'joint 4', 'joint 5', 'joint 6'], loc='upper left')
plt.show()

In [None]:
framerate = 30
frames = list()
for step in best_traj:
    step_start = time.time()
    data.qvel[:6] = step

    mujoco.mj_step(model, data)

    if len(frames) < data.time * framerate:
        renderer.update_scene(data, scene_option=scene_option, camera=camera)
        pixels = renderer.render()
        frames.append(pixels)

    time_until_next_step = model.opt.timestep - (time.time() - step_start)
    if time_until_next_step > 0:
        time.sleep(time_until_next_step)   

In [None]:
media.show_video(frames, fps=framerate)