# For Rendering Purpose Only

In [1]:
from datetime import datetime
import functools
from IPython.display import HTML
import PIL.Image
import yaml
from typing import List, Dict, Text, Callable, NamedTuple, Optional, Union, Any, Sequence, Tuple
from matplotlib import pyplot as plt
import mediapy as media
import wandb
import os

import numpy as np

from etils import epath
from flax import struct
from ml_collections import config_dict

import mujoco
from mujoco import mjx

from dm_control import mjcf as mjcf_dm
from dm_control import composer
from dm_control.locomotion.examples import basic_rodent_2020
from dm_control.composer.variation import distributions
from dm_control.locomotion.arenas import corridors as corr_arenas
from dm_control.locomotion.tasks import corridors as corr_tasks
from dm_control.locomotion.arenas import bowl as bowl_arenas
from dm_control.locomotion.tasks import escape as bowl_tasks
from dm_control.locomotion.walkers import ant #rodent, ant
from dm_control.locomotion.walkers import legacy_base
from dm_control import viewer
from dm_control import mujoco as mujoco_dm
from dm_control.composer.variation import rotations

import jax
from jax import numpy as jp
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State #MjxEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, model
from brax.io import mjcf as mjcf_brax

# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
# np.set_printoptions(precision=3, suppress=True, linewidth=100)

import vnl_brax.rodent_base as rodent



In [2]:
class Gap_Vnl(corr_arenas.GapsCorridor):
    def _build(self, corridor_width, corridor_length, visible_side_planes, aesthetic, platform_length, gap_length):
        super()._build(corridor_width=corridor_width,
                       corridor_length=corridor_length,
                       visible_side_planes=visible_side_planes,
                       aesthetic = aesthetic,
                       platform_length = platform_length,
                       gap_length = gap_length)
    
    def regenerate(self, random_state):
        super().regenerate(random_state)

# Task now just serve as a wrapper
class Task_Vnl_Gap(corr_tasks.RunThroughCorridor):
    def __init__(self,
               walker,
               arena,
               walker_spawn_position):
        
        # we don't really need the rest of the reward setup in dm_control, just how the walker is attached to the arena
        spawn_site =  arena.mjcf_model.worldbody.add('site', size=[1e-6]*3, pos = walker_spawn_position)
        self._arena = arena
        self._walker = walker
        self._walker.create_root_joints(self._arena.attach(self._walker, attach_site=spawn_site)) # customize starting environment
        self._walker_spawn_position = walker_spawn_position

_XML_PATH = 'assets/rodent_stac_optimized.xml'

# _RAT_MOCAP_JOINTS = [
#     'vertebra_1_extend', 'vertebra_2_bend', 'vertebra_3_twist',
#     'vertebra_4_extend', 'vertebra_5_bend', 'vertebra_6_twist',
#     'hip_L_supinate', 'hip_L_abduct', 'hip_L_extend', 'knee_L', 'ankle_L',
#     'toe_L', 'hip_R_supinate', 'hip_R_abduct', 'hip_R_extend', 'knee_R',
#     'ankle_R', 'toe_R', 'vertebra_C1_extend', 'vertebra_C1_bend',
#     'vertebra_C2_extend', 'vertebra_C2_bend', 'vertebra_C3_extend',
#     'vertebra_C3_bend', 'vertebra_C4_extend', 'vertebra_C4_bend',
#     'vertebra_C5_extend', 'vertebra_C5_bend', 'vertebra_C6_extend',
#     'vertebra_C6_bend', 'vertebra_C7_extend', 'vertebra_C9_bend',
#     'vertebra_C11_extend', 'vertebra_C13_bend', 'vertebra_C15_extend',
#     'vertebra_C17_bend', 'vertebra_C19_extend', 'vertebra_C21_bend',
#     'vertebra_C23_extend', 'vertebra_C25_bend', 'vertebra_C27_extend',
#     'vertebra_C29_bend', 'vertebra_cervical_5_extend',
#     'vertebra_cervical_4_bend', 'vertebra_cervical_3_twist',
#     'vertebra_cervical_2_extend', 'vertebra_cervical_1_bend',
#     'vertebra_axis_twist', 'vertebra_atlant_extend', 'atlas', 'mandible',
#     'scapula_L_supinate', 'scapula_L_abduct', 'scapula_L_extend', 'shoulder_L',
#     'shoulder_sup_L', 'elbow_L', 'wrist_L', 'finger_L', 'scapula_R_supinate',
#     'scapula_R_abduct', 'scapula_R_extend', 'shoulder_R', 'shoulder_sup_R',
#     'elbow_R', 'wrist_R', 'finger_R'
# ]

