# `Brax` pipeline adapted from `dm_control`
<img src="https://github.com/google/brax/raw/main/docs/img/a1.gif" width="200" height="200"> 
<img src="https://github.com/google/brax/raw/main/docs/img/humanoid_v2.gif" width="200" height="200"> 
<img src="https://github.com/google/brax/raw/main/docs/img/ant_v2.gif" width="200" height="200"> 
<img src="https://github.com/google/brax/raw/main/docs/img/ur5e.gif" width="200" height="200"> 

dm_control provides a very high level abstraction of Mujoco, which are both implemented by Google Deepmind. **This repository takes a lot of time to parse through it, understand it, and see how each things fit together, but this is very needed as this is the fundamental skills that is needed in research and developing new things:**
1. It is the first step towards understanding what you need to do to create.
2. It is also generally a good skill to have when learning a very powerful highly abstarcted new tool, parsing through many repository classes, seeing how classes' hierchy is connected and how to find the `right level of abstraction`, seeing inherietence relationships between classes and chasing back all the way to the begining.

# Linux GPU Rendoring
##### These are imports, installation, and switch backends needed for `Linux GPU Rendoring`
- osmesa is working, all dependencies are installed, but Mujoco rendoring does not support osmesa, only the other ones, which does have installation issues
- specifically the `mujoco.viwer` packages

In [1]:
# #@title Colab Only
# !pip install brax
# !pip install wandb
# !pip install mujoco
# !pip install dm_control
# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
# !pip install -q mediapy
# !pip install tqdm

In [2]:
# import distutils.util
# import os
# import subprocess

# if subprocess.run('nvidia-smi').returncode:
#   raise RuntimeError(
#       'Cannot communicate with GPU. '
#       'Make sure you are using a GPU Colab runtime. '
#       'Go to the Runtime menu and select Choose runtime type.')

# # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# # This is usually installed as part of an Nvidia driver package, but the Colab
# # kernel doesn't install its driver via APT, and as a result the ICD is missing.
# # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
# NVIDIA_ICD_CONFIG_PATH = '/usr/bin/nvidia-smi'
# if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
#   with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
#     f.write("""{
#     "file_format_version" : "1.0.0",
#     "ICD" : {
#         "library_path" : "libEGL_nvidia.so.0"
#     }
# }
# """)

# # Configure MuJoCo to use the EGL rendering backend (requires GPU)
# print('Setting environment variable to use GPU rendering:')

# !PYOPENGL_PLATFORM=osmesa
# %env MUJOCO_GL=osmesa
# os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
# !pip install --upgrade PyOpenGL PyOpenGL_accelerate

