In [None]:
!conda install -y gdown
!gdown 1i6gsJch1_Odnz-ddqGQzpakD5mowwx-6

In [None]:
!apt install -y python-opengl ffmpeg > /dev/null 2>&1

# !apt install -y xvfb
%pip install pyvirtualdisplay
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

In [None]:
# !pip install -U "jax[cuda12]"
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

In [None]:
#@title Check if MuJoCo installation was successful
import os


# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=glfw

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
)


In [None]:
!sudo apt install unrar
!pip install patool
import patoolib
patoolib.extract_archive("/kaggle/working/Spider.rar",outdir='/kaggle/working/')

In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
import random
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [None]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict
import math


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
# from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.es import train as es
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.es import networks as es_networks
from brax.io import html, mjcf, model


In [None]:
xml_path = r'/kaggle/working/Spider_Assembly_fineMesh_frictionDamp/urdf/final_noFrictionLoss_noCoxaCon_explicitConPair_ellipsoidTibias_kaggle.xml'

class HexapodV0_3(PipelineEnv):
    def __init__(self,
                 xml_path,
                 terminateWhenTilt=True,
                 terminateWhenTiltGreaterThan=30 * math.pi / 180,
                 baseTiltSigma=0.2,
                 baseTiltCoef=1,

                 terminateWhenFumersColide=False,
                 femurCollisionSigma=0.06,
                 femurCollisionCoef=0,

                 correctDirectionSigma=0.2,
                 correctDirectionWeight=1,
                 deviationAngleSigma=0.3,
                 deviationAngleWeight=1,

                 baseHeightSigma=0.027,
                 baseHeightCoef=1,
                 terminateWhenLow=True,
                 baseHeightLowerLimit=0.15,
                 baseOscillationSigma=0.5,
                 baseOscillationCoef=1.0,

                 rewardForTibiaTip=True,
                 tibiaRewardSigma=0.05,
                 tibiaRewardCoef=0,

                 powerCoef=0.001,
                 continuityCoef=0.5,

                 includeBaseAngularVels=True,
                 includeTibiaTipSensors=False,
                 nStacks=3,
                 physics_steps_per_control_step=10,

                 resetPosLowHigh=[jp.array([-0.2, -0.2, 0.23]), jp.array([0.2, 0.2, 0.4])],
                 resetOriLowHigh=[jp.array([-math.pi/12, -math.pi/12, -math.pi]), jp.array([math.pi/12, math.pi/12, -math.pi])],
                 resetJointsPosLowHigh = [jp.array([-math.pi/12]*18), jp.array([math.pi/12]*18)],
                 resetJointsVelsLowHigh=[jp.array([-0.3]*24), jp.array([-0.3]*24)],
                 **kwargs
                 ):

        self.mj_model = mujoco.MjModel.from_xml_path(xml_path)
        self.mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
        self.mj_model.opt.iterations = 6
        self.mj_model.opt.ls_iterations = 6

        sys = mjcf.load_model(self.mj_model)
        sys = sys.tree_replace({'opt.timestep': 0.002})

        kwargs['n_frames'] = kwargs.get(
            'n_frames', physics_steps_per_control_step)
        kwargs['backend'] = 'mjx'
        super().__init__(sys, **kwargs)

        self._terminateWhenTilt = terminateWhenTilt
        self._terminateWhenTiltGreaterThan = terminateWhenTiltGreaterThan
        self._baseTiltSigma= baseTiltSigma
        self._baseTiltCoef= baseTiltCoef
        self._terminateWhenFumersColide = terminateWhenFumersColide
        self._femurCollisionSigma = femurCollisionSigma
        self._femurCollisionCoef = femurCollisionCoef
        self._correctDirectionSigma = correctDirectionSigma
        self._correctDirectionWeight = correctDirectionWeight
        self._deviationAngleSigma = deviationAngleSigma
        self._deviationAngleWeight = deviationAngleWeight
        self._baseHeightSigma = baseHeightSigma
        self._baseHeightCoef = baseHeightCoef
        self._terminateWhenLow = terminateWhenLow
        self._baseHeightLowerLimit = baseHeightLowerLimit
        self._baseOscillationSigma = baseOscillationSigma
        self._baseOscillationCoef = baseOscillationCoef
        self._rewardForTibiaTip = rewardForTibiaTip
        self._tibiaRewardSigma = tibiaRewardSigma
        self._tibiaRewardCoef = tibiaRewardCoef
        self._powerCoef = powerCoef
        self._continuityCoef = continuityCoef
        self._includeBaseAngularVels = includeBaseAngularVels
        self._includeTibiaTipSensors = includeTibiaTipSensors
        self._resetPosLowHigh = resetPosLowHigh
        self._resetOriLowHigh = resetOriLowHigh
        self._resetJointsPosLowHigh = resetJointsPosLowHigh
        self._resetJointsVelsLowHigh = resetJointsVelsLowHigh
        self._includeBaseAngularVels = includeBaseAngularVels
        self._includeTibiaTipSensors = includeTibiaTipSensors
        self._nStacks = nStacks
        self._physics_steps_per_control_step = physics_steps_per_control_step

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2, rng3, rng4, numTransitionsRng, desiredVelRng, desiredAngleRng, transitionStepsRng= jax.random.split(rng, 9)

        base_pos = jax.random.uniform(key=rng1, shape=(3,), minval=self._resetPosLowHigh[0], maxval=self._resetPosLowHigh[1])
        # print(base_pos)
        base_orientation_euler = jax.random.uniform(key=rng2, shape=(3,), minval=self._resetOriLowHigh[0],
                                                    maxval=self._resetOriLowHigh[1])
        base_orientation = self._euler_to_quaternion(base_orientation_euler)

        joints_pos = jax.random.uniform(key=rng3, shape=(18,), minval=self._resetJointsPosLowHigh[0],
                                                    maxval=self._resetJointsPosLowHigh[1])
        qpos = jp.concatenate((base_pos, base_orientation, joints_pos), axis=0)

        qvel = jax.random.uniform(key=rng4, shape=(24,), minval=self._resetJointsVelsLowHigh[0],
                                                    maxval=self._resetJointsVelsLowHigh[1])

        data = self.pipeline_init(qpos, qvel)

        obs_history = jp.zeros(self._nStacks * (self._includeBaseAngularVels * 3 + self._includeTibiaTipSensors * 6 + 18  + 2))
        reward, done, zero = jp.zeros(3)
        