class Rodent_Vnl(rodent.Rat):
    def _build(self,
             params=None,
             name='walker',
             torque_actuators=False,
             foot_mods=False,
             initializer=None):
        self.params = params
        self._mjcf_root = mjcf_dm.from_path(_XML_PATH)
        if name:
            self._mjcf_root.model = name

        self.body_sites = []
        super()._build(initializer=initializer)
    
    # def actuators(self):
    #     """Return all actuators."""
    #     return tuple(self._mjcf_root.find_all('actuator'))
    
    # def egocentric_camera(self):
    #     """Return the egocentric camera."""
    #     return self._mjcf_root.find('camera', 'egocentric')

    # def end_effectors(self):
    #     """Return end effectors."""
    #     return (self._mjcf_root.find('body', 'lower_arm_R'),
    #             self._mjcf_root.find('body', 'lower_arm_L'),
    #             self._mjcf_root.find('body', 'foot_R'),
    #             self._mjcf_root.find('body', 'foot_L'))

    # def ground_contact_geoms(self):
    #     """Return ground contact geoms."""
    #     return tuple(
    #         self._mjcf_root.find('body', 'foot_L').find_all('geom') +
    #         self._mjcf_root.find('body', 'foot_R').find_all('geom') +
    #         self._mjcf_root.find('body', 'hand_L').find_all('geom') +
    #         self._mjcf_root.find('body', 'hand_R').find_all('geom') +
    #         self._mjcf_root.find('body', 'vertebra_C1').find_all('geom')
    #         )
    
    # def observable_joints(self):
    #     """Return observable joints."""
    #     return tuple(actuator.joint
    #                 for actuator in self.actuators  #  This lint is mistaken; pylint: disable=not-an-iterable
    #                 if actuator.joint is not None)
    
    # def mjcf_model(self):
    #     """Return the model root."""
    #     return self._mjcf_root
    
    # def root_body(self):
    #     """Return the body."""
    #     return self._mjcf_root.find('body', 'torso')

In [3]:
class Bowl_Vnl(bowl_arenas.Bowl):
    def _build(self, size=(10, 10), aesthetic='default', name='bowl'):
        super()._build(size=size, aesthetic=aesthetic, name=name)

    def regenerate(self, random_state):
        super().regenerate(random_state)

class Task_Vnl_Bowl(bowl_tasks.Escape):
    def __init__(self,
               walker,
               arena,
               walker_spawn_position):
        
        # we don't really need the rest of the reward setup in dm_control, just how the walker is attached to the arena
        spawn_site =  arena.mjcf_model.worldbody.add('site', size=[1e-6]*3, pos = walker_spawn_position)
        self._arena = arena
        self._walker = walker
        self._walker.create_root_joints(self._arena.attach(self._walker, attach_site=spawn_site)) # customize starting environment
        self._walker_spawn_position = walker_spawn_position    

In [4]:
arena = Gap_Vnl(platform_length=distributions.Uniform(.3, 2.5),
      gap_length=distributions.Uniform(.3, .5), # can't be too big
      corridor_width=10, # walker width follows corridor width
      corridor_length=100,
      aesthetic='outdoor_natural',
      visible_side_planes=False)
arena.regenerate(random_state=None)

arena_bowl = Bowl_Vnl(size=(10, 10),
                      aesthetic='default',
                      name='bowl')