# Environment Building
References to dm_control implementation of customized environment. These are the super high level libs, locomotion lib should have everything to get a rl agent running in dm control and the mjcf lib should have everything about building a connection class.
- [dm_control locomotion lib](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion)
- [dm_control mjcf lib](https://github.com/google-deepmind/dm_control/tree/main/dm_control/mjcf)
- [dm_control lib](https://github.com/google-deepmind/dm_control/tree/main)
- [mujoco](https://mujoco.readthedocs.io/en/3.1.1/overview.html)
- [brax](https://github.com/google/brax)

# Organization
- `locomotion class`
  - `arena subclass`
    - i.e. corridors -> gap corridors
  - `walkers subclass`
    - assets (xml)
    - i.e. rodent -> rat
  - `tasks subclass`
    - i.e. jump gaps
- `mjcf class`
  - class we want to build using idea from corridors
- `composer class`
  - arena.py/arena.xml provides foundation for `locomotion class`

  This really took a quite long time to figure out:
  1. We can use `tasks_subclass` by converting it to mjcf model, then MjModel
  2. Inherentence can be used to build pur own functions needed in `task`, `arena`, and `walker` (work smart)

## Collective Efforts
This is a lower level abstraction compare to the previous one, but it is still really high level, calling functions and environental setup that dm_control have already implemented.

The skill now is to parse through these very convoluted layers to see the `real idea` behind all the implementation and the `pipeline` of how each things are called. In this way, you can know all the way to the implementation level and then decide on which `level of abstraction` you want to stay on to achieve the purpose you need while fixing minimum things and utilize functions that others have already written to complete your goals so
>### You can spend the time on building on top of it and actually doing more useful things to create new innovations instea of doing redendent work. Then later your work may become a function, a abstraction level, that other people call to build more works in the field. This is a collaboration project that takes a collective effort towards a common goal, this is the powerfullness of python libs, to not reinvent the wheels but use the wheels to build the car, to not recreate data structure for storing data but use the data for more things.

Therefore, staying on the right level of abstraction and using implemented functions is very important (just like we used brax implementation of ppo, we can use dm_control's implementation of customized env)

- Brax mjx -> Mujoco env -> modify using dm_control implemented packages -> dm_control mjcf
 - On mjx level? on dm_control implemented env sample level? on mjcf level? on xml level?

 ***Do notice that these things are already at the task level, which is already an execution of (model + environment), the correct level of `abstraction` should be just the background environment level, not the task level, this is the same with brax***

In [3]:
from datetime import datetime
import functools
from IPython.display import HTML
import PIL.Image
import yaml
from typing import List, Dict, Text, Callable, NamedTuple, Optional, Union, Any, Sequence, Tuple
from matplotlib import pyplot as plt
import mediapy as media
import wandb
import os

import numpy as np

from etils import epath
from flax import struct
from ml_collections import config_dict

import mujoco
from mujoco import mjx

from dm_control import mjcf as mjcf_dm
from dm_control import composer
from dm_control.locomotion.examples import basic_rodent_2020
from dm_control.composer.variation import distributions
from dm_control.locomotion.arenas import corridors as corr_arenas
from dm_control.locomotion.tasks import corridors as corr_tasks
from dm_control.locomotion.walkers import rodent, ant
from dm_control import viewer
from dm_control import mujoco as mujoco_dm
from dm_control.composer.variation import rotations

import jax
from jax import numpy as jp
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, MjxEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, model
from brax.io import mjcf as mjcf_brax

# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
# np.set_printoptions(precision=3, suppress=True, linewidth=100)



## `dm_conrol` - Task Level Running
This code is directly implementation in the Mujoco framework:

There are existing examples of the framework that calls from more basic components of the dm control lib. These examples are already super high level abstractions already

In [4]:
_CONTROL_TIMESTEP = .02
_PHYSICS_TIMESTEP = 0.001
random_state=None

walker = rodent.Rat(observable_options={'egocentric_camera': dict(enabled=True)})

  # Build a corridor-shaped arena with gaps, where the sizes of the gaps and
  # platforms are uniformly randomized.
arena = corr_arenas.GapsCorridor(
      platform_length=distributions.Uniform(.4, .8),
      gap_length=distributions.Uniform(.05, .2),
      corridor_width=2,
      corridor_length=40,
      aesthetic='outdoor_natural')

  # Build a task that rewards the agent for running down the corridor at a
  # specific velocity.
task = corr_tasks.RunThroughCorridor(
      walker=walker,
      arena=arena,
      walker_spawn_position=(5, 0, 0),
      walker_spawn_rotation=0,
      target_velocity=1.0,
      contact_termination=False,
      terminate_at_height=-0.3,
      physics_timestep=_PHYSICS_TIMESTEP,
      control_timestep=_CONTROL_TIMESTEP)

env_composed = composer.Environment(time_limit=30,
                              task=task,
                              random_state=random_state,
                              strip_singleton_obs_buffer_dim=True)

In [5]:
#viewer.launch(environment_loader=env_composed)

***
# `dm_control` Arena Backbone Level Engineering
***

These are just the the backbone skeleton in the [Arena_folder](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion/arenas)
- Gap Corridor -> Empty Corridor -> Corridor -> composer.arena -> xml
- Rodent -> xml
## First Idea:
Making the intermediate class that connects the corridor defined environment (corridor environment is in the arena folder, not the task folder)
* [Tasks_folder](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion/tasks) is directly running the algorithm already
* [Arena_folder](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion/arenas) provides the actual skeleton, this is the correct level of `abstraction` that we are looking for
* > ##### Tasks = Walker + Arena

> ##### problem: we can not directly use the arena implemented by dm_control nor can we use the walker implemented by dm_control, they are too specifically designed by dm_control

## Second Idea:
1. Directly build an arena through intuitions from the corriodr class
2. Link the rodent or humanoid xml to such environment
3. because it is created by the mjcf class in dm_control, we can directly export it as a fix file and it cna be utilized in brax
    - essentially we skip some of the really convoluted implementation of dm_control from `tasks -> arena + walker -> composer class -> xml`
    - we are directly building it from scratch using `mjcf.RootElement()`
4. mjcf model here direcly uses Mujoco, there is no need to initiate another mujoco in brax training loop, we directly pass an model into brax

> ##### problem: this is way too inefficient and there should be a much smarter way working with it

## Third Idea:
1. Seems like that the Corridor class can be directly transfered into a mjcf model by `.to_mjcf_model()`
2. Also self implemented a Corridor class that is directly based on `mjcf.RootElement()` and return the model using `.out()` functions
3. Try to inherent directly from the Corridor class and assing some more things

> ##### problem: so far Inherent class idea is implemented quite well, but binding not assigned

## Forth Idea:
1. Directly wrap evrything once `binding` is down with the agent and also the environment at the `tasks` level
2. task level can also be exported as an mjcf_model file and the same storing as in MjModel once ptr, same with the brax humanoid that was previously implemented in the training loop.
3. Created inherietence from task module and the directly bind it to an existing waler construct file (fix ours later)

> ##### problem: this `task wrapper binded model` has friction setup that Mujoco does not support

***
# Directly Inherent Binding Class for Brax Engineering
***

## Primary Goal:
1. Create direct inheritence that from dm_control (`walker, arena, tasks -> wrapper to physics`)
2. Build a pipeline for using customization functions from dm_control for brax training by implementing particular classes that we can change to fit our our needs for the `arena`, `walker`, and `tasks binding` (binds walker and arena).

## Debug Log for Binding
### We need to try to unbound the ant model so it can actually move when trained
- attach function documenttaion [here](https://github.com/google-deepmind/dm_control/blob/9e360bf8a069868107e52960a97dcafe2292113f/dm_control/composer/entity.py#L299)
- maybe we can use this function [here](https://github.com/google-deepmind/dm_control/blob/9e360bf8a069868107e52960a97dcafe2292113f/dm_control/composer/arena.py#L62), but this function also uses the `attach` function.
- cannot initlize the model and then put it back because the output is only the arena file
- switch `._mjcf_root` to `.mjcf_model`, the model need to be attached to a valid "child" but not fully connected to the ground, `site` should work, it is adding a new object into the arena with certain size: from rendered videos, we can also see that it works.
- might be problem with learning to walk in brax then.

In [6]:
class Gap_Vnl(corr_arenas.GapsCorridor):
    def _build(self, corridor_width, corridor_length, visible_side_planes, aesthetic, platform_length, gap_length):
        super()._build(corridor_width=corridor_width,
                       corridor_length=corridor_length,
                       visible_side_planes=visible_side_planes,
                       aesthetic = aesthetic,
                       platform_length = platform_length,
                       gap_length = gap_length)

        # self.walker = mjcf.from_path("./models/rodent_dm_control.xml")
        # # Initiate walker
        # self.spawn_pos = (0, 0, 0)
        # self.spawn_site =  self._mjcf_root.worldbody.add('site', pos=self.spawn_pos)
        # self.spawn_site.attach(self.walker).add('freejoint')
    
    def regenerate(self, random_state):
        super().regenerate(random_state)

# Task now just serve as a wrapper
class Task_Vnl(corr_tasks.RunThroughCorridor):
    def __init__(self,
               walker,
               arena,
               walker_spawn_position):
        
        # we don't really need the rest of the reward setup in dm_control, just how the walker is attached to the arena
        spawn_site =  arena.mjcf_model.worldbody.add('site', size=[1e-6]*3, pos = walker_spawn_position)
        self._arena = arena
        self._walker = walker
        self._walker.create_root_joints(self._arena.attach(self._walker, attach_site=spawn_site)) # customize starting environment
        self._walker_spawn_position = walker_spawn_position
        #self._arena.add_free_entity(self._walker)

        # Instead of relying on env to instantiate task and bind artificial agent to arena, the task class itself now can bind artificial agent to arena
        
# class Walker_Vnl(ant):
#     def _build(self, name='walker', marker_rgba=None, initializer=None):
#         super()._build(initializer=initializer)
#         self._appendages_sensors = []
#         self._bodies_pos_sensors = []
#         self._bodies_quats_sensors = []

#         _XML_DIRNAME = os.path.join(os.path.dirname(__file__), '../../third_party/ant')
#         _XML_FILENAME = 'ant.xml'

#         self._mjcf_root = mjcf_dm.from_path(os.path.join(_XML_DIRNAME, _XML_FILENAME))
        
#         if name:
#             self._mjcf_root.model = name

#         # Set corresponding marker color if specified.
#         if marker_rgba is not None:
#             for geom in self.marker_geoms:
#                 geom.set_attributes(rgba=marker_rgba)

#         # Initialize previous action.
#         self._prev_action = np.zeros(shape=self.action_spec.shape,
#                                     dtype=self.action_spec.dtype)


In [7]:
arena = Gap_Vnl(platform_length=distributions.Uniform(1.5, 2.0),
      gap_length=distributions.Uniform(.05, .2),
      corridor_width=5, # walker width follows corridor width
      corridor_length=40,
      aesthetic='outdoor_natural',
      visible_side_planes=False)

# arena.regenerate(random_state=None)
# physics = mjcf_dm.Physics.from_mjcf_model(arena.mjcf_model)
# PIL.Image.fromarray(physics.render())

walker = ant.Ant(observable_options={'egocentric_camera': dict(enabled=True)})

task = Task_Vnl(
    walker=walker,
    arena=arena,
    walker_spawn_position=(5, 0, 0))

# we need to bind everything to self.arena because that is the only thing we are putting into the MjModel
random_state = np.random.RandomState(123456)
task.initialize_episode_mjcf(random_state)

physics = mjcf_dm.Physics.from_mjcf_model(task.root_entity.mjcf_model) #.root_entity only returns the arena model, no reward, nothing
#physics = mjcf_dm.Physics.from_mjcf_model(walker.mjcf_model)

There is no need to step the environment in dm_control then take the model again like this:
- walker_joints = walker.mjcf_model.find_all('joint')
- physics.bind(walker_joints).qpos = random_state.uniform(size=len(walker_joints))
- task.initialize_episode(physics, random_state) # self.walker position changes
- physics2 = mjcf_dm.Physics.from_mjcf_model(task.root_entity.mjcf_model)

[This](https://github.com/google-deepmind/dm_control/blob/9e360bf8a069868107e52960a97dcafe2292113f/dm_control/composer/entity.py#L299) is how we fixed this issue, we directly change the attach model calling by giving in a predefined existing location in the mjcf file.

# Rendering checking states in `dm_control.viewer`

In [8]:
# env_composed = composer.Environment(time_limit=50,
#                               task=task,
#                               random_state=random_state,
#                               strip_singleton_obs_buffer_dim=True)
# viewer.launch(environment_loader=env_composed)

# Rendoring with Mujoco
We want to check if the rendoringproblem after entering brax env is caused by the MjModel itself, so we rendor it here directly fitrst to see if it actually works!

- From playing around with the rendering in mujoco, we figured out a key idea: ***camera options in the dm_control MjModel is indicated by numerical number***
1. -1 is far side view
2. 0 is closer side view
3. 1 is back view
4. 2 is closer back view
5. 3 is egocentric view

In [9]:
mj_model = physics.model.ptr
mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL

mjx_model = mjx.put_model(mj_model)
mj_data = mujoco.MjData(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model)

# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 5  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option,camera=0)
    pixels = renderer.render()
    frames.append(pixels)

# Simulate and display video.
media.show_video(frames, fps=framerate)

0
This browser does not support the video tag.


# Jax stepping version
Checksing with Jax stepping

In [10]:
# jit_step = jax.jit(mjx.step)

# frames = []
# mujoco.mj_resetData(mj_model, mj_data)
# mjx_data = mjx.put_data(mj_model, mj_data)
# while mjx_data.time < duration:
#   mjx_data = jit_step(mjx_model, mjx_data)
#   if len(frames) < mjx_data.time * framerate:
#     mj_data = mjx.get_data(mj_model, mjx_data)
#     renderer.update_scene(mj_data, scene_option=scene_option)
#     pixels = renderer.render()
#     frames.append(pixels)

# media.show_video(frames, fps=framerate)

***
# Walker Class Adapted Engineering
***

Instead of directly getting the `Mujoco xml` or `mjcf` file, we get it directly by instantiating an customized `dm_control pymjcf` class and then extract it from there directly.

## State Object in Brax
[State Object](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/envs/base.py#L35C1-L43C60) contaisn all information about the environment at a given state, it is in the [environment folder](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/envs/base.py). State object inherent the [Base class](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/base.py#L38).
- The State object is not refering to the [MjxState Object](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/mjx/base.py#L22), this is just an intermediate step for data passing
- The State object is refering to the [Base.State Object](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/base.py#L397)

In [11]:
# class State(base.Base):
#   """Environment state for training and inference."""
#   pipeline_state: Optional[base.State]
#   obs: jax.Array
#   reward: jax.Array
#   done: jax.Array
#   metrics: Dict[str, jax.Array] = struct.field(default_factory=dict)
#   info: Dict[str, Any] = struct.field(default_factory=dict)

## Uopdated Version?
- Really funny mistakes that caused 2 hours of debugging: the code took direct inherietence of the previous class, which caused all the problems. However, better understanding of the pipeline and all the complex relationships in Brax. **If there are official changes in Brax, they will notice the user**
- All your code should work, even with packages update, the main thing should be fine, just with small update it should work again
    - Really ensure that your environment is what you are thinking you are working in
    - Keep all the environment setting so if necessary, you can reset it on the site

## Subtree Object in Mjx
Let's look at this line:
- `com_before = data0.data.subtree_com[1]`

### data0
- data0 in the implementation is a [`pipeline_state`](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/envs/base.py#L38) object -> which is a [`mjx.base.State`](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/mjx/base.py#L22) object (transition class)->  data of it is `Mjx.data` format


### subtree_com[1]
- **Center of Mass (COM)**: This is the point in an object or system of objects where the entire mass can be considered to be concentrated. For physical simulations, especially those involving dynamics and kinematics, calculating the COM is essential for accurately determining how forces and movements will affect the object.

- **Subtree**: A subtree refers to a part of the overall tree structure. For a physical body, this might mean a limb or a combination of a limb and its subsequent parts. For example, an entire arm, from shoulder to fingertips, can be considered a subtree of the body.

## Observe Data Feedback
This is all the data feed back from the environment for deternmining any rewards:
1. data.qpos - **Position**
2. data.qvel - **Velocity**
3. data.cinert[1:].ravel() - **Inertia Matrix**
4. data.cvel[1:].ravel() - **Velocity of Inertia**
5. data.qfrc_actuator **Acutator Forces**



### Detailed Observation Components:
1. **Position**: Extracts the positions (qpos) of the humanoid body. (If self._exclude_current_positions_from_observation is True, it excludes the first two elements of the position vector. This could be useful if you want to exclude certain position information from the observation.)

2. **Velocities**: Appends the velocities (qvel) of the humanoid body.

3. **Inertia Matrix**: Appends the inertia matrix (cinert) excluding the first row. (This matrix represents the inertia of the body segments.) Inertia helps to examine the distribution of mass in the humanoid and then calculates the relatyionship it would have with the forces that may be generated to accelerate or deaccelerate

4. **Velocity of Inertia**: Appends the velocity of the inertia matrix (cvel) excluding the first row.

5. **Actuator Forces**: Appends the actuator forces (qfrc_actuator). Actuators are typically modeled as components that generate forces or torques to drive the movement of joints in a simulated robotic system. These forces or torques are applied to the joints of the simulated body, affecting its motion.

## Rewards

### 3 Main Rewards
The reward consists of three parts:
- **reward_alive**: Every timestep that the humanoid is alive, it gets a reward of 5.

- **forward_reward**: A reward of walking forward which is measured as *1.25 * (average center of mass before action - average center of mass after action) / dt*. *dt* is the time between actions - the default *dt = 0.015*. This reward would be positive if the humanoid walks forward (right) desired. The calculation for the center of mass is defined in the `.py` file for the Humanoid.

- **reward_quadctrl**: A negative reward for penalising the humanoid if it has too large of a control force. If there are *nu* actuators/controls, then the control has shape  `nu x 1`. It is measured as *0.1 **x** sum(control<sup>2</sup>)*. ***Essentially control force penalty makes the movement much more realistic***
    - Control forces are often the output of a control system or a learned policy (in the case of reinforcement learning), dictating how the entity should move to achieve certain objectives, like walking, jumping, or picking up objects.
    - High control forces can lead to aggressive, unstable, or unsafe movements, which might damage the simulated entity or the environment.
    - High control forces often result in abrupt, jerky movements that are not only inefficient but can also be unrealistic, especially in simulations aiming to mimic biological movements.

### Starting State
All observations start in state (0.0, 0.0,  1.4, 1.0, 0.0  ... 0.0) with a uniform noise in the range of [-0.01, 0.01] added to the positional and velocity values (values in the table) for stochasticity. Note that the initial z coordinate is intentionally selected to be high, thereby indicating a standing up humanoid. The initial orientation is designed to make it face forward as well.

### Episode Termination
The episode terminates when any of the following happens:

1. The episode duration reaches a 1000 timesteps
2. The z-coordinate of the torso (index 0 in state space OR index 2 in the table) is **not** in the range `[0.8, 2.1]` (the humanoid has fallen or is about to fall beyond recovery).

In [12]:
# MjxEnv is directly an API to the Mujoco mjx
class Walker(MjxEnv):
  '''
  This is greatly coustomizable of what reward you want to give: reward engineering
  '''
  def __init__(
      self,
      forward_reward_weight=5.0,
      ctrl_cost_weight=0.1,
      healthy_reward=0.5,
      terminate_when_unhealthy=False,
      healthy_z_range=(0.0, 1.0), # healthy reward takes care of not falling, this is the contact_termination in dm_control
      distance_reward=5.0,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,):
    '''
    Defining initilization of the agent
    '''

    mj_model = physics.model.ptr
    # this is directly a mj_model already of type mujoco_py.MjModel (This is already a MJModel, same as previously in brax)
    # the original xml load is directly creaing an new MjModel instance, which carries the configuration of everything, including mjtCone
    # but this pass in one doesn't, it uses the default mjCONE_PYRAMIDAL, but MjModel now uses the eliptic model, so reset is needed

    # solver is an optimization system
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL # Read documentation

    #Iterations for solver
    mj_model.opt.iterations = 1
    mj_model.opt.ls_iterations = 4

    # Defult framne to be 5, but can self define in kwargs
    physics_steps_per_control_step = 5
    
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)

    # Parents inheritence from MjxEnv class
    super().__init__(model=mj_model, **kwargs)

    # Global vraiable for later calling them
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._distance_rewaed = distance_reward
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (exclude_current_positions_from_observation)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""

    #Creating randome keys
    #rng = random number generator key for starting random initiation
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale

    #Vectors of generalized joint position in the configuration space
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )

    #Vectors of generalized joint velocities in the configuration space
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    #Reset everything
    obs = self._get_obs(data.data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_reward': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics) # State is a big wrapper that contains all information about the environment

  def step(self, state: State, action: jp.ndarray) -> State: # push towards another State
    """Runs one timestep of the environment's dynamics."""
    #Previous Pipeline
    data0 = state.pipeline_state

    #Current pipeline state, step 1
    #Looking at the documenttaion of pipeline_step, "->" means return a modified State
    data = self.pipeline_step(data0, action)

    #Running forward (Velocity) tracking base on center of mass movement
    com_before = data0.data.subtree_com[5]
    com_after = data.data.subtree_com[5]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    #Reaching the target location distance
    #distance = state.metrics['distance_from_origin']
    #distance_reward = [self._distance_rewaed * distance if isinstance(distance, int) else 0]
    distance_reward = self._distance_rewaed * velocity[0] * self.dt

    #Height being healthy
    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)

    #Termination condition
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    #Control force cost
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    #Feedback from env
    obs = self._get_obs(data.data, action)
    reward = forward_reward + distance_reward + healthy_reward - ctrl_cost

    #Termination State
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_reward=distance_reward,
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )
    return state.replace(pipeline_state=data, obs=obs, reward=reward, done=done)

  def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos

    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    # environment observation described later
    return jp.concatenate([
        position,
        data.qvel,
        # data.cinert[1:].ravel(),
        # data.cvel[1:].ravel(),
        # data.qfrc_actuator,
    ])

In [13]:
# Registering the environment setup in env as humanoid_mjx
envs.register_environment('walker', Walker)

## Not Walking Debug Log
1. Not a binding issue, rendering with random control force show that it is moving + drop from height works, not attached to the arena. **Not a env defining issue**.
2. Not a observation space mismatch issue because observations pace tells the agent all information about itself in the environment, if healthy reward is working, observation space is matching because it have propriorteeptive input. **Not a Brax env issue**
3. Does the agent have sufficient information about it's velocity and walking forward? Does it know that? Brax treat the whole "Walker Class" as an agent and define reward in that sense, `self` is refering to the artificial agent. However, we passed in the arena with the walker into it.
    - The ant walks just as fine when it is passing in just the ant model and also when passing in as the arena binding model. It may not be the issue of model passed into the brax class because randomly initiate control forces does pass in to the walker correctly.
    - The velocity (key for deciding forward reward) does look at the center of mass movement for the torso, we need to identiy which center of mass we are looking at. the center of mass has an array of float32[17,3], we need to find **the right center of mass**.

### Findings:
1. Reward assigning need to much more specific, we are dealing with a much more complicated situation where the passed in model is a arena and a walker combined.
    - They are not fully bind, we need to extract each's correct sub components to assign reward, much more accurate reward assigning (COM check maybe?)
    - Some reward working, observation space feedback is successful
    - In Brax, the class is designed directly for one artificial agent, so grabbing all teh COM and assigning reward to `self` is much more simple, but in thi  modified case, we are using the class as a monitor of the full environment, so reward assigning must be precise to the correct sub_tree section.
2. The brax class now is a big monitor of both the `arena` + `the walker`, need to make sure reward assign to correc `sub_tree`.
3. This is the correct perspective, changing COM does now give non-zero distance and forward reward! Maybe we can use `subtree_linvel`
    - The subtree linear velocity represents the velocity of all the connected bodies under the specified body.
    - Actually no, this requires using the physics module directly, which we don't have

***
# Training Configuration
***

This is a direct easy view of some of the most important hyperparametyer tuning that we need todo. Also visit PPO Documentation for more details (https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py).
1. `num_env`:
    - Number of environment is refering to the number of parrallel environment that the agent is traine on. In another word, it is how many instances of the registered environment that have being activated and trained in the same time.
    - It creates more robust and diverse policy
    - Agent learns quicker because it gathers experiences from all of them and just choose the best policy
2. `num_timesteps`:
    - Total number of interactions that will happen between the agent and environment
3. `eval_every`:
    - eval_every is the learning time, it is how often do we update the policy parameter. This is the same with the horizon idea N in the pytorch simple version PPO we implemented
4. `episode_length`:
    - episode length is the number of timesteps that constitute one episode. In here it would be 1000 * 10_000 episodes.
5. `num_evals`:
    - In total how many evals there are, should be num_timesteps/eval_every
6. `batch_size`:
    - Batch size is the number of samples extracted from the "replay buffer" for calling as previous experiences everytime during bellman equation optimization for 1 batch.
    - This 1 batch is used for 1 SGD only, we SGD multiple times with multiple subsamples or "mini batches"
    - In brax implementation of ppo, it doesn't have an replay buffer, but the idea is the same.
7. `num_minibatches`:
    - This is how many batches there are by splitting all data in the replay buffer to little chunks
    - This is also how many times the SGD is run
    - num_batch * batch_size = all_data
    - as replay buffer increase, num_minibatches is the same, each batch_size increases
8. `unroll_length`:
    - The number of timesteps to unroll in each environment. Instead of looking at one time step on the trajectory, unroll_length helps to look at n-steps on the trajectory and collect all the data at once for computational efficiency

***
# Rendoring a Rollout Engineering
***

You need `ffmpeg` to render a rollout, mac rendering can avoid the GL error on Linux serevr,  now rendoring on mac works just fine!
The environment actually works -> establishes and step successfuly. 
- For Linux Server rendering, you need to run this pipeline in a Gpu based and having neccessary colab import and installation with handling GL error.

dm_control model did implement `camera`, but the name is different from brax's model, so we need to find the specific name for the specific model (camera is usually implemeneted in the walker's xml file, it is not something that is normally directly ccaried by MjModle, but it can be set as well):

camera default argument is actually None in brax documentation here in this [line](https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/base.py#L205)

In [14]:
env = envs.get_environment(env_name='walker')

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# initialize the state
state = jit_reset(jax.random.PRNGKey(0))

#Creating an container for rollout states
rollout = [state.pipeline_state]

# grab a trajectory
for i in (range(500)):
    ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    # -1 * jp.ones(env.sys.nu) is just gravity
    # positive is flattening
    # negative grabs tight
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

# media.show_video(env.render(rollout,camera=None), fps=1.0 / env.dt)

In [15]:
media.show_video(env.render(rollout,camera=1), fps=1.0 / env.dt)

0
This browser does not support the video tag.


## Debugging Logs & Ideas:
1. Follow error message and function descriptions, check all the way to `mujoco` render level, then to `brax`
2. Ran mujoco rnder separately, `MjModel` is able to render
3. brax render uses mujoco render, mujoco render works fine, the `MjModel` should be correct, may be problem with `states` in rollout `trajectory`
4. problem probably occur on this [line](https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/base.py#L214) where `d` has shape of 8 but renderer only takes in shape of 7
5. We may need to define our own rendoring function by shrinking some of the inputting array in `d`
6. It may be the env setup for brax here is for humanoid, not for ant (humanoid model in dm_control has mj_mopdel.geo_condim issue (mjx only support geo_condium of 3) and rodent has tendon issue)
7. Reimplemented ant using a simpler implementation instead of the complex one, seems like still encounter the issue
8. May have to reimplement the walker class first for this to work
    - check geom_condim
    - check tendon
9. The mismatch issue may come from mismatch training loop
10. previous wrong understanding, training loop for brax env is correct, wrong looking previously because we direcly load in state, not loading in `.Data`
11. indeed it is still this [line](https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/base.py#L214) that have error occuring, I implemented the rendoring function step by step and it is at this step the error message `ValueError: could not broadcast input array from shape (8,) into shape (7,)` occurs
12. Now actually everything works! Practiced debugging going to really small detailed level and trying out bit by bit.

### Really getting in-depth understanding
What `mjx.get_data` is getting really is the date of one aspect of the rollout, or one state of it `rollout[i]`. so `state.data`

In [16]:
def render(
      self, trajectory: List[base.State], camera: Union[str, None] = None
  ) -> Sequence[np.ndarray]:
    """Renders a trajectory using the MuJoCo renderer."""
    renderer = mujoco.Renderer(self._model)
    camera = camera or -1

    def get_image(state: mjx.Data):
      d = mjx.get_data(self._model, state) # This line have error occuring, shows same error message
      mujoco.mj_forward(self._model, d)
      renderer.update_scene(d, camera=camera)
      return renderer.render()

    return [get_image(s.data) for s in trajectory]  # pytype: disable=attribute-error

In [17]:
from mujoco.mjx._src import constraint
from mujoco.mjx._src import types

##### Idea 1: Know which component is which when calling, take important part that is relevant to debugging and then run separately

In [18]:
# brax renderer level
d = rollout[0].data
m = mj_model

# mjx.get_data level
result = mujoco.MjData(mj_model)

# mjx.get_data_into level
d = jax.device_get(d)
batched = isinstance(result, list)
batch_size = d.qpos.shape[0] if batched else 1

# for i in range(batch_size):
#     d_i = jax.tree_map(lambda x, i=i: x[i], d) if batched else d
#     result_i = result[i] if batched else result
#     efc_active = (d_i.efc_J != 0).any(axis=1)
#     nefc = efc_active.sum()
#     result_i.efc_J_rownnz[:] = np.repeat(m.nv, nefc)
#     result_i.efc_J_rowadr[:] = np.arange(0, nefc * m.nv, m.nv)

##### Idea 2: Take out big chunk of the code and try to remove different things to see what is causing the shape error

In [19]:
ne, nf, nl, nc = constraint.count_constraints(m)
efc_type = np.array([
    mujoco.mjtConstraint.mjCNSTR_EQUALITY,
    mujoco.mjtConstraint.mjCNSTR_FRICTION_DOF,
    mujoco.mjtConstraint.mjCNSTR_LIMIT_JOINT,
    mujoco.mjtConstraint.mjCNSTR_CONTACT_PYRAMIDAL,
]).repeat([ne, nf, nl, nc])

dof_i, dof_j = [], []
for i in range(m.nv):
    j = i
while j > -1:
    dof_i.append(i)
    dof_j.append(j)
    j = m.dof_parentid[j]

for i in range(batch_size):
    d_i = jax.tree_map(lambda x, i=i: x[i], d) if batched else d
    result_i = result[i] if batched else result
    ncon = (d_i.contact.dist <= 0).sum()
    efc_active = (d_i.efc_J != 0).any(axis=1)
    efc_con = efc_type == mujoco.mjtConstraint.mjCNSTR_CONTACT_PYRAMIDAL
    nefc, nc = efc_active.sum(), (efc_active & efc_con).sum()
    result_i.nnzJ = nefc * m.nv
    mujoco._functions._realloc_con_efc(result_i, ncon=ncon, nefc=nefc)  # pylint: disable=protected-access
    result_i.efc_J_rownnz[:] = np.repeat(m.nv, nefc)
    result_i.efc_J_rowadr[:] = np.arange(0, nefc * m.nv, m.nv)
    result_i.efc_J_colind[:] = np.tile(np.arange(m.nv), nefc)

ValueError: could not broadcast input array from is `np.arange(0, nefc * m.nv, m.nv)` into is `result_i.efc_J_rowadr[:]`

***When running the whole structure, the shape seems to match up -> call rendor try -> call mediapy try***

##### Idea 3: Separate fundamental functions to examine them separately

Seems like the problem lays in this `mjx.get_data` function, which is [here](https://github.com/google-deepmind/mujoco/blob/8be966cdf9073813ec8b494062f4d97848432057/mjx/mujoco/mjx/_src/io.py#L235) and related to the `get_data_into` function, particularly this [line](https://github.com/google-deepmind/mujoco/blob/8be966cdf9073813ec8b494062f4d97848432057/mjx/mujoco/mjx/_src/io.py#L293). I think this is related to [this line's shape problem](https://github.com/google-deepmind/mujoco/blob/8be966cdf9073813ec8b494062f4d97848432057/mjx/mujoco/mjx/_src/io.py#L284C4-L284C47)

In [20]:
def get_data(
    m: mujoco.MjModel, d: types.Data
) -> Union[mujoco.MjData, List[mujoco.MjData]]:
  """Gets mjx.Data from a device, resulting in mujoco.MjData or List[MjData]."""
  batched = len(d.qpos.shape) > 1
  batch_size = d.qpos.shape[0] if batched else 1

  if batched:
    result = [mujoco.MjData(m) for _ in range(batch_size)]
  else:
    result = mujoco.MjData(m)

  get_data_into(result, m, d)

  return result

In [21]:
def get_data_into(
    result: Union[mujoco.MjData, List[mujoco.MjData]],
    m: mujoco.MjModel,
    d: types.Data,
):
  """Gets mjx.Data from a device into an existing mujoco.MjData or list."""
  batched = isinstance(result, list)
  if batched and len(d.qpos.shape) < 2:
    raise ValueError('dst is a list, but d is not batched.')
  if not batched and len(d.qpos.shape) >= 2:
    raise ValueError('dst is a an MjData, but d is batched.')

  d = jax.device_get(d)

  batch_size = d.qpos.shape[0] if batched else 1
  
  ne, nf, nl, nc = constraint.count_constraints(m)
  efc_type = np.array([
      mujoco.mjtConstraint.mjCNSTR_EQUALITY,
      mujoco.mjtConstraint.mjCNSTR_FRICTION_DOF,
      mujoco.mjtConstraint.mjCNSTR_LIMIT_JOINT,
      mujoco.mjtConstraint.mjCNSTR_CONTACT_PYRAMIDAL,
  ]).repeat([ne, nf, nl, nc])

  dof_i, dof_j = [], []
  for i in range(m.nv):
    j = i
    while j > -1:
      dof_i.append(i)
      dof_j.append(j)
      j = m.dof_parentid[j]

  for i in range(batch_size):
    d_i = jax.tree_map(lambda x, i=i: x[i], d) if batched else d
    result_i = result[i] if batched else result
    ncon = (d_i.contact.dist <= 0).sum()
    efc_active = (d_i.efc_J != 0).any(axis=1)
    efc_con = efc_type == mujoco.mjtConstraint.mjCNSTR_CONTACT_PYRAMIDAL
    nefc, nc = efc_active.sum(), (efc_active & efc_con).sum()
    result_i.nnzJ = nefc * m.nv
    mujoco._functions._realloc_con_efc(result_i, ncon=ncon, nefc=nefc)  # pylint: disable=protected-access
    result_i.efc_J_rownnz[:] = np.repeat(m.nv, nefc)
    result_i.efc_J_rowadr[:] = np.arange(0, nefc * m.nv, m.nv)
    result_i.efc_J_colind[:] = np.tile(np.arange(m.nv), nefc)

***
# Rebdor a Trajectory Post-Training
***

Rendering now works!

In [22]:
#@title Load Model and Define Inference Function
model_path = './model_checkpoints/brax_ant_task_finished'
params = model.load_params(model_path)

In [23]:
from brax.training.agents.ppo import networks as brax_networks

def make_inference_fn(observation_size: int, 
                      action_size: int, 
                      normalize_observations: bool = False, 
                      network_factory_kwargs: Optional[Dict[str, Any]] = None):
    
    normalize = lambda x, y: x
    
    ppo_network = brax_networks.make_ppo_networks(
        observation_size,
        action_size,
        preprocess_observations_fn=normalize,
        **(network_factory_kwargs or {})
        )
    
    make_policy = brax_networks.make_inference_fn(ppo_network)
    
    return make_policy

make_policy = make_inference_fn(observation_size=env.observation_size, action_size=env.action_size)

In [24]:
env_name = 'walker'
env = envs.create(env_name=env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_policy(params, deterministic=False))

# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 1000
render_every = 1

# might becasue brax does not clip the action to the xml limit in the model
for i in (range(n_steps)):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)

    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

    if state.done:
        break

In [30]:
media.show_video(env.render(rollout[::render_every], camera=2), fps=1.0 / env.dt / render_every)

0
This browser does not support the video tag.