#         num_transitions = jax.random.randint(key=numTransitionsRng, shape=(1,), minval=1, maxval=11)[0].astype(int)
        num_transitions = random.randint(1, 10)
        desired_vels = jax.random.uniform(key=desiredVelRng, shape=(num_transitions + 1,), minval=0.1, maxval=1.2)
        desired_angle_vels = jax.random.uniform(key=desiredAngleRng, shape=(num_transitions + 1,), minval=-1, maxval=1)
        transition_steps = jax.random.randint(key=transitionStepsRng, shape=(num_transitions,), minval=50, maxval=951)

        obs= self._get_obs(data, jp.zeros(self.sys.nu), desired_vels[0], desired_angle_vels[0], obs_history)
        state_info = {
            'last_action': jp.zeros(self.sys.nu),
            'num_transitions': num_transitions,
            'desired_vels': desired_vels,
            'desired_angle_vels': desired_angle_vels,
            'transition_steps': transition_steps,
            'current_idx' : 0,
            'step': 0,
        }
        metrics = {
            'correct_direction_reward': zero,
            'base_tilt':zero,
            'base_height': zero,
            'movement_angle': zero,
            'continuity': zero,
            'power': zero,
            'tibia_tip_contact': zero,
            'femur_collision': zero,
            'x_position': zero,
            'y_position': zero,
            'distance_from_origin': zero,
            'x_velocity': zero,
            'y_velocity': zero,
            'base_tilt_reward': zero,
            'tibia_reward': zero,
            'total_reward' : zero
        }
        return State(data, obs, reward, done, metrics, state_info)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        prev_pipeline_state = state.pipeline_state
        pipeline_state = self.pipeline_step(prev_pipeline_state, action)
        last_action = state.info['last_action']
        desired_vel = state.info['desired_vels'][state.info['current_idx']]
        desired_angle_vel = state.info['desired_angle_vels'][state.info['current_idx']]
