In [None]:
!pip install pybullet gymnasium stable-baselines3[extra] opencv-python protobuf==3.20.3

In [None]:
!pip uninstall -y gym

In [None]:
!git clone https://github.com/HumaRobotics/phantomx_description.git

In [None]:
import pybullet as p
import pybullet_data
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from tqdm.notebook import trange
import os
import time

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from stable_baselines3.common.callbacks import BaseCallback

In [None]:
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket('issac-output')

In [None]:
class HexapodEnv(gym.Env):
    def __init__(self, gui=False):
        super(HexapodEnv, self).__init__()
        self.gui = gui
        if gui:
            self.physicsClient = p.connect(p.GUI)
        else:
            self.physicsClient = p.connect(p.DIRECT)

        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        p.setGravity(0, 0, -9.81)
        p.setTimeStep(1/240.)

        self.plane = p.loadURDF("plane.urdf")

        self.robot = p.loadURDF("phantomx_description/urdf/phantomx.urdf",
                                [0,0,0.1], useFixedBase=False)

        self.num_joints = p.getNumJoints(self.robot)
        for j in range(self.num_joints):
          p.changeDynamics(self.robot, j, jointLowerLimit=-np.pi/2,jointUpperLimit=np.pi/2)

        obs_high = np.array([np.pi]*self.num_joints*2 + [np.inf]*12)
        self.observation_space = spaces.Box(-obs_high, obs_high, dtype=np.float32)

        self.action_space = spaces.Box(-np.pi, np.pi, shape=(self.num_joints,), dtype=np.float32)

        obs_len = self.num_joints * 2 + 9
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_len,), dtype=np.float32)

    def reset(self, seed=None, options=None):
        for j in range(self.num_joints):
            p.resetJointState(self.robot, j, 0)
        self.start_pos = p.getBasePositionAndOrientation(self.robot)[0]
        return self._get_obs(), {}

    def _get_obs(self):
      joint_angles, joint_vels = [], []
      for j in range(self.num_joints):
          state = p.getJointState(self.robot, j)
          joint_angles.append(state[0])
          joint_vels.append(state[1])

      pos, orn = p.getBasePositionAndOrientation(self.robot)
      lin_vel, ang_vel = p.getBaseVelocity(self.robot)
      roll, pitch, yaw = p.getEulerFromQuaternion(orn)

      lin_vel = lin_vel if len(lin_vel) == 3 else [0,0,0]
      ang_vel = ang_vel if len(ang_vel) == 3 else [0,0,0]

      obs = np.array(joint_angles + joint_vels + [roll, pitch, yaw] + list(lin_vel) + list(ang_vel), dtype=np.float32)
      assert obs.shape[0] == self.observation_space.shape[0], f"Obs length {obs.shape[0]} does not match expected {self.observation_space.shape[0]}"
      return obs

    def step(self, action):
        for j in range(self.num_joints):
            p.setJointMotorControl2(self.robot, j, p.POSITION_CONTROL, targetPosition=action[j], positionGain=0.1, velocityGain=1.0,force=2)
        p.stepSimulation()
        obs = self._get_obs()
        pos, orn = p.getBasePositionAndOrientation(self.robot)
        roll, pitch, _ = p.getEulerFromQuaternion(orn)
        forward = (pos[0] - self.start_pos[0]) * 10
        stability_penalty = -abs(roll) - abs(pitch)
        vel_penalty = -np.sum(np.square(obs[self.num_joints:self.num_joints*2])) * 0.01
        reward = forward + stability_penalty + vel_penalty
        done = pos[2] < 0.05 or abs(roll) > 1.0 or abs(pitch) > 1.0
        return obs, reward, done, False, {}

    def close(self):
        p.disconnect(self.physicsClient)

In [None]:
def make_env(gui=False):
    def _init():
        return HexapodEnv(gui=gui)
    return _init

num_envs = 4
vec_env = DummyVecEnv([make_env(gui=False) for _ in range(num_envs)])

In [None]:
device = "cpu"

policy_kwargs = dict(
    net_arch=[256, 256]
)

model = PPO(
    "MlpPolicy",
    vec_env,
    verbose=0,
    device=device,
    tensorboard_log="./tensorboard_logs/",
    batch_size=4096,
    n_steps=1024,
    policy_kwargs=policy_kwargs,
    learning_rate=3e-4
)

print(f"Training on device: {device}")

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./tensorboard_logs/

In [None]:
class LossLoggerCallback(BaseCallback):
    def _on_step(self) -> bool:
        policy_loss = self.logger.name_to_value.get("train/policy_loss")
        value_loss = self.logger.name_to_value.get("train/value_loss")
        entropy_loss = self.logger.name_to_value.get("train/entropy_loss")
        if policy_loss is not None:
            print(f"Policy Loss: {policy_loss:.4f}, Value Loss: {value_loss:.4f}, Entropy: {entropy_loss:.4f}")
        return True

total_timesteps = 10000000
save_interval = 100000

save_model_dirc = "saved_models/"
import os
if not os.path.exists(save_model_dirc):
    os.makedirs(save_model_dirc)
for i in trange(0, total_timesteps, save_interval):
    model.learn(total_timesteps=save_interval, reset_num_timesteps=False, progress_bar=False, callback=LossLoggerCallback())

    local_checkpoint_path = f"hexapod_ppo_checkpoint_{i}"
    model.save(local_checkpoint_path)

In [None]:
vec_env.close()
print("Training complete. Checkpoints saved for later use.")

# Saving Models

In [None]:
!gsutil cp *.zip gs://issac-output/run2/

In [None]:
!gsutil cp -r tensorboard_logs/ gs://issac-output/run2/

# Make a video

In [None]:
import pybullet as p
p.disconnect()

In [None]:
from stable_baselines3 import PPO

model_path = "hexapod_ppo_checkpoint_1500000.zip"
model = PPO.load(model_path)

In [None]:
class HexapodEnvVideo(HexapodEnv):
    def __init__(self, gui=False):
        super().__init__(gui=gui)
        self.render_mode = "rgb_array"
    def render(self):
        if self.render_mode == "rgb_array":
            # get camera image
            width, height, view_matrix, proj_matrix = 640, 480, p.computeViewMatrixFromYawPitchRoll(
                cameraTargetPosition=[0,0,0],
                distance=3,
                yaw=0,
                pitch=-30,
                roll=-0,
                upAxisIndex=2
            ), p.computeProjectionMatrixFOV(
                fov=60,
                aspect=640/480,
                nearVal=0.1,
                farVal=100
            )
            img_arr = p.getCameraImage(width, height, viewMatrix=view_matrix,
                                       projectionMatrix=proj_matrix,
                                       renderer=p.ER_TINY_RENDERER)
            rgb_array = np.array(img_arr[2])[:,:,:3]
            return rgb_array
        else:
            return None

env = HexapodEnvVideo()
obs, _ = env.reset()

In [None]:
from tqdm.notebook import tqdm
import cv2

frames = []
obs, _ = env.reset()

for step in tqdm(range(1024), desc="Generating video"):
    action, _ = model.predict(obs)
    obs, reward, terminated, truncated, info = env.step(action)

    frame = env.render()
    frames.append(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    if terminated or truncated:
        obs, _ = env.reset()

height, width, _ = frames[0].shape
out = cv2.VideoWriter('hexapod.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height))
for frame in frames:
    out.write(frame)
out.release()