## Comodo Example JaxSim 
This examples, load a basic robot model (i.e. composed only of basic shapes), modifies the links of such a robot model by elongating the legs, define instances of the TSID (Task Based Inverse Dynamics) and Centroidal MPC  controller and simulate the behavior of the robot using mujoco.  

In [None]:
# Comodo import
from comodo.jaxsimSimulator import JaxsimSimulator
from comodo.robotModel.robotModel import RobotModel
from comodo.robotModel.createUrdf import createUrdf
from comodo.centroidalMPC.centroidalMPC import CentroidalMPC
from comodo.centroidalMPC.mpcParameterTuning import MPCParameterTuning
from comodo.TSIDController.TSIDParameterTuning import TSIDParameterTuning
from comodo.TSIDController.TSIDController import TSIDController
from comodo.mujocoSimulator.mujocoVisualizer import MujocoVisualizer
from comodo.mujocoSimulator.idyntreeVisualizer import iDynTreeVisualizer
from comodo.mujocoSimulator.mujocoSimulator import MujocoSimulator

In [None]:
# General  import
import xml.etree.ElementTree as ET
import numpy as np
import tempfile
import urllib.request

import os

In [None]:
# Getting stickbot urdf file and convert it to string
urdf_robot_file = tempfile.NamedTemporaryFile(mode="w+")
url = "https://raw.githubusercontent.com/icub-tech-iit/ergocub-gazebo-simulations/master/models/stickBot/model.urdf"
urllib.request.urlretrieve(url, urdf_robot_file.name)
# Load the URDF file
tree = ET.parse(urdf_robot_file.name)
root = tree.getroot()

# Convert the XML tree to a string
robot_urdf_string_original = ET.tostring(root)

create_urdf_instance = createUrdf(
    original_urdf_path=urdf_robot_file.name, save_gazebo_plugin=False
)

In [None]:
# Define parametric links and controlled joints
legs_link_names = ["hip_3", "lower_leg"]
joint_name_list = [
    "l_hip_pitch",  # 0
    "l_shoulder_pitch",  # 1
    "r_hip_pitch",  # 2
    "r_shoulder_pitch",  # 3
    "l_hip_roll",  # 4
    "l_shoulder_roll",  # 5
    "r_hip_roll",  # 6
    "r_shoulder_roll",  # 7
    "l_hip_yaw",  # 8
    "l_shoulder_yaw",  # 9
    "r_hip_yaw",  # 10
    "r_shoulder_yaw",  # 11
    "l_knee",  # 12
    "l_elbow",  # 13
    "r_knee",  # 14
    "r_elbow",  # 15
    "l_ankle_pitch",  # 16
    "r_ankle_pitch",  # 17
    "l_ankle_roll",  # 18
    "r_ankle_roll",  # 19
]

In [None]:
# Define the robot modifications
modifications = {}
for item in legs_link_names:
    left_leg_item = "l_" + item
    right_leg_item = "r_" + item
    modifications.update({left_leg_item: 1.0})
    modifications.update({right_leg_item: 1.0})
# Motors Parameters
Im_arms = 1e-3 * np.ones(4)  # from 0-4
Im_legs = 1e-3 * np.ones(6)  # from 5-10
kv_arms = 0.001 * np.ones(4)  # from 11-14
kv_legs = 0.001 * np.ones(6)  # from 20

Im = np.concatenate((Im_arms, Im_arms, Im_legs, Im_legs))
kv = np.concatenate((kv_arms, kv_arms, kv_legs, kv_legs))

In [None]:
# Modify the robot model and initialize
create_urdf_instance.modify_lengths(modifications)
urdf_robot_string = create_urdf_instance.write_urdf_to_file()
create_urdf_instance.reset_modifications()
robot_model_init = RobotModel(urdf_robot_string, "stickBot", joint_name_list)
s_des, xyz_rpy, H_b = robot_model_init.compute_desired_position_walking()

In [None]:
# Define simulator and set initial position
jax_instance = JaxsimSimulator()
jax_instance.load_model(robot_model_init, s=s_des, xyz_rpy=xyz_rpy, kv_motors=kv, Im=Im)
s, ds, tau = jax_instance.get_state()
t = 0.0  # jax_instance.get_simulation_time()
H_b = jax_instance.get_base()
w_b = jax_instance.get_base_velocity()
jax_instance.visualize_robot_flag = True
jax_instance.render()