#         desired_velx = desired_vel * jp.cos(desired_angle)
#         desired_vely = desired_vel * jp.sin(desired_angle)
#         desired_vel_vector = jp.array([desired_velx, desired_vely])

        # prev_base_pos = prev_pipeline_state.x.pos[0,:]
        prev_base_pos = prev_pipeline_state.subtree_com[0]
        # base_pos = pipeline_state.x.pos[0,:]
        base_pos = pipeline_state.subtree_com[0]
        displacement = base_pos - prev_base_pos
        velocity = pipeline_state.xd.vel[0, :]

        base_ori = pipeline_state.x.rot[0,:]
        base_tilt = (jp.linalg.norm(self._quaternion_to_euler(base_ori[:])[0:2]))
        base_ang_vel = jp.linalg.norm(pipeline_state.xd.ang[0,0:2])
        movement_angle = jp.atan2(displacement[0], displacement[1])

        # diversion_vector = jp.array([desired_velx - velocity[0], desired_vely - velocity[1]])
        correctDirectionReward = self._correctDirectionWeight * jp.exp(-(desired_vel - jp.linalg.norm(velocity[0:2]))**2/self._correctDirectionSigma**2)
#         correctDirectionReward = jp.dot(velocity[0:2], desired_vel_vector)
        correctAngVelReward = 1 * jp.exp(-(desired_angle_vel - pipeline_state.xd.ang[0, -1])**2/self._correctDirectionSigma**2)
        # deviationReward = (self._deviationAngleWeight * jp.exp(-(desired_angle-movement_angle)**2/self._deviationAngleSigma**2))
        # femurDistanceReward = self._get_femur_reward(pipeline_state)
        # tibiaReward = self._get_tibia_reward(pipeline_state) * self._rewardForTibiaTip
        # baseHeightReward = (self._baseHeightCoef * jp.exp(-(0.23 - base_pos[2])**2/self._baseHeightSigma**2) *
        # (base_pos[2] > self._baseHeightLowerLimit) )
        # baseHeightReward = 0
        # baseOscillationReward = self._baseOscillationCoef * jp.exp(-base_ang_vel**2/self._baseOscillationSigma**2)
        baseTiltReward = -1 * (1 - self._baseTiltCoef * jp.exp(-base_tilt**2/self._baseTiltSigma**2))
        termination = jp.array(((base_tilt > self._terminateWhenTiltGreaterThan) * self._terminateWhenTilt |
                       (base_pos[2] < self._baseHeightLowerLimit) * self._terminateWhenLow), dtype=jp.bool)

        continuity_reward = self._continuityCoef * ((action - last_action)**2).sum()
        state.info['last_action'] = state.info['last_action'].at[:].set(action)

        reward = (correctDirectionReward + correctAngVelReward - 1 * termination + baseTiltReward - continuity_reward)
        done = 1.0 - ~termination
        # print('1')
        state.info['step'] += 1
        condition = jp.any(state.info['step'] == state.info['transition_steps'])
        new_current_idx = jax.lax.select(condition, state.info['current_idx'] + 1, state.info['current_idx'])

        state.info['current_idx'] = new_current_idx

        obs = self._get_obs(pipeline_state,  action, desired_vel, desired_angle_vel, obs_history=state.obs)

        state.metrics.update(correct_direction_reward=correctDirectionReward,
                         base_tilt=base_tilt,
                         base_height=base_pos[2],
                         movement_angle=movement_angle,
#                          continuity=continuity_reward,
                         x_position=base_pos[0],
                         y_position=base_pos[1],
                         distance_from_origin=jp.linalg.norm(base_pos[0:2]),
                         x_velocity=velocity[0],
                         y_velocity=velocity[1],
                         base_tilt_reward=baseTiltReward,
#                          tibia_reward=tibiaReward,
#                          total_reward=reward
                             )
        # print(state.metrics)
        return state.replace(
            pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
        )


    def _get_obs(
            self, data: base.State, action: jp.ndarray, desired_velx, desired_vely, obs_history: jax.Array
                ) -> jp.ndarray:
        """Observes"""
        # mjx.forward(self.sys, data)
        # historic_action = jp.zeros(self.sys.nu*2)
        current_obs = jp.concatenate((jp.array([desired_velx]), jp.array([desired_vely]), data.qpos[7:]), axis=0)
        if self._includeBaseAngularVels:
            base_ang_vel = data.xd.ang[0,:]
            # print(data.xd.ang)
            current_obs = jp.append(base_ang_vel, current_obs)

        if self._includeTibiaTipSensors:
            mj_data = mjx.get_data(self.mj_model, data)
            for i in range(6):
                contact = jp.array(mj_data.sensordata[i], dtype=jp.bool)
                current_obs = jp.append(current_obs, contact)

        obs = jp.roll(obs_history, current_obs.size).at[:current_obs.size].set(current_obs)
        return obs

    def _get_femur_reward(self, pipeline_state):
        femur_dists = pipeline_state.contact.dist[6:]