vnl_ant = ant.Ant(observable_options={'egocentric_camera': dict(enabled=True)})
#rodent = Rodent_Vnl(observable_options={'egocentric_camera': dict(enabled=True)})
rodent = rodent.Rat(observable_options={'egocentric_camera': dict(enabled=True)})

task = Task_Vnl_Gap(
    walker=rodent,
    arena=arena,
    walker_spawn_position=(3, 0, 0))

# we need to bind everything to self.arena because that is the only thing we are putting into the MjModel
random_state = np.random.RandomState(123456)
task.initialize_episode_mjcf(random_state)

physics = mjcf_dm.Physics.from_mjcf_model(task.root_entity.mjcf_model) #.root_entity only returns the arena model, no reward, nothing

ValueError: Compile error raised by Mujoco; run again with --pymjcf_debug for additional debug information.
Error: plane only allowed in static bodies: geom 'walker/floor' (id = 61)
Object name = walker/floor, id = 61, line = 231
<geom name="walker/floor" class="walker/collision_floor" type="plane" size="10 10 0.025000000000000001" material="walker/grid" pos="0 0 -0.0050000000000000001"/>

In [None]:
class Walker(PipelineEnv):
  '''
  This is greatly coustomizable of what reward you want to give: reward engineering
  '''
  def __init__(
      self,
      forward_reward_weight=5.0,
      ctrl_cost_weight=0.1,
      healthy_reward=0.5,
      terminate_when_unhealthy=False, # should be false in rendering
      healthy_z_range=(0.0, 1.0), # healthy reward takes care of not falling, this is the contact_termination in dm_control
      train_reward=5.0,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,):
    '''
    Defining initilization of the agent
    '''

    #mj_model = mujoco.MjModel.from_xml_path(_XML_PATH)
    mj_model = physics.model.ptr
    # this is directly a mj_model already of type mujoco_py.MjModel (This is already a MJModel, same as previously in brax)
    # the original xml load is directly creaing an new MjModel instance, which carries the configuration of everything, including mjtCone
    # but this pass in one doesn't, it uses the default mjCONE_PYRAMIDAL, but MjModel now uses the eliptic model, so reset is needed

    # solver is an optimization system
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON #.mjSOL_CG
    mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL # Read documentation

    #Iterations for solver
    mj_model.opt.iterations = 2
    mj_model.opt.ls_iterations = 4

    sys = mjcf_brax.load_model(mj_model)

    # Defult framne to be 5, but can self define in kwargs
    physics_steps_per_control_step = 3
    
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    # Parents inheritence from MjxEnv class
    #super().__init__(model=mj_model, **kwargs)
    super().__init__(sys, **kwargs)

    # Global vraiable for later calling them
    self._model = mj_model
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._train_reward = train_reward
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (exclude_current_positions_from_observation)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""

    #Creating randome keys
    #rng = random number generator key for starting random initiation
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale

    #Vectors of generalized joint position in the configuration space
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )

    #Vectors of generalized joint velocities in the configuration space
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    #Reset everything
    obs = self._get_obs(data, jp.zeros(self.sys.nu)) #['proprioceptive']
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'train_reward': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics) # State is a big wrapper that contains all information about the environment

  def step(self, state: State, action: jp.ndarray) -> State: # push towards another State
    """Runs one timestep of the environment's dynamics."""
    #Previous Pipeline
    data0 = state.pipeline_state

    #Current pipeline state, step 1
    #Looking at the documenttaion of pipeline_step, "->" means return a modified State
    data = self.pipeline_step(data0, action)

    #Running forward (Velocity) tracking base on center of mass movement
    com_before = data0.subtree_com[3]
    com_after = data.subtree_com[3]

    #print(data.data)
    
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    train_reward = self._train_reward * self.dt # as more training, more rewards

    #Height being healthy
    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)

    #Termination condition
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    #Control force cost
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    #Feedback from env
    obs = self._get_obs(data, action)
    reward = forward_reward + train_reward + healthy_reward - ctrl_cost

    #print(obs)

    #Termination State
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        train_reward=train_reward,
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )
    return state.replace(pipeline_state=data, obs=obs, reward=reward, done=done)

  def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
    """environment feedback of observing walker's proprioreceptive and vision data"""

    # Vision Data Mujoco Version
    # passed in data is a pipeline_state.data object, pipeline_state is the sate
    renderer = mujoco.Renderer(model = self._model)

    # this here is the correct format, need qpos in calling
    #d = mjx.get_data(self._model, data)
    d = mujoco.MjData(self._model)

    mujoco.mj_forward(self._model, d)
    renderer.update_scene(d, camera=3) # can call via name too!
    image = renderer.render()
    image_jax = jax.numpy.array(image)
    print(f'image out of mujoco is {image_jax.shape}')
    # cam = mujoco.MjvCamera()

    # fake_image = jax.numpy.array(np.random.rand(64, 64, 3))
    # image_jax = fake_image.flatten() # fit into jp array

    o_height, o_width, _ = 240,320,3
    c_x,  c_y = o_width//2, o_height//2
    cropped_jax_image = image_jax[c_y-32:c_y+32, c_x-32:c_x+32, :]
    print(f'image cropped {cropped_jax_image.shape}')

    image_jax = cropped_jax_image.flatten()
    image_jax_noise = image_jax * 1e-12 # noise added
    print(f'image cropped flatened {image_jax_noise.shape}')

    # Proprioreceptive Data
    position = data.qpos
    velocity = data.qvel
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    proprioception = jp.concatenate([position, velocity])

    return jp.concatenate([proprioception, image_jax_noise])