In [None]:
# Define the controller parameters  and instantiate the controller
# Controller Parameters
tsid_parameter = TSIDParameterTuning()
mpc_parameters = MPCParameterTuning()

# TSID Instance
TSID_controller_instance = TSIDController(frequency=0.001, robot_model=robot_model_init)
TSID_controller_instance.define_tasks(tsid_parameter)
TSID_controller_instance.set_state_with_base(s, ds, H_b, w_b, t)

# MPC Instance
step_lenght = 0.1
mpc = CentroidalMPC(robot_model=robot_model_init, step_length=step_lenght)
mpc.intialize_mpc(mpc_parameters=mpc_parameters)

# Set desired quantities
mpc.configure(s_init=s_des, H_b_init=H_b)
TSID_controller_instance.compute_com_position()
mpc.define_test_com_traj(TSID_controller_instance.COM.toNumPy())

# Set initial robot state  and plan trajectories
jax_instance.step()

# Reading the state
s, ds, tau = jax_instance.get_state()
H_b = jax_instance.get_base()
w_b = jax_instance.get_base_velocity()
t = 0.0

# MPC
mpc.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)
mpc.initialize_centroidal_integrator(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)
mpc_output = mpc.plan_trajectory()


In [None]:
# Set loop variables
TIME_TH = 20

# Define number of steps
n_step = int(TSID_controller_instance.frequency / jax_instance.dt)
n_step_mpc_tsid = int(mpc.get_frequency_seconds() / TSID_controller_instance.frequency)

counter = 0
mpc_success = True
energy_tot = 0.0
succeded_controller = True
com_state = []

In [None]:
mj_list = [
    "r_shoulder_pitch", #0
    "r_shoulder_roll", #1
    "r_shoulder_yaw", #2
    "r_elbow", #3
    "l_shoulder_pitch", #4
    "l_shoulder_roll", #5
    "l_shoulder_yaw", #6
    "l_elbow", #7
    "r_hip_pitch", #8
    "r_hip_roll", #9
    "r_hip_yaw", #10
    "r_knee", #11
    "r_ankle_pitch", #12
    "r_ankle_roll",#13
    "l_hip_pitch",#14
    "l_hip_roll",#15
    "l_hip_yaw",#16
    "l_knee",#17
    "l_ankle_pitch",#18
    "l_ankle_roll",#19
]
get_joint_map = lambda from_, to: np.array(list(map(to.index, from_)))
joint_map = get_joint_map(mj_list,jax_instance.model.joint_names())
assert all(np.array(mj_list) == np.array(joint_name_list)[joint_map])

In [None]:
mpc_reference, reference = [], []

In [None]:
# viewer: iDynTreeVisualizer = iDynTreeVisualizer(model_name="stickBot")
# viewer.prepare_visualization()
# viewer.add_model(robot_model=robot_model_init, urdf_path = urdf_robot_string)
# viewer.update_model(s=s, H_b=H_b)
# viewer.visualize()

In [None]:
import jaxsim.api as js
total_mass = js.model.total_mass(jax_instance.model)

In [None]:
# import manifpy as manif
# s_init = s
# ds_init = np.zeros(robot_model_init.NDoF)
# H_b_init = H_b
# w_b_init = np.zeros(6)
# com_init = TSID_controller_instance.COM.toNumPy()
# forces_left = np.zeros(3)
# forces_left[2] = total_mass * 9.81 / 2
# forces_right = np.zeros(3)
# forces_right[2] = total_mass * 9.81 / 2
# H_left_foot = robot_model_init.H_left_foot(H_b_init, s_init)
# H_rigth_foot = robot_model_init.H_right_foot(H_b_init, s_init)
# quaternion = [0.0, 0.0, 0.0, 1.0]
# leftPosition_casadi = np.array(H_left_foot[:3, 3])
# leftPosition = np.zeros(3)
# leftPosition[0] = float(leftPosition_casadi[0])
# leftPosition[1] = float(leftPosition_casadi[1])
# leftPosition[2] = float(leftPosition_casadi[2])
# leftPosition[2] = 0.0
# contact_left = manif.SE3(position=leftPosition, quaternion=quaternion)
# rightPosition = np.zeros(3)
# rightPosition_casadi = np.array(H_rigth_foot[:3, 3])
# rightPosition[0] = float(rightPosition_casadi[0])
# rightPosition[1] = float(rightPosition_casadi[1])
# rightPosition[2] = float(rightPosition_casadi[2])
# rightPosition[2] = 0.0
# contact_right = manif.SE3(position=rightPosition, quaternion=quaternion)

