# Installations and Imports

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax
!pip install playground

In [None]:
# @title Import packages for plotting and creating graphics
import json
import itertools
import time
from typing import Callable, List, NamedTuple, Optional, Union
import numpy as np

import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [None]:
# @title Import MuJoCo, MJX, and Brax
from datetime import datetime
import functools
import os
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.io import html, mjcf, model
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import networks as sac_networks
from brax.training.agents.sac import train as sac
from etils import epath
from flax import struct
from flax.training import orbax_utils
from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import numpy as np
from orbax import checkpoint as ocp

In [None]:
# @title Import The Playground

from mujoco_playground import wrapper
from mujoco_playground import registry

# Clone Repository

In [None]:
!git clone https://github.com/shaoanlu/control_system_project_template.git

In [None]:
%cd control_system_project_template

In [None]:
!pwd

# Run simulation

In [None]:
import jax
import mujoco
from tqdm import tqdm

from examples.mujoco_Go1.env_wrapper import Go1Env
from examples.mujoco_Go1.ppo import PPO, PPOParams, PPOParamsBuilder
from src.control.controller_factory import ControllerFactory

In [None]:
velocity_kick_range = [0.0, 0.0]  # Disable velocity kick.
kick_duration_range = [0.05, 0.2]


def sample_pert(rng, env, state):
    rng, key1, key2 = jax.random.split(rng, 3)
    pert_mag = jax.random.uniform(key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1])
    duration_seconds = jax.random.uniform(key2, minval=kick_duration_range[0], maxval=kick_duration_range[1])
    duration_steps = jax.numpy.round(duration_seconds / env.dt).astype(jax.numpy.int32)
    state.info["pert_mag"] = pert_mag
    state.info["pert_duration"] = duration_steps
    state.info["pert_duration_seconds"] = duration_seconds
    return rng

## Instantiate Simulator And Controller

In [None]:
# instantiate mujoco Env
env_name = "Go1JoystickFlatTerrain"  # Go1Handstand, Go1JoystickFlatTerrain
rng = jax.random.PRNGKey(0)
env = Go1Env(env_name=env_name)

# Instantiate controller based on env_name
factory = ControllerFactory()
factory.register_controller(PPOParams, PPO)
controller_config = {"npy_path": f"examples/mujoco_Go1/nn_params/{env_name}"}
ppo_params = PPOParamsBuilder().build(config=controller_config)
controller = factory.build(params=ppo_params)

In [None]:
rollout = []
modify_scene_fns = []
swing_peak = []
linvel = []
angvel = []

command = jax.numpy.array([0, 0, 0.9])
state = env.reset(rng)
for i in range(env.env_cfg.episode_length):
    state.info["command"] = command
    try:
        if state.info["steps_since_last_pert"] < state.info["steps_until_next_pert"]:
            rng = sample_pert(rng, env, state)
    except:
        pass
    act_rng, rng = jax.random.split(rng)
    ctrl = controller.control(state.obs["state"])

    state = env.step(state, ctrl)
    rollout.append(state)
    swing_peak.append(state.info["swing_peak"])
    linvel.append(env.get_global_linvel(state.data))
    angvel.append(env.get_gyro(state.data))

# Visualize Simulation Result

In [None]:
# visualization
# Plot each foot in a 2x2 grid.
swing_peak = jax.numpy.array(swing_peak)
names = ["FR", "FL", "RR", "RL"]
colors = ["r", "g", "b", "y"]
fig, axs = plt.subplots(2, 2)
for i, ax in enumerate(axs.flat):
    ax.plot(swing_peak[:, i], color=colors[i])
    ax.set_ylim([0, env.env_cfg.reward_config.max_foot_height * 1.25])
    ax.axhline(env.env_cfg.reward_config.max_foot_height, color="k", linestyle="--")
    ax.set_title(names[i])
    ax.set_xlabel("time")
    ax.set_ylabel("height")
plt.tight_layout()
plt.show()

linvel_x = jax.numpy.array(linvel)[:, 0]
linvel_y = jax.numpy.array(linvel)[:, 1]
angvel_yaw = jax.numpy.array(angvel)[:, 2]

# Plot whether velocity is within the command range.
linvel_x = jax.numpy.convolve(linvel_x, jax.numpy.ones(10) / 10, mode="same")
linvel_y = jax.numpy.convolve(linvel_y, jax.numpy.ones(10) / 10, mode="same")
angvel_yaw = jax.numpy.convolve(angvel_yaw, jax.numpy.ones(10) / 10, mode="same")

fig, axes = plt.subplots(3, 1, figsize=(10, 10))
axes[0].plot(linvel_x)
axes[1].plot(linvel_y)
axes[2].plot(angvel_yaw)

axes[0].set_ylim(-env.env_cfg.command_config.a[0], env.env_cfg.command_config.a[0])
axes[1].set_ylim(-env.env_cfg.command_config.a[1], env.env_cfg.command_config.a[1])
axes[2].set_ylim(-env.env_cfg.command_config.a[2], env.env_cfg.command_config.a[2])

for i, ax in enumerate(axes):
    ax.axhline(state.info["command"][i], color="red", linestyle="--")

labels = ["dx", "dy", "dyaw"]
for i, ax in enumerate(axes):
    ax.set_ylabel(labels[i])