## Comodo Example JaxSim 
This examples, load a basic robot model (i.e. composed only of basic shapes), modifies the links of such a robot model by elongating the legs, define instances of the TSID (Task Based Inverse Dynamics) and Centroidal MPC  controller and simulate the behavior of the robot using mujoco.  

In [1]:
# Comodo import
from comodo.jaxsimSimulator import JaxsimSimulator
from comodo.robotModel.robotModel import RobotModel
from comodo.robotModel.createUrdf import createUrdf
from comodo.centroidalMPC.centroidalMPC import CentroidalMPC
from comodo.centroidalMPC.mpcParameterTuning import MPCParameterTuning
from comodo.TSIDController.TSIDParameterTuning import TSIDParameterTuning
from comodo.TSIDController.TSIDController import TSIDController
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""

os.environ["XLA_PYTHON_CLIENT_MEM_PREALLOCATE"] = "False"

[34mjaxsim[1056615][0m [1;30mINFO[0m Enabling JAX to use 64bit precision
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [2]:
# General  import
import xml.etree.ElementTree as ET
import numpy as np
import tempfile
import urllib.request

In [3]:
# Getting stickbot urdf file and convert it to string
urdf_robot_file = tempfile.NamedTemporaryFile(mode="w+")
url = "https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf"
urllib.request.urlretrieve(url, urdf_robot_file.name)
# Load the URDF file
tree = ET.parse(urdf_robot_file.name)
root = tree.getroot()

# Convert the XML tree to a string
robot_urdf_string_original = ET.tostring(root)

create_urdf_instance = createUrdf(
    original_urdf_path=urdf_robot_file.name, save_gazebo_plugin=False
)

In [4]:
# Define parametric links and controlled joints
legs_link_names = ["hip_3", "lower_leg"]
joint_name_list = [
    "l_hip_pitch",  # 0
    "l_shoulder_pitch",  # 1
    "r_hip_pitch",  # 2
    "r_shoulder_pitch",  # 3
    "l_hip_roll",  # 4
    "l_shoulder_roll",  # 5
    "r_hip_roll",  # 6
    "r_shoulder_roll",  # 7
    "l_hip_yaw",  # 8
    "l_shoulder_yaw",  # 9
    "r_hip_yaw",  # 10
    "r_shoulder_yaw",  # 11
    "l_knee",  # 12
    "l_elbow",  # 13
    "r_knee",  # 14
    "r_elbow",  # 15
    "l_ankle_pitch",  # 16
    "r_ankle_pitch",  # 17
    "l_ankle_roll",  # 18
    "r_ankle_roll",  # 19
]
mj_list = [
    "r_shoulder_pitch",  # 0
    "r_shoulder_roll",  # 1
    "r_shoulder_yaw",  # 2
    "r_elbow",  # 3
    "l_shoulder_pitch",  # 4
    "l_shoulder_roll",  # 5
    "l_shoulder_yaw",  # 6
    "l_elbow",  # 7
    "r_hip_pitch",  # 8
    "r_hip_roll",  # 9
    "r_hip_yaw",  # 10
    "r_knee",  # 11
    "r_ankle_pitch",  # 12
    "r_ankle_roll",  # 13
    "l_hip_pitch",  # 14
    "l_hip_roll",  # 15
    "l_hip_yaw",  # 16
    "l_knee",  # 17
    "l_ankle_pitch",  # 18
    "l_ankle_roll",  # 19
]

In [5]:
# Define the robot modifications
modifications = {}
for item in legs_link_names:
    left_leg_item = "l_" + item
    right_leg_item = "r_" + item
    modifications.update({left_leg_item: 1.2})
    modifications.update({right_leg_item: 1.2})
# Motors Parameters
Im_arms = 1e-3 * np.ones(4)  # from 0-4
Im_legs = 1e-3 * np.ones(6)  # from 5-10
kv_arms = 0.001 * np.ones(4)  # from 11-14
kv_legs = 0.001 * np.ones(6)  # from 20

Im = np.concatenate((Im_arms, Im_arms, Im_legs, Im_legs))
kv = np.concatenate((kv_arms, kv_arms, kv_legs, kv_legs))

In [6]:
# Modify the robot model and initialize
create_urdf_instance.modify_lengths(modifications)
urdf_robot_string = create_urdf_instance.write_urdf_to_file()
create_urdf_instance.reset_modifications()
robot_model_init = RobotModel(urdf_robot_string, "stickBot", mj_list)
s_des, xyz_rpy, H_b = robot_model_init.compute_desired_position_walking()
# Update base height
xyz_rpy[2] = xyz_rpy[2] + 0.007


******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************

This is Ipopt version 3.14.14, running with linear solver MUMPS 5.6.2.

Number of nonzeros in equality constraint Jacobian...:      126
Number of nonzeros in inequality constraint Jacobian.:        0
Number of nonzeros in Lagrangian Hessian.............:      142

Total number of variables............................:       27
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        0
                     variables with only upper bounds:        0
Total number of equality constraints.................:       21
Total number of inequality c

In [7]:
H_b

DM(
[[1, 0, 0, -0.0593655], 
 [0, 1, 0, -2.43658e-05], 
 [0, 0, 1, 0.638397], 
 [00, 00, 00, 1]])

In [8]:
# Check that joint list from mujoco and jaxsim have the same elements (just ordered differently)
get_joint_map = lambda from_, to: np.array(list(map(to.index, from_)))
to_mujoco = get_joint_map(mj_list, joint_name_list)
to_jaxsim = get_joint_map(joint_name_list, mj_list)

assert np.array_equal(np.array(joint_name_list)[to_mujoco], mj_list)
assert np.array_equal(np.array(mj_list)[to_jaxsim], joint_name_list)

In [9]:
s_des = np.array(
    [
        0.0,
        0.251,
        0.0,
        0.616,
        0.0,
        0.251,
        0.0,
        0.616,
        0.50082726,
        0.00300592,
        -0.00164537,
        -1.0,
        -0.49917522,
        -0.00342677,
        0.49922533,
        0.00300592,
        -0.00163911,
        -1.0,
        -0.50077713,
        -0.00342377,
    ]
)[to_jaxsim]

In [10]:
# s_des = np.array([ 0.        ,  0.251     ,  0.        ,  0.616     ,  0.        ,
#         0.251     ,  0.        ,  0.616     ,  0.50082726,  0.00300592,
#        -0.00164537, -1.        , -0.6, -0.00342677,  0.49922533,
#         0.00300592, -0.00163911, -1.        , -0.60077713, -0.00342377])[get_joint_map(joint_name_list,mj_list)]

In [11]:
# Define simulator and set initial position
jax_instance = JaxsimSimulator()
jax_instance.load_model(robot_model_init, s=s_des, xyz_rpy=xyz_rpy, kv_motors=kv, Im=Im)
s, ds, tau = jax_instance.get_state()
t = 0.0  # jax_instance.get_simulation_time()
H_b = jax_instance.get_base()
w_b = jax_instance.get_base_velocity()
jax_instance.visualize_robot_flag = False
# jax_instance.render()
print(jax_instance.model.contact_model)

[34mjaxsim[1056615][0m [1;30mDEBUG[0m [32mFound model 'stickBot' in SDF resource[0m
  switch_frame_convention(
[34mjaxsim[1056615][0m [1;30mDEBUG[0m [32mModel 'stickBot' is floating-base[0m
[34mjaxsim[1056615][0m [1;30mDEBUG[0m [32mConsidering 'root_link' as base link[0m
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: r_foot_rear->(r_foot_rear_ft_sensor)->r_ankle_2
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: r_foot_front->(r_foot_front_ft_sensor)->r_ankle_2
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: l_foot_rear->(l_foot_rear_ft_sensor)->l_ankle_2
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: l_foot_front->(l_foot_front_ft_sensor)->l_ankle_2
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: r_shoulder_3->(r_arm_ft_sensor)->r_shoulder_2
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: l_shoulder_3->(l_arm_ft_sensor)->l_shoulder_2
[34mjaxsim[1056615][0m [1;30mINFO[0m Lumping chain: r_hip_3->(r_leg_ft_sensor)->r_

RigidContacts(parameters=RigidContactsParams(mu=0.5, K=10000.0, D=100.0), terrain=FlatTerrain(_height=0.0))


In [12]:
# print("H_b: {}".format(H_b))
# print("w_b: {}".format(w_b))
# print("s: {}".format(s[to_jaxsim]))
jax_instance.model.link_names()
jax_instance.record_frame()

('root_link',
 'l_hip_1',
 'l_shoulder_1',
 'r_hip_1',
 'r_shoulder_1',
 'l_hip_2',
 'l_shoulder_2',
 'r_hip_2',
 'r_shoulder_2',
 'l_upper_leg',
 'l_upper_arm',
 'r_upper_leg',
 'r_upper_arm',
 'l_lower_leg',
 'l_elbow_1',
 'r_lower_leg',
 'r_elbow_1',
 'l_ankle_1',
 'r_ankle_1',
 'l_ankle_2',
 'r_ankle_2')

In [13]:
# Define the controller parameters  and instantiate the controller
# Controller Parameters
tsid_parameter = TSIDParameterTuning()
mpc_parameters = MPCParameterTuning()

# TSID Instance
TSID_controller_instance = TSIDController(frequency=0.01, robot_model=robot_model_init)
TSID_controller_instance.define_tasks(tsid_parameter)
TSID_controller_instance.set_state_with_base(s[to_mujoco], ds[to_mujoco], H_b, w_b, t)

# MPC Instance
step_length = 0.1
mpc = CentroidalMPC(robot_model=robot_model_init, step_length=step_length)
mpc.intialize_mpc(mpc_parameters=mpc_parameters)

# Set desired quantities
mpc.configure(s_init=s_des, H_b_init=H_b)
TSID_controller_instance.compute_com_position()
mpc.define_test_com_traj(TSID_controller_instance.COM.toNumPy())

# Set initial robot state  and plan trajectories
jax_instance.step()

# Reading the state
s, ds, tau = jax_instance.get_state()
H_b = jax_instance.get_base()
w_b = jax_instance.get_base_velocity()
t = 0.0

# MPC
mpc.set_state_with_base(s=s[to_mujoco], s_dot=ds[to_mujoco], H_b=H_b, w_b=w_b, t=t)
mpc.initialize_centroidal_integrator(
    s=s[to_mujoco], s_dot=ds[to_mujoco], H_b=H_b, w_b=w_b, t=t
)
mpc_output = mpc.plan_trajectory()

[DEBUG] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [StdImplementation::getParameterPrivate] Parameter named 'verbosity' not found.
[INFO] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [QPTSID::initialize] 'verbosity' not found. The following parameter will be used 'false'.
[DEBUG] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [StdImplementation::getParameterPrivate] Parameter named 'mask' not found.
[INFO] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [CoMTask::initialize] [CoMTask Task.] Unable to find the mask parameter. The default value is used: true true true.
[DEBUG] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [StdImplementation::getParameterPrivate] Parameter named 'mask' not found.
[INFO] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [SE3Task::initialize]  [SE3Task Optimal Control Element - Frame name:  l_sole] Unable to find the mask parameter. The default value is used: true true true.
[DEBUG] [2024-09-20 15:41:44.226] [thread: 1056615] [blf] [Std

AttributeError: 'NoneType' object has no attribute 'write_video'

In [None]:
# Set loop variables
TIME_TH = 2.0

# Define number of steps
n_step = int(TSID_controller_instance.frequency / jax_instance.dt)
n_step_mpc_tsid = int(mpc.get_frequency_seconds() / TSID_controller_instance.frequency)
print(n_step_mpc_tsid)
counter = 0
mpc_success = True
energy_tot = 0.0
succeded_controller = True
ref_state, com_state = [], []
import jaxsim.api as js

total_mass = js.model.total_mass(jax_instance.model)

In [15]:
mpc_reference, reference = [], []

In [16]:
# import manifpy as manif
# s_init = s
# ds_init = np.zeros(robot_model_init.NDoF)
# H_b_init = H_b
# w_b_init = np.zeros(6)
# com_init = TSID_controller_instance.COM.toNumPy()
# forces_left = np.zeros(3)
# forces_left[2] = total_mass * 9.81 / 2
# forces_right = np.zeros(3)
# forces_right[2] = total_mass * 9.81 / 2
# H_left_foot = robot_model_init.H_left_foot(H_b_init, s_init)
# H_rigth_foot = robot_model_init.H_right_foot(H_b_init, s_init)
# quaternion = [0.0, 0.0, 0.0, 1.0]
# leftPosition_casadi = np.array(H_left_foot[:3, 3])
# leftPosition = np.zeros(3)
# leftPosition[0] = float(leftPosition_casadi[0])
# leftPosition[1] = float(leftPosition_casadi[1])
# leftPosition[2] = float(leftPosition_casadi[2])
# leftPosition[2] = 0.0
# contact_left = manif.SE3(position=leftPosition, quaternion=quaternion)
# rightPosition = np.zeros(3)
# rightPosition_casadi = np.array(H_rigth_foot[:3, 3])
# rightPosition[0] = float(rightPosition_casadi[0])
# rightPosition[1] = float(rightPosition_casadi[1])
# rightPosition[2] = float(rightPosition_casadi[2])
# rightPosition[2] = 0.0
# contact_right = manif.SE3(position=rightPosition, quaternion=quaternion)

# TSID_controller_instance.update_task_references_mpc_balancing(
#     com=TSID_controller_instance.COM.toNumPy(),
#     dcom=np.zeros(3),
#     ddcom=np.zeros(3),
#     left_foot_desired=contact_left,20
#     right_foot_desired=contact_right,
#     s_desired=s_init ,
#     wrenches_left=forces_left,
#     wrenches_right=forces_right,
# )

# succeded_controller = TSID_controller_instance.run()

# # if not (succeded_controller):
# #     print("Controller failed")
# #     break

# tau_ctrl = TSID_controller_instance.get_torque()

In [None]:
jax_instance.get_feet_wrench()
diff = []
contacts = []
gravity_diff = []
torques = []
jax_instance.record_frame()

print(jax_instance.model.contact_model)
print(total_mass)

In [None]:
# Simulation-control loop
import time

simtime = time.perf_counter()

while t < TIME_TH:
    print(f"==== Time: {t:.4f}s ====", flush=True, end="\r")

    # Reading robot state from simulator
    s, ds, tau = jax_instance.get_state()
    energy_i = np.linalg.norm(tau)
    H_b = jax_instance.get_base()
    w_b = jax_instance.get_base_velocity()
    t = jax_instance.get_simulation_time()

    # Update TSID
    TSID_controller_instance.set_state_with_base(
        s=s[to_mujoco], s_dot=ds[to_mujoco], H_b=H_b, w_b=w_b, t=t
    )

    # MPC plan
    if counter == 0:
        mpc.set_state_with_base(
            s=s[to_mujoco], s_dot=ds[to_mujoco], H_b=H_b, w_b=w_b, t=t
        )
        mpc.update_references()
        mpc_success = mpc.plan_trajectory()
        mpc.contact_planner.advance_swing_foot_planner()
        if not (mpc_success):
            print("MPC failed")
            break

    # Reading new references
    com, dcom, forces_left, forces_right, ang_mom = mpc.get_references()
    left_foot, right_foot = mpc.contact_planner.get_references_swing_foot_planner()

    real_left_foot, real_right_foot = jax_instance.get_feet_wrench()

    diff.append([forces_left - real_left_foot[:3], forces_right - real_right_foot[:3]])

    gravity_diff.append(real_left_foot[2] + real_right_foot[2] - total_mass * 9.81)

    TSID_controller_instance.compute_com_position()
    com_state.append(TSID_controller_instance.COM.toNumPy())
    ref_state.append(com)

    contacts.append(jax_instance.get_feet_wrench())

    # Update references TSID
    TSID_controller_instance.update_task_references_mpc(
        com=com,
        dcom=dcom,
        ddcom=np.zeros(3),
        left_foot_desired=left_foot,
        right_foot_desired=right_foot,
        s_desired=np.array(s_des),
        wrenches_left=forces_left,
        wrenches_right=forces_right,
    )

    # Run control
    succeded_controller = TSID_controller_instance.run()

    if not (succeded_controller):
        print("Controller failed")
        break

    tau = TSID_controller_instance.get_torque()[to_jaxsim]
    torques.append(tau)

    # Step the simulator
    jax_instance.step(n_step=n_step, torque=tau)
    counter = counter + 1

    if t % int(1e9 / jax_instance.recorder.fps) == 0:
        jax_instance.record_frame()

    if counter == n_step_mpc_tsid:
        counter = 0

    # Stop the simulation if the robot fell down
    if jax_instance.data.base_position()[2] < 0.5:
        print(f"Robot fell down at t={t:.4f}s.")
        break

print(f"Time elapsed: {time.perf_counter() - simtime}")

In [None]:
js.model.link_contact_forces(model=jax_instance.model, data=jax_instance.data)

In [None]:
# Plot the com state and the reference
time = np.linspace(0, len(torques) * jax_instance.dt, len(torques) + 1)

import matplotlib.pyplot as plt

com_state = np.array(com_state)
ref_state = np.array(ref_state)

plt.plot(time, com_state[:, 0], label="com_x")
plt.plot(time, ref_state[:, 0], label="ref_x")
plt.plot(time, com_state[:, 1], label="com_y")
plt.plot(time, ref_state[:, 1], label="ref_y")
plt.plot(time, com_state[:, 2], label="com_z")
plt.plot(time, ref_state[:, 2], label="ref_z")
plt.legend()
plt.ylabel("COM position [m]")
plt.xlabel("Time [s]")
plt.show()

In [None]:
import matplotlib.pyplot as plt

time = np.linspace(0, len(torques) * jax_instance.dt, len(torques) + 1)
gravity_diff = np.array(gravity_diff)
plt.plot(time, gravity_diff)
plt.ylabel("F_z - Mg [N]")
plt.xlabel("Time [s]")

In [None]:
import matplotlib.pyplot as plt

time = np.linspace(0, len(torques) * jax_instance.dt, len(torques) + 1)

diff = np.array(diff)

left_diff = diff[:, 0]
right_diff = diff[:, 1]

fig, axs = plt.subplots(1, 3)
fig.set_figwidth(20)

left_diff_x = left_diff[:, 0]
left_diff_y = left_diff[:, 1]
left_diff_z = left_diff[:, 2]

right_diff_x = right_diff[:, 0]
right_diff_y = right_diff[:, 1]
right_diff_z = right_diff[:, 2]

axs[0].plot(time, left_diff_x, label="left_diff_x")
axs[0].plot(time, right_diff_x, label="right_diff_x")
axs[0].legend()
axs[0].set_ylabel("Diff X [N]")
axs[0].set_xlabel("Time [s]")
axs[0].set_title("Diff X")

axs[1].plot(time, left_diff_y, label="left_diff_y")
axs[1].plot(time, right_diff_y, label="right_diff_y")
axs[1].legend()
axs[1].set_ylabel("Diff Y [N]")
axs[1].set_xlabel("Time [s]")
axs[1].set_title("Diff Y")

axs[2].plot(time, left_diff_z, label="left_diff_z")
axs[2].plot(time, right_diff_z, label="right_diff_z")
axs[2].legend()
axs[2].set_ylabel("Diff Z [N]")
axs[2].set_xlabel("Time [s]")
axs[2].set_title("Diff Z")

fig.tight_layout()

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_PREALLOCATE"] = "False"
import matplotlib.pyplot as plt

torques = np.array(torques)

fig, axs = plt.subplots(1, 1)
fig.set_figwidth(15)

# Plot all the torques in the same plot
axs.plot(time[:-1], torques[:, 0], label="r_shoulder_pitch")
axs.plot(time[:-1], torques[:, 1], label="r_shoulder_roll")
axs.plot(time[:-1], torques[:, 2], label="r_shoulder_yaw")
axs.plot(time[:-1], torques[:, 3], label="r_elbow")
axs.plot(time[:-1], torques[:, 4], label="l_shoulder_pitch")
axs.plot(time[:-1], torques[:, 5], label="l_shoulder_roll")
axs.plot(time[:-1], torques[:, 6], label="l_shoulder_yaw")
axs.plot(time[:-1], torques[:, 7], label="l_elbow")
axs.plot(time[:-1], torques[:, 8], label="r_hip_pitch")
axs.plot(time[:-1], torques[:, 9], label="r_hip_roll")
axs.plot(time[:-1], torques[:, 10], label="r_hip_yaw")
axs.plot(time[:-1], torques[:, 11], label="r_knee")
axs.plot(time[:-1], torques[:, 12], label="r_ankle_pitch")
axs.plot(time[:-1], torques[:, 13], label="r_ankle_roll")
axs.plot(time[:-1], torques[:, 14], label="l_hip_pitch")
axs.plot(time[:-1], torques[:, 15], label="l_hip_roll")
axs.plot(time[:-1], torques[:, 16], label="l_hip_yaw")
axs.plot(time[:-1], torques[:, 17], label="l_knee")
axs.plot(time[:-1], torques[:, 18], label="l_ankle_pitch")
axs.plot(time[:-1], torques[:, 19], label="l_ankle_roll")
axs.set_ylabel("Torque [Nm]")
axs.set_xlabel("Time [s]")
axs.legend(ncol=4, loc="center", bbox_to_anchor=(0.5, 0.8))

fig.tight_layout()

In [None]:
import matplotlib.pyplot as plt

left_foot_force = np.array(contacts)[:, 0]
left_foot_force_x = left_foot_force[:, 0]
left_foot_force_y = left_foot_force[:, 1]
left_foot_force_z = left_foot_force[:, 2]
left_foot_force_Rx = left_foot_force[:, 3]
left_foot_force_Ry = left_foot_force[:, 4]
left_foot_force_Rz = left_foot_force[:, 5]

right_foot_force = np.array(contacts)[:, 1]
right_foot_force_x = right_foot_force[:, 0]
right_foot_force_y = right_foot_force[:, 1]
right_foot_force_z = right_foot_force[:, 2]
right_foot_force_Rx = right_foot_force[:, 3]
right_foot_force_Ry = right_foot_force[:, 4]
right_foot_force_Rz = right_foot_force[:, 5]

fig, axs = plt.subplots(2, 3)
fig.set_figwidth(10)

axs[0, 0].plot(time, left_foot_force_x, "tab:purple")
axs[0, 0].plot(time, right_foot_force_x, "tab:red")
axs[0, 0].set_title("Foot Force -- x")

axs[0, 1].plot(time, left_foot_force_y, "tab:purple")
axs[0, 1].plot(time, right_foot_force_y, "tab:red")
axs[0, 1].set_title("Foot Force -- y")

axs[0, 2].plot(time, left_foot_force_z, "tab:purple")
axs[0, 2].plot(time, right_foot_force_z, "tab:red")
axs[0, 2].set_title("Foot Force -- z")

axs[1, 0].plot(time, left_foot_force_Rx, "tab:purple")
axs[1, 0].plot(time, right_foot_force_Rx, "tab:red")
axs[1, 0].set_title("Foot Force -- Rx")

axs[1, 1].plot(time, left_foot_force_Ry, "tab:purple")
axs[1, 1].plot(time, right_foot_force_Ry, "tab:red")
axs[1, 1].set_title("Foot Force -- Ry")

axs[1, 2].plot(time, left_foot_force_Rz, "tab:purple")
axs[1, 2].plot(time, right_foot_force_Rz, "tab:red")
axs[1, 2].set_title("Foot Force -- Rz")

fig.tight_layout()

In [None]:
# Save video
import datetime
import pathlib

now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d_%H-%M-%S")
jax_instance.recorder.save_video(
    path=pathlib.Path.cwd()
    / pathlib.Path("results")
    / pathlib.Path(current_time + "simulation_comodo.mp4"),
    exist_ok=True,
)
# Clean up the recorder.
jax_instance.recorder.frames = []
jax_instance.recorder.renderer.close()