#         print(pipeline_state.contact.dist)
#         print(femur_dists)
        femur_reward = self._femurCollisionCoef * jp.exp(-femur_dists.min()**2/self._femurCollisionSigma**2)
        return femur_reward

    def _get_tibia_reward(self, pipeline_state: State) -> jp.ndarray:
        contact_dists = pipeline_state.contact.dist[0:6]
        contact_booleans = (jp.abs(contact_dists) < 0.03) * jp.ones(6)
        tibia_tip_dists = 0
        for i in range(2,8):
            tibia_tip_dists += contact_booleans[i-2] * pipeline_state.site_xpos[i, 2]

        tibia_reward = self._tibiaRewardCoef * jp.exp(-tibia_tip_dists**2/self._tibiaRewardSigma**2)
        return tibia_reward


    def _euler_to_quaternion(self, euler):
        """Converts Euler angles to quaternion."""
        roll, pitch, yaw = euler
        qx = jp.sin(roll/2) * jp.cos(pitch/2) * jp.cos(yaw/2) - jp.cos(roll/2) * jp.sin(pitch/2) * jp.sin(yaw/2)
        qy = jp.cos(roll/2) * jp.sin(pitch/2) * jp.cos(yaw/2) + jp.sin(roll/2) * jp.cos(pitch/2) * jp.sin(yaw/2)
        qz = jp.cos(roll/2) * jp.cos(pitch/2) * jp.sin(yaw/2) - jp.sin(roll/2) * jp.sin(pitch/2) * jp.cos(yaw/2)
        qw = jp.cos(roll/2) * jp.cos(pitch/2) * jp.cos(yaw/2) + jp.sin(roll/2) * jp.sin(pitch/2) * jp.sin(yaw/2)
        return jp.array([qw, qx, qy, qz])

    def _quaternion_to_euler(self, quaternion):
        """Converts quaternion to Euler angles."""
        qw, qx, qy, qz = quaternion

        # Roll (x-axis rotation)
        sinr_cosp = 2 * (qw * qx + qy * qz)
        cosr_cosp = 1 - 2 * (qx * qx + qy * qy)
        roll = jp.arctan2(sinr_cosp, cosr_cosp)

        # Pitch (y-axis rotation)
        sinp = 2 * (qw * qy - qz * qx)
        pitch = jp.where(jp.abs(sinp) >= 1, jp.sign(sinp) * (jp.pi / 2), jp.arcsin(sinp))

        # Yaw (z-axis rotation)
        siny_cosp = 2 * (qw * qz + qx * qy)
        cosy_cosp = 1 - 2 * (qy * qy + qz * qz)
        yaw = jp.arctan2(siny_cosp, cosy_cosp)

        return jp.array([roll, pitch, yaw])

    def render(
            self, trajectory: List[base.State], camera = None,
            width: int = 640, height: int = 480,
            ) -> Sequence[np.ndarray]:
        camera = camera or 'track'
        return super().render(trajectory, camera=camera, width=width, height=height)