# TSID_controller_instance.update_task_references_mpc_balancing(
#     com=TSID_controller_instance.COM.toNumPy(),
#     dcom=np.zeros(3),
#     ddcom=np.zeros(3),
#     left_foot_desired=contact_left,
#     right_foot_desired=contact_right,
#     s_desired=s_init ,
#     wrenches_left=forces_left,
#     wrenches_right=forces_right,
# )

# succeded_controller = TSID_controller_instance.run()

# # if not (succeded_controller):
# #     print("Controller failed")
# #     break

# tau_ctrl = TSID_controller_instance.get_torque()


In [None]:
# Simulation-control loop
while t < TIME_TH:

    # Reading robot state from simulator
    s, ds, tau = jax_instance.get_state()
    energy_i = np.linalg.norm(tau)
    H_b = jax_instance.get_base()
    w_b = jax_instance.get_base_velocity()
    t = jax_instance.get_simulation_time()

    # Update TSID
    TSID_controller_instance.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)

    # MPC plan
    if counter == 0:
        mpc.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)
        mpc.update_references()
        mpc_success = mpc.plan_trajectory()
        mpc.contact_planner.advance_swing_foot_planner()
        if not (mpc_success):
            print("MPC failed")
            break

    # Reading new references
    com, dcom, forces_left, forces_right = mpc.get_references()
    left_foot, right_foot = mpc.contact_planner.get_references_swing_foot_planner()

    com_state.append(com)
    # Update references TSID
    TSID_controller_instance.update_task_references_mpc(
        com=com,
        dcom=dcom,
        ddcom=np.zeros(3),
        left_foot_desired=left_foot,
        right_foot_desired=right_foot,
        s_desired=np.array(s_des),
        wrenches_left=forces_left,
        wrenches_right=forces_right,
    )

    # Run control
    succeded_controller = TSID_controller_instance.run()

    if not (succeded_controller):
        print("Controller failed")
        break

    tau = TSID_controller_instance.get_torque()

    # Step the simulator
    jax_instance.step(n_step=n_step, torque=tau)
    counter = counter + 1

    if counter == n_step_mpc_tsid:
        counter = 0

In [None]:
# Plot the com state

import matplotlib.pyplot as plt
com_state = np.array(com_state)

plt.plot(com_state[:, 0], label="x")
plt.plot(com_state[:, 1], label="y")
plt.plot(com_state[:, 2], label="z")

plt.legend()
plt.show()

