In [None]:
!git clone https://github.com/triton-droids/pupper-simulations.git
%cd pupper-simulations

fatal: destination path 'pupper-simulations' already exists and is not an empty directory.
/content/pupper-simulations


In [None]:
!pip install -r training_requirements.txt
%cd locomotion

/content/pupper-simulations/locomotion


In [None]:
from bittle_env import BittleEnv
from training_helpers import progress, domain_randomize

from brax import envs
import functools
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.io import model

from datetime import datetime
import sys
from etils import epath
from matplotlib import pyplot as plt

from flax.training import orbax_utils
from orbax import checkpoint as ocp

import jax
from jax import numpy as jp

import mediapy as media



# **Train Policy**

In [None]:
envs.register_environment('bittle', BittleEnv)

env_name = 'bittle'
xml_path = 'bittle_adapted_scene.xml'
env = envs.get_environment(env_name, xml_path = xml_path)

ckpt_path = epath.Path('/tmp/quadrupred_joystick/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)


make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
        policy_hidden_layer_sizes=(128, 128, 128, 128))
train_fn = functools.partial(
      ppo.train, num_timesteps=100_000_000, num_evals=10_000_000,
      reward_scaling=1, episode_length=1000, normalize_observations=True,
      action_repeat=1, unroll_length=20, num_minibatches=32,
      num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,
      entropy_cost=1e-2, num_envs=8192, batch_size=256,
      network_factory=make_networks_factory,
      randomization_fn=domain_randomize,
      policy_params_fn=policy_params_fn,
      seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 40, 0

# Reset environments since internals may be overwritten by tracers from the
# domain randomization function.
env = envs.get_environment(env_name, xml_path = xml_path)
eval_env = envs.get_environment(env_name, xml_path = xml_path)
make_inference_fn, params, _= train_fn(environment=env,
                                       progress_fn=progress,
                                       eval_env=eval_env)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

# Save and reload params.
model_path = '/tmp/mjx_brax_quadruped_policy'
model.save_params(model_path, params)
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)



Bittle has 9 actuators
Bittle has 9 position DOFs
Bittle has 9 velocity DOFs
Found lower leg body: servos_rf_1 (id=2)
Found lower leg body: servos_rr_1 (id=4)
Found lower leg body: servos_lf_1 (id=7)
Found lower leg body: servos_lr_1 (id=9)
Bittle has 9 actuators
Bittle has 9 position DOFs
Bittle has 9 velocity DOFs
Found lower leg body: servos_rf_1 (id=2)
Found lower leg body: servos_rr_1 (id=4)
Found lower leg body: servos_lf_1 (id=7)
Found lower leg body: servos_lr_1 (id=9)
Bittle has 9 actuators
Bittle has 9 position DOFs
Bittle has 9 velocity DOFs
Found lower leg body: servos_rf_1 (id=2)
Found lower leg body: servos_rr_1 (id=4)
Found lower leg body: servos_lf_1 (id=7)
Found lower leg body: servos_lr_1 (id=9)


  return literals.TypedNdArray(np.asarray(x, dtype), weak_type=False)
Exception ignored in: <function _xla_gc_callback at 0x7e1712463880>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/jax/_src/lib/__init__.py", line 129, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 


# **Visualize Policy**

In [None]:
eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# @markdown Commands **only used for Barkour Env**:
x_vel = 1.0  #@param {type: "number"}
y_vel = 0.0  #@param {type: "number"}
ang_vel = -0.5  #@param {type: "number"}

the_command = jp.array([x_vel, y_vel, ang_vel])

# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command'] = the_command
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 500
render_every = 2

for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(
    eval_env.render(rollout[::render_every], camera='track'),
    fps=1.0 / eval_env.dt / render_every)