In [None]:
envs.register_environment('walker', Walker)
env = envs.get_environment(env_name='walker')

In [None]:
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# initialize the state
state = jit_reset(jax.random.PRNGKey(0))

#Creating an container for rollout states
rollout = [state.pipeline_state]

# grab a trajectory
for i in (range(500)):
    ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

In [None]:
media.show_video(env.render(rollout,camera=2), fps=1.0 / env.dt)

## Actual Post Trained Model

In [None]:
#@title Load Model and Define Inference Function
model_path = './good_models/brax_vision_0.4'
params = model.load_params(model_path)

In [None]:
# This mimics the steps in the train.py file
import vnl_brax.networks_vision as brax_networks
#from brax.training.agents.ppo import networks as brax_networks

def make_inference_fn(observation_size: int, 
                      action_size: int,
                      normalize_observations: bool = False, 
                      network_factory_kwargs: Optional[Dict[str, Any]] = None):
    
    normalize = lambda x, y:x
    inference_ppo_network = brax_networks.make_ppo_networks(observation_size,
                                                  action_size,
                                                  preprocess_observations_fn=normalize,
                                                  **(network_factory_kwargs or {}))
    
    return brax_networks.make_inference_fn(inference_ppo_network)

make_policy = make_inference_fn(observation_size=env.observation_size, action_size=env.action_size)

# This should be correct

Remanber all jax's reward is not calculated from the obs space but rather from the data, which is collecting all environment information
- [jax.lax.scan](https://github.com/google/brax/blob/b68d9387f8c0b05271e0a5fd4cff8f851a256995/brax/envs/base.py#L121) in the `pipeline_step` function is where the error comes from
- Previously have not dealing with the data side directly, no `step` function is called directly
- This is not a problem with the image, the same error occurs even when just proprioreceptive data is feed into it!

Ask Charles

In [None]:
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_policy(params, deterministic=False))

# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 3000
render_every = 1

# might becasue brax does not clip the action to the xml limit in the model
for i in (range(n_steps)):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    #ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    #action = jp.zeros((env.action_size,))
    # print(ctrl) ctrl is correct
    # print(state.pipeline_state)

    state = jit_step(state, ctrl) # this line causing error
    rollout.append(state.pipeline_state)

    if state.done:
        break

In [None]:
media.show_video(env.render(rollout[::render_every], camera=2), fps=1.0 / env.dt / render_every)

In [None]:
media.show_video(env.render(rollout[::render_every], camera=3), fps=1.0 / env.dt / render_every)