In [None]:
# Simulation-control loop
while t < 3:
    print(tau_ctrl)
    t = t + jax_instance.dt
    print("Time: ", t)
    # Reading robot state from simulator
    s, ds, tau = jax_instance.get_state()
    energy_i = np.linalg.norm(tau)
    H_b = jax_instance.get_base()
    w_b = jax_instance.get_base_velocity()

    viewer.update_model(s=s, H_b=H_b)
    viewer.visualize()
    # Update TSID

    #  TODO: BALANCING COMMENT OUT TO LIFT THE FOOT: TSID CLOSES THE LOOP
    # TSID_controller_instance.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)
    # # MPC plan
    # if counter == 0:
    #     mpc.update_references()
    #     mpc_success = mpc.plan_trajectory()
    #     mpc.contact_planner.advance_swing_foot_planner()

    # com, dcom, forces_left, forces_right = mpc.get_references()
    # left_foot, right_foot = mpc.contact_planner.get_references_swing_foot_planner()
    # TSID_controller_instance.update_task_references_mpc(
    #     com=com,
    #     dcom=np.zeros(3),
    #     ddcom=np.zeros(3),
    #     left_foot_desired=left_foot,
    #     right_foot_desired=right_foot,
    #     s_desired=s_init ,
    #     wrenches_left=forces_left,
    #     wrenches_right=forces_right,
    # )
    # succeded_controller = TSID_controller_instance.run()
    # tau_ctrl = TSID_controller_instance.get_torque()

    # TODO: Rigth Left trajectory CoM  TSID closes the loop
    TSID_controller_instance.set_state_with_base(s=s, s_dot=ds, H_b=H_b, w_b=w_b, t=t)
    com_ref = TSID_controller_instance.COM.toNumPy()
    Amplitude = 0.05 # amplitude in meters
    omega_sin = 2*np.pi*0.5 # frequency in Hz
    print(Amplitude*np.sin(omega_sin*t))
    mpc_reference.append(com_ref)
    com_ref[1] += Amplitude*np.sin(omega_sin*t)
    reference.append(com_ref[1])

    TSID_controller_instance.update_task_references_mpc_balancing(
        com=com_ref,
        dcom=np.zeros(3),
        ddcom=np.zeros(3),
        left_foot_desired=contact_left,
        right_foot_desired=contact_right,
        s_desired=s_init ,
        wrenches_left=forces_left,
        wrenches_right=forces_right,
        )
    succeded_controller = TSID_controller_instance.run()
    tau_ctrl = TSID_controller_instance.get_torque()

    assert np.allclose(tau_ctrl, jax_instance.tau)
    jax_instance.step(n_step=n_step, torque=tau_ctrl)
    print(f"Joint positions: {jax_instance.data.joint_positions()}")
    counter = counter + 1
    # break

    if counter == n_step_mpc_tsid:
        counter = 0

In [None]:
# import matplotlib.pyplot as plt

# left_foot_force = np.array(contact_forces)[:,0]
# left_foot_force_x = left_foot_force[:,0]
# left_foot_force_y = left_foot_force[:,1]
# left_foot_force_z = left_foot_force[:,2]
# left_foot_force_Rx = left_foot_force[:,3]
# left_foot_force_Ry = left_foot_force[:,4]
# left_foot_force_Rz = left_foot_force[:,5]

# right_foot_force = np.array(contact_forces)[:,1]
# right_foot_force_x = right_foot_force[:,0]
# right_foot_force_y = right_foot_force[:,1]
# right_foot_force_z = right_foot_force[:,2]
# right_foot_force_Rx = right_foot_force[:,3]
# right_foot_force_Ry = right_foot_force[:,4]
# right_foot_force_Rz = right_foot_force[:,5]
# time = np.arange(0, len(left_foot_force_x), 1)
# fig, axs = plt.subplots(2, 3)
# fig.set_figwidth(10)

# axs[0, 0].plot(time, left_foot_force_x, 'tab:purple')
# axs[0, 0].plot(time, right_foot_force_x, 'tab:red')
# axs[0, 0].set_title('foot_force_x')

# axs[0, 1].plot(time, left_foot_force_y, 'tab:purple')
# axs[0, 1].plot(time, right_foot_force_y, 'tab:red')
# axs[0, 1].set_title('foot_force_y')

# axs[0, 2].plot(time, left_foot_force_z, 'tab:purple')
# axs[0, 2].plot(time, right_foot_force_z, 'tab:red')
# axs[0, 2].set_title('foot_force_z')

# axs[1, 0].plot(time, left_foot_force_Rx, 'tab:purple')
# axs[1, 0].plot(time, right_foot_force_Rx, 'tab:red')
# axs[1, 0].set_title('foot_force_Rx')

# axs[1, 1].plot(time, left_foot_force_Ry, 'tab:purple')
# axs[1, 1].plot(time, right_foot_force_Ry, 'tab:red')
# axs[1, 1].set_title('foot_force_Ry')

# axs[1, 2].plot(time, left_foot_force_Rz, 'tab:purple')
# axs[1, 2].plot(time, right_foot_force_Rz, 'tab:red')
# axs[1, 2].set_title('foot_force_Rz')

# fig.tight_layout()

In [None]:
# reference

In [None]:
# import matplotlib.pyplot as plt

# time = np.arange(0, len(reference), 1)

# # plt.plot(mpc_reference[:1], label=["mpc_ref_x", "mpc_ref_y", "mpc_ref_z"])
# plt.plot(reference, label="reference")
# plt.legend()
# plt.show()