In [None]:
import os

os.environ["MUJOCO_GL"] = "egl"

In [None]:
import matplotlib.pyplot as plt
import mediapy as media
import mink
import mujoco
import numpy as np

from mujoco_playground.locomotion.go1 import go1_constants

In [None]:
def get_rz(phi: np.ndarray, swing_height=0.08) -> np.ndarray:
  def cubic_bezier_interpolation(y_start, y_end, x):
    y_diff = y_end - y_start
    bezier = x**3 + 3 * (x**2 * (1 - x))
    return y_start + y_diff * bezier

  # Convert [-pi, pi] to [0, 1].
  x = (phi + np.pi) / (2 * np.pi)
  return np.where(
      x <= 0.5,
      cubic_bezier_interpolation(0, swing_height, 2 * x),
      cubic_bezier_interpolation(swing_height, 0, 2 * x - 1),
  )

In [None]:
phi = np.linspace(-np.pi, np.pi, 200)
rz = get_rz(phi)

plt.figure()
plt.plot(phi, rz)
plt.axhline(y=0.08, color="r", linestyle="--", label="nominal swing height")
plt.legend()

In [None]:
freq = 0.5
duration = 1.0
ctrl_freq = 60
gait = "walk"  # ["walk", "trot", "pace", "bound"]

In [None]:
ctrl_dt = 1.0 / ctrl_freq
dt = 2 * np.pi / freq * ctrl_dt
n_steps = int(duration / ctrl_dt)

gait_phases = {
    "walk": np.array([0, 0.5 * np.pi, np.pi, 1.5 * np.pi]),
    "trot": np.array([0, np.pi, np.pi, 0]),
    "pace": np.array([0, np.pi, 0, np.pi]),
    "bound": np.array([0, 0, np.pi, np.pi]),
    "gallop": np.array([0, 0, 0, 0]),
}

phase_shifts = gait_phases[gait]
phases = np.zeros((n_steps, 4))
rs = np.zeros((n_steps, 4))
t = 0
for i in range(int(duration / ctrl_dt)):
  t += dt
  phases[i] = np.fmod(phase_shifts + t + np.pi, 2 * np.pi) - np.pi
  rs[i] = get_rz(phases[i])

# plt.plot(np.cos(phases[:, 0]), label="FR")
# plt.plot(np.cos(phases[:, 1]), label="FL")
# plt.plot(np.cos(phases[:, 2]), label="RR")
# plt.plot(np.cos(phases[:, 3]), label="RL")
# plt.legend()
# plt.show()

# plt.plot(rs[:, 0], label="FR")
# plt.plot(rs[:, 1], label="FL")
# plt.plot(rs[:, 2], label="RR")
# plt.plot(rs[:, 3], label="RL")
# plt.show()

model = mujoco.MjModel.from_xml_path(str(go1_constants.FEET_ONLY_XML))
configuration = mink.Configuration(model)
feet = ["FR", "FL", "RR", "RL"]

base_task = mink.FrameTask(
    frame_name="trunk",
    frame_type="body",
    position_cost=1.0,
    orientation_cost=1.0,
)

posture_task = mink.PostureTask(model, cost=1e-5)

feet_tasks = []
for foot in feet:
  task = mink.FrameTask(
      frame_name=foot,
      frame_type="site",
      position_cost=1.0,
      orientation_cost=0.0,
  )
  feet_tasks.append(task)

tasks = [base_task, posture_task, *feet_tasks]

model = configuration.model
data = configuration.data
solver = "quadprog"

configuration.update_from_keyframe("home_higher")
posture_task.set_target_from_configuration(configuration)
base_task.set_target_from_configuration(configuration)

# Get current foot positions.
feet_positions = []
for foot in feet:
  feet_positions.append(data.site(foot).xpos.copy())
feet_positions = np.array(feet_positions)

scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True

frames = []
with mujoco.Renderer(model, height=480, width=640) as renderer:
  # Assign foot heights as targets.
  for r in rs:
    for i, foot in enumerate(feet):
      foot_pos = feet_positions[i].copy()
      foot_pos[-1] = r[i]
      feet_tasks[i].set_target(mink.SE3.from_translation(foot_pos))

    vel = mink.solve_ik(configuration, tasks, ctrl_dt, solver, 1e-5)
    configuration.integrate_inplace(vel, ctrl_dt)
    mujoco.mj_forward(model, data)

    renderer.update_scene(data, camera="side", scene_option=scene_option)
    frames.append(renderer.render())
media.show_video(frames, fps=(1.0 / ctrl_dt), loop=False)