env = HexapodV0_3(xml_path=xml_path)


In [None]:
make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
        policy_hidden_layer_sizes=(128, 128, 128, 128))
train_fn = functools.partial(
      ppo.train, num_timesteps=200_000_000, num_evals=20,
      reward_scaling=1, episode_length=1000, normalize_observations=True,
      action_repeat=1, unroll_length=20, num_minibatches=64,
      num_updates_per_batch=10, discounting=0.97, learning_rate=3.0e-4,
      entropy_cost=1e-2, num_envs=int(8192), batch_size=256,
      network_factory=make_networks_factory,
      seed=0)

def progress(num_steps, metrics):
  print(metrics['eval/episode_reward'])
  times.append(datetime.now())
  plt.figure()
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
#   plt.ylim([min_y, max_y])

  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')

  plt.errorbar(
      x_data, y_data, yerr=ydataerr)
  plt.show()

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

env = HexapodV0_3(xml_path=xml_path)
eval_env = HexapodV0_3(xml_path=xml_path)
make_inference_fn, params, metrics = train_fn(environment=env,
                                       progress_fn=progress,
                                       eval_env=eval_env)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
model_path = '/kaggle/working/mjx_brax_policy_22'
model.save_params(model_path, params)

In [None]:
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

In [None]:
eval_env = HexapodV0_3(xml_path=xml_path)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]
rollout_info = [state.info]

# grab a trajectory
n_steps = 1000
render_every = 1

for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  # for i in range(env._physics_steps_per_control_step):
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)
  rollout_info.append(state.info)

  if state.done:
    break

frames = env.render(rollout, camera='hexapod_camera')
media.show_video(frames, fps=1.0 / env.dt / 2)

In [None]:
import cv2


def add_capsule_and_text(frames, desired_vels, desired_angle_vels):
    for i, frame in enumerate(frames):
        desired_vel = desired_vels[i]
        desired_angle_vel = desired_angle_vels[i]

        # Calculate the start and end points of the capsule from the center of the frame
        height, width, _ = frame.shape
        center_x = width // 2
        center_y = height // 2
        capsule_length = int(desired_vel * 100)  # Scale the velocity to pixels
#         angle = desired_angle

#         start_point = (center_x, center_y)
#         end_point = (int(center_x + capsule_length * math.cos(angle)), int(center_y - capsule_length * math.sin(angle)))

        # Draw the capsule
#         frame = cv2.line(frame, start_point, end_point, (0, 0, 255), 2)
#         frame = cv2.circle(frame, start_point, 5, (0, 0, 255), -1)
#         frame = cv2.circle(frame, end_point, 5, (0, 0, 255), -1)

        # Add text with transparent background
        text_str = f"Velocity: {desired_vel:.2f}\nAngle: {math.degrees(desired_angle_vel):.2f}°"
        y0, dy = 30, 20
        for j, line in enumerate(text_str.split('\n')):
            y = y0 + j * dy
            frame = cv2.putText(frame, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
            frame = cv2.putText(frame, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1, cv2.LINE_AA)

        frames[i] = frame
    return frames

# Assuming you have frames and env
# frames = env.render(rollout, rollout_info, camera='hexapod_camera')
# media.show_video(frames, fps=1.0 / env.dt / 2)

# Call the add_capsule_and_text function with desired velocities and angles
# Replace the following lists with the actual desired_vels and desired_angle_vels lists used in your environment
desired_vels = [info['desired_vels'][info['current_idx']] for info in rollout_info]
desired_angle_vels = [info['desired_angle_vels'][info['current_idx']] for info in rollout_info]

# frames = env.render(rollout, rollout_info, camera='hexapod_camera')
frames = add_capsule_and_text(frames, desired_vels, desired_angle_vels)
# media.show_video(frames, fps=1.0 / env.dt / 2)media.show_video(frames, fps=1.0 / env.dt / 2)

In [None]:
media.show_video(frames, fps=1.0 / env.dt / 2)