# Notebook For Computational Structure of Brain

dm_control provides a very high level abstraction of Mujoco, which are both implemented by Google Deepmind. **This repository takes a lot of time to parse through it, understand it, and see how each things fit together, but this is very needed as this is the fundamental skills that is needed in research and developing new things:**
1. It is the first step towards understanding what you need to do to create.
2. It is also generally a good skill to have when learning a very powerful highly abstarcted new tool, parsing through many repository classes, seeing how classes' hierchy is connected and how to find the `right level of abstraction`, seeing inherietence relationships between classes and chasing back all the way to the begining.


**Techniqually we can create an species that doesn't exist, using Reinforcement Learning for it to learn (genetic algorithm for evolution?) and then _accept_ that it is how it would act if it actually exist?**

Break through sometimes just come at the point where your previous learned knowledge merge on one point

# Catalog for this Project
1. [Stage 1: Building Pipeline Across Systems](#stage-1-building-pipelines-across-systems)
    - dm_control Environment Setup
    - dm_control Task Level Handling
2. [Stage 1: dm_control Arena Backbone Level Engineering](#stage-1-dm_control---arena-backbone-level-engineering)
    - Debug Log for "bridging issue"
3. [Stage 1: Inherent Binding Class for brax Level Envnvironment](#stage-1-inherent-binding-class-for-brax-engineering)
    - Debug Log for "binding issue"
4. [Stage 1: Rendering Check Using Mujoco](#stage-1-rendering-checking-states-in-dm_controlviewer)
5. [Stage 2: Walker Class Adapted Engineering](#stage-2-walker-class-adapted-engineering)
    - Debug Log for "updated version"
    - State Object in brax
    - Subtree Object in mjx
    - Observe Data Feedback
    - Rewards Setting
    - Mujoco Model Solver Setting
6. [Stage 2: Previous Debug Log for "Rendering Isuue"](#stage-2-previous-debugging-logs--ideas-for-rendering)
7. [Stage 3: ConvNet Adapted](#stage-3-convnet-adaptation)
    - Vision Encoder Ideas & Pipelines
    - All Data Collected
    - Data Flow in brax
    - Debug Log for Vision Encoder Input
    - Debug Log for Vision Encoder Wrapper Class for Data Pathway Separation
    - Debug Log for Vision Encoder Wrapper Class for Network Building
    - Debug Log for "not walking issue"
8. [Stage 3: Vision Network Using `network_base.py`](#stage-3-vision-network-using-network_basepy)
    - General Structure
    - Modification Try 1
    - PPO Model Architecture
    - Modification Try 2
    - Report from Week 8 (Out->In 神经网络结构)
    - Proposition from Week 8 (一体式神经网络结构)
    - Small Fixes With Shapes
    - Convolution Neural Network Structure
    - New Vision Encoder Structure
    - Report from Week 9 (内存过大)
    - Debugging for GPU Memory Overflow
    - Report from Week 10 (内存过大)
    - Report from Week 11 (Gap增大，Performance下降)
9. [Traing Configuration & Rendering Post-trained Rollout](#training-configuration--rendor-trajectory)
    - Rendering Rollout in brax

## Collective Efforts
This is a lower level abstraction compare to the previous one, but it is still really high level, calling functions and environental setup that dm_control have already implemented.

The skill now is to parse through these very convoluted layers to see the `real idea` behind all the implementation and the `pipeline` of how each things are called. In this way, you can know all the way to the implementation level and then decide on which `level of abstraction` you want to stay on to achieve the purpose you need while fixing minimum things and utilize functions that others have already written to complete your goals so

> You can spend the time on building on top of it and actually doing more useful things to create new innovations instea of doing redendent work. Then later your work may become a function, a abstraction level, that other people call to build more works in the field. This is a collaboration project that takes a collective effort towards a common goal, this is the powerfullness of python libs, to not reinvent the wheels but use the wheels to build the car, to not recreate data structure for storing data but use the data for more things.

Therefore, staying on the right level of abstraction and using implemented functions is very important (just like we used brax implementation of ppo, we can use dm_control's implementation of customized env)

- Brax mjx -> Mujoco env -> modify using dm_control implemented packages -> dm_control mjcf
 - On mjx level? on dm_control implemented env sample level? on mjcf level? on xml level?

 ***Do notice that these things are already at the task level, which is already an execution of (model + environment), the correct level of `abstraction` should be just the background environment level, not the task level, this is the same with brax***

***
# Stage 1: Building Pipelines Across Systems
***

### Linux GPU Rendoring
These are imports, installation, and switch backends needed for `Linux GPU Rendoring`
- osmesa is working, all dependencies are installed, but Mujoco rendoring does not support osmesa, only the other ones, which does have installation issues
- specifically the `mujoco.viwer` packages

## `dm_control` Environment
References to dm_control implementation of customized environment. These are the super high level libs, locomotion lib should have everything to get a rl agent running in dm control and the mjcf lib should have everything about building a connection class.
- [dm_control locomotion lib](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion)
- [dm_control mjcf lib](https://github.com/google-deepmind/dm_control/tree/main/dm_control/mjcf)
- [dm_control lib](https://github.com/google-deepmind/dm_control/tree/main)
- [mujoco](https://mujoco.readthedocs.io/en/3.1.1/overview.html)
- [brax](https://github.com/google/brax)


- `locomotion class`
  - `arena subclass`
    - i.e. corridors -> gap corridors
  - `walkers subclass`
    - assets (xml)
    - i.e. rodent -> rat
  - `tasks subclass`
    - i.e. jump gaps
- `mjcf class`
  - class we want to build using idea from corridors
- `composer class`
  - arena.py/arena.xml provides foundation for `locomotion class`

  This really took a quite long time to figure out:
  1. We can use `tasks_subclass` by converting it to mjcf model, then MjModel
  2. Inherentence can be used to build pur own functions needed in `task`, `arena`, and `walker` (work smart)

In [2]:
from datetime import datetime
import functools
from IPython.display import HTML
import PIL.Image
import yaml
from typing import List, Dict, Text, Callable, NamedTuple, Optional, Union, Any, Sequence, Tuple
from matplotlib import pyplot as plt
import mediapy as media
import wandb
import os
import numpy as np
from etils import epath
from flax import struct
from ml_collections import config_dict
import mujoco
from mujoco import mjx
from dm_control import mjcf as mjcf_dm
from dm_control import composer
from dm_control.locomotion.examples import basic_rodent_2020
from dm_control.composer.variation import distributions
from dm_control.locomotion.arenas import corridors as corr_arenas
from dm_control.locomotion.tasks import corridors as corr_tasks
from dm_control.locomotion.walkers import rodent, ant
from dm_control import viewer
from dm_control import mujoco as mujoco_dm
from dm_control.composer.variation import rotations
import jax
from jax import numpy as jp
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State #MjxEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, model
from brax.io import mjcf as mjcf_brax

# !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
# np.set_printoptions(precision=3, suppress=True, linewidth=100)



## `dm_conrol` - Task Level Running
This code is directly implementation in the Mujoco framework:

There are existing examples of the framework that calls from more basic components of the dm control lib. These examples are already super high level abstractions already

In [3]:
_CONTROL_TIMESTEP = .02
_PHYSICS_TIMESTEP = 0.001
random_state=None

walker = rodent.Rat(observable_options={'egocentric_camera': dict(enabled=True)})

  # Build a corridor-shaped arena with gaps, where the sizes of the gaps and
  # platforms are uniformly randomized.
arena = corr_arenas.GapsCorridor(
      platform_length=distributions.Uniform(.4, .8),
      gap_length=distributions.Uniform(.05, .2),
      corridor_width=2,
      corridor_length=40,
      aesthetic='outdoor_natural')

  # Build a task that rewards the agent for running down the corridor at a
  # specific velocity.
task = corr_tasks.RunThroughCorridor(
      walker=walker,
      arena=arena,
      walker_spawn_position=(5, 0, 0),
      walker_spawn_rotation=0,
      target_velocity=1.0,
      contact_termination=False,
      terminate_at_height=-0.3,
      physics_timestep=_PHYSICS_TIMESTEP,
      control_timestep=_CONTROL_TIMESTEP)

env_composed = composer.Environment(time_limit=30,
                              task=task,
                              random_state=random_state,
                              strip_singleton_obs_buffer_dim=True)

#viewer.launch(environment_loader=env_composed)

***
# Stage 1: `dm_control` - Arena Backbone Level Engineering
***

<div class="alert alert-warning">

## Debug Log for Initilization of Binding

These are just the the backbone skeleton in the [Arena_folder](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion/arenas)
- Gap Corridor -> Empty Corridor -> Corridor -> composer.arena -> xml
- Rodent -> xml
### First Idea:
Making the intermediate class that connects the corridor defined environment (corridor environment is in the arena folder, not the task folder)
* [Tasks_folder](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion/tasks) is directly running the algorithm already
* [Arena_folder](https://github.com/google-deepmind/dm_control/tree/main/dm_control/locomotion/arenas) provides the actual skeleton, this is the correct level of `abstraction` that we are looking for
* **Tasks = Walker + Arena**

**problem: we can not directly use the arena implemented by dm_control nor can we use the walker implemented by dm_control, they are too specifically designed by dm_control**

### Second Idea:
1. Directly build an arena through intuitions from the corriodr class
2. Link the rodent or humanoid xml to such environment
3. because it is created by the mjcf class in dm_control, we can directly export it as a fix file and it cna be utilized in brax
    - essentially we skip some of the really convoluted implementation of dm_control from `tasks -> arena + walker -> composer class -> xml`
    - we are directly building it from scratch using `mjcf.RootElement()`
4. mjcf model here direcly uses Mujoco, there is no need to initiate another mujoco in brax training loop, we directly pass an model into brax

**problem: this is way too inefficient and there should be a much smarter way working with it**

### Third Idea:
1. Seems like that the Corridor class can be directly transfered into a mjcf model by `.to_mjcf_model()`
2. Also self implemented a Corridor class that is directly based on `mjcf.RootElement()` and return the model using `.out()` functions
3. Try to inherent directly from the Corridor class and assing some more things

**problem: so far Inherent class idea is implemented quite well, but binding not assigned**

### Forth Idea:
1. Directly wrap evrything once `binding` is down with the agent and also the environment at the `tasks` level
2. task level can also be exported as an mjcf_model file and the same storing as in MjModel once ptr, same with the brax humanoid that was previously implemented in the training loop.
3. Created inherietence from task module and the directly bind it to an existing waler construct file (fix ours later)

**problem: this `task wrapper binded model` has friction setup that Mujoco does not support**

</div>

***
# Stage 1: Inherent Binding Class for Brax Engineering
***

### Primary Goal:
1. Create direct inheritence that from dm_control (`walker, arena, tasks -> wrapper to physics`)
2. Build a pipeline for using customization functions from dm_control for brax training by implementing particular classes that we can change to fit our our needs for the `arena`, `walker`, and `tasks binding` (binds walker and arena).

<div class="alert alert-warning">

## Debug Log for Binding
### We need to try to unbound the ant model so it can actually move when trained
- attach function documenttaion [here](https://github.com/google-deepmind/dm_control/blob/9e360bf8a069868107e52960a97dcafe2292113f/dm_control/composer/entity.py#L299)
- maybe we can use this function [here](https://github.com/google-deepmind/dm_control/blob/9e360bf8a069868107e52960a97dcafe2292113f/dm_control/composer/arena.py#L62), but this function also uses the `attach` function.
- cannot initlize the model and then put it back because the output is only the arena file
- switch `._mjcf_root` to `.mjcf_model`, the model need to be attached to a valid "child" but not fully connected to the ground, `site` should work, it is adding a new object into the arena with certain size: from rendered videos, we can also see that it works.
- might be problem with learning to walk in brax then.

</div>

In [None]:
class Gap_Vnl(corr_arenas.GapsCorridor):
    def _build(self, corridor_width, corridor_length, visible_side_planes, aesthetic, platform_length, gap_length):
        super()._build(corridor_width=corridor_width,
                       corridor_length=corridor_length,
                       visible_side_planes=visible_side_planes,
                       aesthetic = aesthetic,
                       platform_length = platform_length,
                       gap_length = gap_length)

        # self.walker = mjcf.from_path("./models/rodent_dm_control.xml")
        # # Initiate walker
        # self.spawn_pos = (0, 0, 0)
        # self.spawn_site =  self._mjcf_root.worldbody.add('site', pos=self.spawn_pos)
        # self.spawn_site.attach(self.walker).add('freejoint')
    
    def regenerate(self, random_state):
        super().regenerate(random_state)

# Task now just serve as a wrapper
class Task_Vnl(corr_tasks.RunThroughCorridor):
    def __init__(self,
               walker,
               arena,
               walker_spawn_position):
        
        # we don't really need the rest of the reward setup in dm_control, just how the walker is attached to the arena
        spawn_site =  arena.mjcf_model.worldbody.add('site', size=[1e-6]*3, pos = walker_spawn_position)
        self._arena = arena
        self._walker = walker
        self._walker.create_root_joints(self._arena.attach(self._walker, attach_site=spawn_site)) # customize starting environment
        self._walker_spawn_position = walker_spawn_position
        #self._arena.add_free_entity(self._walker)

        # Instead of relying on env to instantiate task and bind artificial agent to arena, the task class itself now can bind artificial agent to arena

In [None]:
class Bowl_Vnl(bowl_arenas.Bowl):
    def _build(self, size=(10, 10), aesthetic='default', name='bowl'):
        super()._build(size=size, aesthetic=aesthetic, name=name)

    def regenerate(self, random_state):
        super().regenerate(random_state)

class Task_Vnl(bowl_tasks.Escape):
    def __init__(self,
               walker,
               arena,
               walker_spawn_position):
        
        # we don't really need the rest of the reward setup in dm_control, just how the walker is attached to the arena
        spawn_site =  arena.mjcf_model.worldbody.add('site', size=[1e-6]*3, pos = walker_spawn_position)
        self._arena = arena
        self._walker = walker
        self._walker.create_root_joints(self._arena.attach(self._walker, attach_site=spawn_site)) # customize starting environment
        self._walker_spawn_position = walker_spawn_position

Initilizing Conditions

In [None]:
arena = Gap_Vnl(platform_length=distributions.Uniform(1.5, 2.0),
      gap_length=distributions.Uniform(.05, .2),
      corridor_width=5, # walker width follows corridor width
      corridor_length=40,
      aesthetic='outdoor_natural',
      visible_side_planes=False)

# arena.regenerate(random_state=None)
# physics = mjcf_dm.Physics.from_mjcf_model(arena.mjcf_model)
# PIL.Image.fromarray(physics.render())

walker = ant.Ant(observable_options={'egocentric_camera': dict(enabled=True)})

arena_bowl = Bowl_Vnl()

task = Task_Vnl(
    walker=walker,
    arena=arena,
    walker_spawn_position=(5, 0, 0))

# we need to bind everything to self.arena because that is the only thing we are putting into the MjModel
random_state = np.random.RandomState(123456)
task.initialize_episode_mjcf(random_state)

physics = mjcf_dm.Physics.from_mjcf_model(task.root_entity.mjcf_model) #.root_entity only returns the arena model, no reward, nothing
#physics = mjcf_dm.Physics.from_mjcf_model(walker.mjcf_model)

There is no need to step the environment in dm_control then take the model again like this:
- walker_joints = walker.mjcf_model.find_all('joint')
- physics.bind(walker_joints).qpos = random_state.uniform(size=len(walker_joints))
- task.initialize_episode(physics, random_state) # self.walker position changes
- physics2 = mjcf_dm.Physics.from_mjcf_model(task.root_entity.mjcf_model)

[This](https://github.com/google-deepmind/dm_control/blob/9e360bf8a069868107e52960a97dcafe2292113f/dm_control/composer/entity.py#L299) is how we fixed this issue, we directly change the attach model calling by giving in a predefined existing location in the mjcf file.

***
# Stage 1: Rendering checking states in `dm_control.viewer`
***

In [None]:
# env_composed = composer.Environment(time_limit=50,
#                               task=task,
#                               random_state=random_state,
#                               strip_singleton_obs_buffer_dim=True)
# viewer.launch(environment_loader=env_composed)

## Rendoring with Mujoco
We want to check if the rendoringproblem after entering brax env is caused by the MjModel itself, so we rendor it here directly fitrst to see if it actually works!

- From playing around with the rendering in mujoco, we figured out a key idea: ***camera options in the dm_control MjModel is indicated by numerical number***
1. -1 is far side view
2. 0 is closer side view
3. 1 is back view
4. 2 is closer back view
5. 3 is egocentric view

In [None]:
mj_model = physics.model.ptr
mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL

mjx_model = mjx.put_model(mj_model)
mj_data = mujoco.MjData(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model)

# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 5  # (seconds)
framerate = 60  # (Hz)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option,camera=0)
    pixels = renderer.render()
    frames.append(pixels)

# Simulate and display video.
media.show_video(frames, fps=framerate)

***
# Stage 2: Walker Class Adapted Engineering
***

Instead of directly getting the `Mujoco xml` or `mjcf` file, we get it directly by instantiating an customized `dm_control pymjcf` class and then extract it from there directly.

<div class="alert alert-warning">

## Debug Log: Uopdated Version?
- Really funny mistakes that caused 2 hours of debugging: the code took direct inherietence of the previous class, which caused all the problems. However, better understanding of the pipeline and all the complex relationships in Brax. **If there are official changes in Brax, they will notice the user**
- All your code should work, even with packages update, the main thing should be fine, just with small update it should work again
    - Really ensure that your environment is what you are thinking you are working in
    - Keep all the environment setting so if necessary, you can reset it on the site

</div>

## State Object in Brax
[State Object](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/envs/base.py#L35C1-L43C60) contaisn all information about the environment at a given state, it is in the [environment folder](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/envs/base.py). State object inherent the [Base class](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/base.py#L38).
- The State object is not refering to the [MjxState Object](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/mjx/base.py#L22), this is just an intermediate step for data passing
- The State object is refering to the [Base.State Object](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/base.py#L397)

```python
class State(base.Base):
  """Environment state for training and inference."""
  pipeline_state: Optional[base.State]
  obs: jax.Array
  reward: jax.Array
  done: jax.Array
  metrics: Dict[str, jax.Array] = struct.field(default_factory=dict)
  info: Dict[str, Any] = struct.field(default_factory=dict)
```

## Subtree Object in mjx
Let's look at this line:
- `com_before = data0.data.subtree_com[1]`

### data0
- data0 in the implementation is a [`pipeline_state`](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/envs/base.py#L38) object -> which is a [`mjx.base.State`](https://github.com/google/brax/blob/532a88a030a0761f9c83279cd7c5f028bd5aa320/brax/mjx/base.py#L22) object (transition class)->  data of it is `Mjx.data` format


### subtree_com[1]
- **Center of Mass (COM)**: This is the point in an object or system of objects where the entire mass can be considered to be concentrated. For physical simulations, especially those involving dynamics and kinematics, calculating the COM is essential for accurately determining how forces and movements will affect the object.

- **Subtree**: A subtree refers to a part of the overall tree structure. For a physical body, this might mean a limb or a combination of a limb and its subsequent parts. For example, an entire arm, from shoulder to fingertips, can be considered a subtree of the body.

## `brax` get_obs Data Feedback
This is all the data feed back from the environment for deternmining any rewards:
1. data.qpos - **Position**
2. data.qvel - **Velocity**
3. data.cinert[1:].ravel() - **Inertia Matrix**
4. data.cvel[1:].ravel() - **Velocity of Inertia**
5. data.qfrc_actuator **Acutator Forces**



### Detailed Observation Components:
1. **Position**: Extracts the positions (qpos) of the humanoid body. (If self._exclude_current_positions_from_observation is True, it excludes the first two elements of the position vector. This could be useful if you want to exclude certain position information from the observation.)

2. **Velocities**: Appends the velocities (qvel) of the humanoid body.

3. **Inertia Matrix**: Appends the inertia matrix (cinert) excluding the first row. (This matrix represents the inertia of the body segments.) Inertia helps to examine the distribution of mass in the humanoid and then calculates the relatyionship it would have with the forces that may be generated to accelerate or deaccelerate

4. **Velocity of Inertia**: Appends the velocity of the inertia matrix (cvel) excluding the first row.

5. **Actuator Forces**: Appends the actuator forces (qfrc_actuator). Actuators are typically modeled as components that generate forces or torques to drive the movement of joints in a simulated robotic system. These forces or torques are applied to the joints of the simulated body, affecting its motion.

## `brax` Rewards

### 3 Main Rewards
The reward consists of three parts:
- **reward_alive**: Every timestep that the humanoid is alive, it gets a reward of 5.

- **forward_reward**: A reward of walking forward which is measured as *1.25 * (average center of mass before action - average center of mass after action) / dt*. *dt* is the time between actions - the default *dt = 0.015*. This reward would be positive if the humanoid walks forward (right) desired. The calculation for the center of mass is defined in the `.py` file for the Humanoid.

- **reward_quadctrl**: A negative reward for penalising the humanoid if it has too large of a control force. If there are *nu* actuators/controls, then the control has shape  `nu x 1`. It is measured as *0.1 **x** sum(control<sup>2</sup>)*. ***Essentially control force penalty makes the movement much more realistic***
    - Control forces are often the output of a control system or a learned policy (in the case of reinforcement learning), dictating how the entity should move to achieve certain objectives, like walking, jumping, or picking up objects.
    - High control forces can lead to aggressive, unstable, or unsafe movements, which might damage the simulated entity or the environment.
    - High control forces often result in abrupt, jerky movements that are not only inefficient but can also be unrealistic, especially in simulations aiming to mimic biological movements.

### Starting State
All observations start in state (0.0, 0.0,  1.4, 1.0, 0.0  ... 0.0) with a uniform noise in the range of [-0.01, 0.01] added to the positional and velocity values (values in the table) for stochasticity. Note that the initial z coordinate is intentionally selected to be high, thereby indicating a standing up humanoid. The initial orientation is designed to make it face forward as well.

### Episode Termination
The episode terminates when any of the following happens:

1. The episode duration reaches a 1000 timesteps
2. The z-coordinate of the torso (index 0 in state space OR index 2 in the table) is **not** in the range `[0.8, 2.1]` (the humanoid has fallen or is about to fall beyond recovery).

## `Mujoco` Model Solver
A "solver" typically refers to the algorithm or computational method used to resolve the physical equations that govern the dynamics of the system being simulated, including handling contacts, frictions, and constraints among the bodies in the simulation.

### Solver Parameters
MuJoCo allows users to configure various parameters of the solver to balance between accuracy and performance. These parameters can include:

1. **Iterations**: The number of iterations the solver performs to resolve contacts and constraints. More iterations can lead to more accurate results but at the cost of computational time.

2. **Tolerance**: The tolerance level for the solver's accuracy. A lower tolerance means the solver will work harder to minimize errors, which can be necessary for precise simulations but again increases computation time.

3. **Solver Type**: MuJoCo may offer different types of numerical solvers for the dynamics equations, each with its trade-offs in terms of stability and computational load.

### Conjugate Gradient (CG) Solver
The Conjugate Gradient solver is one of several solver types available in MuJoCo. It is designed to solve systems of linear equations that arise in the simulation of physical dynamics, particularly when dealing with sparse data structures. The CG solver is often preferred for its efficiency in handling large, sparse systems, which is common in complex simulations involving many joints and contacts.

```python
mujoco.mjtSolver.mjSOL_CG
```

### Newton Solver
Offers high accuracy at the cost of computational resources, suitable for simulations where precision is paramount.

1. High Accuracy: The Newton solver can achieve high levels of accuracy because it iteratively refines its guesses until it converges on a solution that satisfies the system of equations within a predefined tolerance. This characteristic makes it highly effective for simulations where precision is paramount.

2. Efficiency in Nonlinear Systems: It is particularly suited for solving nonlinear problems, which are common in physics simulations involving complex interactions between objects. Nonlinear dynamics are typical in systems with large deformations, sophisticated material models, or intricate boundary conditions.

3. Convergence Properties: When the initial guess is close to the true solution, the Newton solver converges rapidly due to its quadratic convergence property. However, its performance is highly dependent on the quality of the initial guess and the nature of the system being solved.


### PGS (Projected Gauss-Seidel)
Good for general-purpose simulations, balancing speed and accuracy.

***
# Stage 2: Previous Debugging Logs & Ideas for Rendering
***

<div class="alert alert-warning">

## Debugging Logs for Rendering:

***Key: Debug line by line, get to all the layers, and try one byone to see where the error is actually occuring, then figure out why that is the case, solving a puzzle.***

1. Follow error message and function descriptions, check all the way to `mujoco` render level, then to `brax`
2. Ran mujoco rnder separately, `MjModel` is able to render
3. brax render uses mujoco render, mujoco render works fine, the `MjModel` should be correct, may be problem with `states` in rollout `trajectory`
4. problem probably occur on this [line](https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/base.py#L214) where `d` has shape of 8 but renderer only takes in shape of 7
5. We may need to define our own rendoring function by shrinking some of the inputting array in `d`
6. It may be the env setup for brax here is for humanoid, not for ant (humanoid model in dm_control has mj_mopdel.geo_condim issue (mjx only support geo_condium of 3) and rodent has tendon issue)
7. Reimplemented ant using a simpler implementation instead of the complex one, seems like still encounter the issue
8. May have to reimplement the walker class first for this to work
    - check geom_condim
    - check tendon
9. The mismatch issue may come from mismatch training loop
10. previous wrong understanding, training loop for brax env is correct, wrong looking previously because we direcly load in state, not loading in `.Data`
11. indeed it is still this [line](https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/base.py#L214) that have error occuring, I implemented the rendoring function step by step and it is at this step the error message `ValueError: could not broadcast input array from shape (8,) into shape (7,)` occurs
12. Now actually everything works! Practiced debugging going to really small detailed level and trying out bit by bit.

### More Understanding
What `mjx.get_data` is getting really is the date of one aspect of the rollout, or one state of it `rollout[i]`. so `state.data`

```python
def render(
      self, trajectory: List[base.State], camera: Union[str, None] = None
  ) -> Sequence[np.ndarray]:
    """Renders a trajectory using the MuJoCo renderer."""
    renderer = mujoco.Renderer(self._model)
    camera = camera or -1

    def get_image(state: mjx.Data):
      d = mjx.get_data(self._model, state) # This line have error occuring, shows same error message
      mujoco.mj_forward(self._model, d)
      renderer.update_scene(d, camera=camera)
      return renderer.render()

    return [get_image(s.data) for s in trajectory]  # pytype: disable=attribute-error
```

#### Idea 1: Know which component is which when calling, take important part that is relevant to debugging and then run separately.

#### Idea 2: Take out big chunk of the code and try to remove different things to see what is causing the shape error.
- ValueError: could not broadcast input array from is `np.arange(0, nefc * m.nv, m.nv)` into is `result_i.efc_J_rowadr[:]`
- **When running the whole structure, the shape seems to match up -> call rendor try -> call mediapy try**

#### Idea 3: Separate fundamental functions to examine them separately.
Seems like the problem lays in this `mjx.get_data` function, which is [here](https://github.com/google-deepmind/mujoco/blob/8be966cdf9073813ec8b494062f4d97848432057/mjx/mujoco/mjx/_src/io.py#L235) and related to the `get_data_into` function, particularly this [line](https://github.com/google-deepmind/mujoco/blob/8be966cdf9073813ec8b494062f4d97848432057/mjx/mujoco/mjx/_src/io.py#L293). I think this is related to [this line's shape problem](https://github.com/google-deepmind/mujoco/blob/8be966cdf9073813ec8b494062f4d97848432057/mjx/mujoco/mjx/_src/io.py#L284C4-L284C47)

</div>

***
# Stage 3: ConvNet Adaptation
***

## `brax` Visual Encoder Ading Idea
1. We can use a rendering function in the `Walker` class. but the rendering function itself uses a helper function that sets outside of the `Walker` class.
2. We can render from 1st perspective
3. We directly intergrade the system as the process decribed below

### Implementation Details
1. use the same idea with brax `env.render` and grab the mjx.data and pass it through mujoco rendering to get one frame
2. real time redner one frame of the state (1st perspective) directly from the one state mjx data at hand and feeds into training loop -> **it is during training, training image, unlike the rollout in the post-training steps**
3. add it into the observaryion space in `get_obs`, observation should allow adding
4. handle `reset` and `step` function with the new data of the visual cortex
5. preprocessing it using a function in the training loop
    - pass it through a CNN network or pretrained image encoder (inference) or maybe training it in real time traininga
    - we need to find a way to put things into the CNN and **separate the data pathway (key) for visual and for others**
        - `make_policy` and put it into the ppo document?
        - directly inherent the ppo document and add the vision encoder information we need?
        - [dm_control vision encoder link](https://github.com/talmolab/vanilla_rl_rodent/blob/main/main_ppo_rodent_vision.py#L129)
        - [brax ppo](https://github.com/google/brax/tree/main/brax/training/agents/ppo)
    - We use the same set of training procedure since we want the conv net for maximizing reward as well, so the training for image is exactly the same as for all other data, PPO algorithm, difference being the layer is a AlexNet now.
    - Or just extended front connecting to the ppo training network
    - vision encoder is "encoded" the activation, what the actual PPO takes in is the activation parameter of the CNN, it learns from that, not the raw images, same with the inferenceing function as well.

### Pipeline of Visual Encoder

<img src=./images/encoder_plan_2.png width=60%>
<img src=./images/encoder_plan_3.png width=60%>

## All State.Pipeline_state.data colected
### All of these are from [mjx.Data](https://github.com/google-deepmind/mujoco/blob/c6a41fbfe64ee7b2680a6bde90200ca660d08c2a/mjx/mujoco/mjx/_src/types.py#L590)

```python
1. solver_niter=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
2. time=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
3. qpos=Traced<ShapedArray(float32[15])>with<DynamicJaxprTrace(level=1/0)>
4. qvel=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
5. act=Traced<ShapedArray(float32[0])>with<DynamicJaxprTrace(level=1/0)>
6. qacc_warmstart=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
7. ctrl=Traced<ShapedArray(float32[8])>with<DynamicJaxprTrace(level=1/0)>
8. qfrc_applied=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
9. xfrc_applied=Traced<ShapedArray(float32[17,6])>with<DynamicJaxprTrace(level=1/0)>
10. eq_active=Traced<ShapedArray(int32[0])>with<DynamicJaxprTrace(level=1/0)>
11. qacc=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
12. act_dot=Traced<ShapedArray(float32[0])>with<DynamicJaxprTrace(level=1/0)>
13. xpos=Traced<ShapedArray(float32[17,3])>with<DynamicJaxprTrace(level=1/0)>
14. xquat=Traced<ShapedArray(float32[17,4])>with<DynamicJaxprTrace(level=1/0)>
15. xmat=Traced<ShapedArray(float32[17,3,3])>with<DynamicJaxprTrace(level=1/0)>
16. xipos=Traced<ShapedArray(float32[17,3])>with<DynamicJaxprTrace(level=1/0)>
17. ximat=Traced<ShapedArray(float32[17,3,3])>with<DynamicJaxprTrace(level=1/0)>
18. xanchor=Traced<ShapedArray(float32[9,3])>with<DynamicJaxprTrace(level=1/0)>
19. xaxis=Traced<ShapedArray(float32[9,3])>with<DynamicJaxprTrace(level=1/0)>
20. geom_xpos=Traced<ShapedArray(float32[38,3])>with<DynamicJaxprTrace(level=1/0)>
21. geom_xmat=Traced<ShapedArray(float32[38,3,3])>with<DynamicJaxprTrace(level=1/0)>
22. site_xpos=Traced<ShapedArray(float32[35,3])>with<DynamicJaxprTrace(level=1/0)>
23. site_xmat=Traced<ShapedArray(float32[35,3,3])>with<DynamicJaxprTrace(level=1/0)>
24. subtree_com=Traced<ShapedArray(float32[17,3])>with<DynamicJaxprTrace(level=1/0)>
25. cdof=Traced<ShapedArray(float32[14,6])>with<DynamicJaxprTrace(level=1/0)>
26. cinert=Traced<ShapedArray(float32[17,10])>with<DynamicJaxprTrace(level=1/0)>
27. crb=Traced<ShapedArray(float32[17,10])>with<DynamicJaxprTrace(level=1/0)>
28. actuator_length=Traced<ShapedArray(float32[8])>with<DynamicJaxprTrace(level=1/0)>
29. actuator_moment=Traced<ShapedArray(float32[8,14])>with<DynamicJaxprTrace(level=1/0)>
30. qM=Traced<ShapedArray(float32[81])>with<DynamicJaxprTrace(level=1/0)>
31. qLD=Traced<ShapedArray(float32[81])>with<DynamicJaxprTrace(level=1/0)>
32. qLDiagInv=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
33. qLDiagSqrtInv=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
34. contact=Contact(dist=Traced<ShapedArray(float32[669])>with<DynamicJaxprTrace(level=1/0)>
35. pos=Traced<ShapedArray(float32[669,3])>with<DynamicJaxprTrace(level=1/0)>
36. frame=Traced<ShapedArray(float32[669,3,3])>with<DynamicJaxprTrace(level=1/0)>
37. includemargin=Traced<ShapedArray(float32[669])>with<DynamicJaxprTrace(level=1/0)>
38. friction=Traced<ShapedArray(float32[669,5])>with<DynamicJaxprTrace(level=1/0)>
39. solref=Traced<ShapedArray(float32[669,2])>with<DynamicJaxprTrace(level=1/0)>
40. solreffriction=Traced<ShapedArray(float32[669,2])>with<DynamicJaxprTrace(level=1/0)>
41. solimp=Traced<ShapedArray(float32[669,5])>with<DynamicJaxprTrace(level=1/0)>
42. geom1=Traced<ShapedArray(int32[669])>with<DynamicJaxprTrace(level=1/0)>
43. geom2=Traced<ShapedArray(int32[669])>with<DynamicJaxprTrace(level=1/0)>)
44. efc_J=Traced<ShapedArray(float32[2684,14])>with<DynamicJaxprTrace(level=1/0)>
45. efc_frictionloss=Traced<ShapedArray(float32[2684])>with<DynamicJaxprTrace(level=1/0)>
46. efc_D=Traced<ShapedArray(float32[2684])>with<DynamicJaxprTrace(level=1/0)>
47. actuator_velocity=Traced<ShapedArray(float32[8])>with<DynamicJaxprTrace(level=1/0)>
48. cvel=Traced<ShapedArray(float32[17,6])>with<DynamicJaxprTrace(level=1/0)>
49. cdof_dot=Traced<ShapedArray(float32[14,6])>with<DynamicJaxprTrace(level=1/0)>
50. qfrc_bias=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
51. qfrc_passive=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
52. efc_aref=Traced<ShapedArray(float32[2684])>with<DynamicJaxprTrace(level=1/0)>
53. actuator_force=Traced<ShapedArray(float32[8])>with<DynamicJaxprTrace(level=1/0)>
54. qfrc_actuator=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
55. qfrc_smooth=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
56. qacc_smooth=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
57. qfrc_constraint=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
58. qfrc_inverse=Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
59. efc_force=Traced<ShapedArray(float32[2684])>with<DynamicJaxprTrace(level=1/0)>)
```

<br>

In `get_obs` function
1. `Position_data` = Traced<ShapedArray(float32[13])>with<DynamicJaxprTrace(level=1/0)>
2. `Velocity_data` = Traced<ShapedArray(float32[14])>with<DynamicJaxprTrace(level=1/0)>
3. `Image_data` = Traced<ShapedArray(uint8[230400])>with<DynamicJaxprTrace(level=1/0)>

In `step` function
1. `COM` data is Traced<ShapedArray(float32[17,3])>with<DynamicJaxprTrace(level=1/0)>
2. reward is not actually calculated with the `obs` space, **so use dict because not effect?**
3. `obs` is a concated data type, **maybe create a `Data` class just like how brax did in there data class?**


In [None]:
# This is the direct inherent data class just like mjx.Data
class BraxData(mujoco.mjx._src.dataclasses.PyTreeNode):
    position:jax.Array
    velocity:jax.Array
    image:jax.Array

## Data Flow in Brax
This is the direct inherent data class just like mjx.Data
1. Customized data class
2. Customized tags
3. Easy with calling because of the exact same type

Remanber, the official brax tutorial doesn't really use observation data but rather the whole data pool

<img src=./images/data_flow.png width=50%>

## How PPO Training Works
### Reward flow
1. brax ppo `train.py` have important [unroll function](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L296) crucial step for generating new action step!
2. This leads to the `acting.py` file, which have the [unroll function](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/acting.py#L57), which call the [step function](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/acting.py#L34). This leads to the [Policy function](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/acting.py#L42)
3. Reward is gathered in this [Transition Class](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/acting.py#L45) for taking gradient later!

* One method of approaching separate the vision data and proprioreceptive data is to modify the [data output from the unroll function](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L303)

### Inputting layer of Neural Network
1. [This](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L369) is where the NN are made, random initilization to begin with
2. Need to also modify the conatiner for all parameters in the [ppo_loss.py](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/losses.py#L31)
3. The [optimzer](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L237C3-L238C1) takes in the parameters for optimization, which is
    ```python
    optimizer = optax.adam(learning_rate=learning_rate)
    ```
    - This should be fine for taking in another set of parameters, reference to [optax](https://github.com/google-deepmind/optax)
4. Wrapps again in the [Training State Class](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L371) and then [Jax version](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L377) it
5. Feed into [training_epoch_with_timing](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L423)
    - Feed in to [training_epoch](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L350C14-L350C28)
    - [training_epoch](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L333) returns ppo_loss matrix

6. General Plan:
    1. **Separate vision and proprioreceptive State**
    2. **Retrieve the parameter once within the training loop using `Training_State.train` function [here](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/train.py#L73)**
        - Might need to make certain adjustment to some intermediate data structure
        - Might need to re-create `Training_State.train` function
        - **Use the same train function but with network adjusted, retrieve parameter, put it into state space, call train function again**
    3. **Feed it backin for PPO again**
    
**Why don't we just do a separte CNN and then encode the activation parameter into state since we are gonna change the algorithm anyways?**

**Solution: Modify the [`network.py`](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/networks.py) file directly**
- Create a new `network.py` with appropriate modifications for the encoder's stuff.
    - Don’t need to go so deep into the `train.py` to change all the training loop
    - Brax seems to be very messy but everything is developed in a very modularized and ordered fashion, you can directly guide the input in the `network.py` file and leave the `train.py` file untouched.
- The original Brax observation space parameter is a concatenated jax array, but the network should also take in the mjx.data (BraxData) type, or we can just concatenate later if this doesn’t work.


<div class="alert alert-warning">

## Debug Log for Vision Encoder Input
1. [Brax Reendering](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/io/image.py#L26)
2. [Mujoco Rendering](https://github.com/google-deepmind/mujoco/blob/c146bb40033a1e080aa0d881c476e4a0549238be/python/mujoco/renderer.py#L27)
3. [Transformation to Egocentric Camera Frame](https://github.com/google-deepmind/dm_control/blob/3adfe8c3a1d1dff491956a290f7205f19dbb55f3/dm_control/locomotion/walkers/base.py#L77)

### First Challenge:
We need the renderer to take in an `Mujoco.Model` and an `Mujico.Data`, we have the `Mujoco.Model`, but the data in the get_obs function passed in is `pipeline_state. data` format, `pipeline_state` is the state, its data is in `mjx.data` format, we need to convert it to `Mujoco.Data` format.

This is resolved as `mjx.data` seems to be also working, but the **problem now goes to jax step**

### Second Challenge:
There is extra jax array presented with the `brax` `mjx.data` type. In the actual rollout, we append each rollout by stepping the environment, which requires the environment itself.

Solved using the new `brax` renderer documentation.
```python
d = mujoco.MjData(self._model)
```

image resulted: (240, 320, 3) -> flat out to be an 1D array by making it an jax.numpy.array and then .flatten() it uisng the numpy array function

1. **Now that the one frame image can be stored into the observation space, we can start training**
2. **Performance time for rendering seems to be quite acceptable, didn't add too much time complexity**
</div>

<div class="alert alert-warning">

## Debug Log for Vision Encoder Wrapper Class
### For Data Pathway Separation
1. We can create a wrapper class to separate the data pathway by creating different `keys` needed and then figure out how to feed into PPO differently.
2. Maybe we can directly create such wrapper outside of the env class, a wrapper that saves all aspects of the `Walker.env` class with the only difference being data differences? **key is that this is after the environment is created, add tags?**

### Example of Idea2
```python
class ObsWrapper(...):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Dict({
            'velocity': ...,
            'position': ...,
            'image': ...,
        })
```

- We make **an dictionary** to tag and separate data now (**directly in the get_obs function**, canm make it a data class if needed for later usage)! Now data are separated
```python
 observation = {'vision': image_jax,
                   'proprioceptive': jp.concatenate([
                     position,
                     velocity,
                     ])}
    return observation
```
- However, obs can only take in an long lenagth of array, maybe chunk it later?
- We directly package our data just like mjx.Data but using a different inheritence class, this way we cna customize **data class**, customized **tags**, and directly have **separated data**

</div>

<div class="alert alert-warning">

## Debug Log for Vision Encoder Building Network

[This](https://github.com/google/brax/blob/3c109cfd131e01691a53891e8f2ec9b32cf97670/brax/training/agents/ppo/networks.py#L37) is how the PPO makes a policy netweork, having `type.Observation` input

```python
def make_ppo_networks(
    observation_size: int,
    action_size: int,
    preprocess_observations_fn: types.PreprocessObservationFn = types
    .identity_observation_preprocessor,
    policy_hidden_layer_sizes: Sequence[int] = (32,) * 4,
    value_hidden_layer_sizes: Sequence[int] = (256,) * 5,
    activation: networks.ActivationFn = linen.swish) -> PPONetworks:
  """Make PPO networks with preprocessor."""
  
  parametric_action_distribution = distribution.NormalTanhDistribution(event_size=action_size)
  
  policy_network = networks.make_policy_network(
      parametric_action_distribution.param_size,
      observation_size,
      preprocess_observations_fn=preprocess_observations_fn,
      hidden_layer_sizes=policy_hidden_layer_sizes,
      activation=activation)
  
  value_network = networks.make_value_network(
      observation_size,
      preprocess_observations_fn=preprocess_observations_fn,
      hidden_layer_sizes=value_hidden_layer_sizes,
      activation=activation)

      return PPONetworks(
      policy_network=policy_network,
      value_network=value_network,
      parametric_action_distribution=parametric_action_distribution)
```

We proabably need to write our own unique function of `make_policy_network`

</div>

<div class="alert alert-warning">

## Debug Log for Not Walking
1. Not a binding issue, rendering with random control force show that it is moving + drop from height works, not attached to the arena. **Not a env defining issue**.
2. Not a observation space mismatch issue because observations pace tells the agent all information about itself in the environment, if healthy reward is working, observation space is matching because it have propriorteeptive input. **Not a Brax env issue**
3. Does the agent have sufficient information about it's velocity and walking forward? Does it know that? Brax treat the whole "Walker Class" as an agent and define reward in that sense, `self` is refering to the artificial agent. However, we passed in the arena with the walker into it.
    - The ant walks just as fine when it is passing in just the ant model and also when passing in as the arena binding model. It may not be the issue of model passed into the brax class because randomly initiate control forces does pass in to the walker correctly.
    - The velocity (key for deciding forward reward) does look at the center of mass movement for the torso, we need to identiy which center of mass we are looking at. the center of mass has an array of float32[17,3], we need to find **the right center of mass**.

### Findings:
1. Reward assigning need to much more specific, we are dealing with a much more complicated situation where the passed in model is a arena and a walker combined.
    - They are not fully bind, we need to extract each's correct sub components to assign reward, much more accurate reward assigning (COM check maybe?)
    - Some reward working, observation space feedback is successful
    - In Brax, the class is designed directly for one artificial agent, so grabbing all teh COM and assigning reward to `self` is much more simple, but in thi  modified case, we are using the class as a monitor of the full environment, so reward assigning must be precise to the correct sub_tree section.
2. The brax class now is a big monitor of both the `arena` + `the walker`, need to make sure reward assign to correc `sub_tree`.
3. This is the correct perspective, changing COM does now give non-zero distance and forward reward! Maybe we can use `subtree_linvel`
    - The subtree linear velocity represents the velocity of all the connected bodies under the specified body.
    - Actually no, this requires using the physics module directly, which we don't have

</div>

***
# Stage 3: Vision Network using `network_base.py`
***

## General Structure
1. Big class is the `class PPONetworks` which contains 3 network
2. `make_ppo_networks` instantiate all the networks using the base `networks.py` file and return the class wrapper object
3. `make_inference_fn` is the main training function
    - Brax data return to PPO familiar form in here
    - vision network instantiated first to retrieve parameter
    - concatinated with velocity and position data to feed into `train.py` in ppo calling

* `train.py`'s function of train is called with extra parameter feeded in with the customized file

**Do we need to step into the environment for getting more "stepped" data for vision feedin? Or is parameters good enough?**

```python
 # ppo.train file calls makes policy, which gets the post processed actions and a dictionary, we do the same here
       if deterministic:
         return ppo_networks.parametric_action_distribution.mode(visio_param), {}
       raw_actions = parametric_action_distribution.sample_no_postprocessing(
         visio_param, key_sample)
       log_prob = parametric_action_distribution.log_prob(visio_param, raw_actions)
       
       postprocessed_actions = parametric_action_distribution.postprocess(raw_actions)
       dict = {'log_prob': log_prob, 'raw_action': raw_actions}

       next_state, data = acting.generate_unroll(
          env,
          current_state,
          policy,
          current_key,
          unroll_length,
          extra_fields=('truncation',))
```

<div class="alert alert-warning">

## Modification Try 1
1. Currently made fake image (rand pixel) for testing the pipeline
2. trasnitions seems to be okay but there is a shape error where `ValueError: Incompatible shapes for broadcasting: shapes=[(128, 230400), (230427,)]`
    - maybe caused by `env_state.obs.shape[-1]` in train.py where the observation space defined in training is either **inconsistent (image + proprioreceptive)** or the **data format in shape attribute** that is returned by customized BraxData is wrong.
3. DataClass.image is different in `base.py` and in `network_vision.py`
    - Traced<ShapedArray(float32[230400])>with<DynamicJaxprTrace(level=3/0)>
    - Traced<ShapedArray(float32[128,230400])>with<DynamicJaxprTrace(level=3/0)>
    - This is due to the `vmap` function of cutting to smaller pieces
    - `vmap` treat concatenated jax array or brax data class the same, all split
    - shape mismatch issue still exist!
    - Resolved this issue!

</div>

## PPO Model Architecture
This is the model architecture that is used:

```python
RunningStatisticsState(mean=Traced<ShapedArray(float32[230427])>with<DynamicJaxprTrace(level=1/0)>, std=Traced<ShapedArray(float32[230427])>with<DynamicJaxprTrace(level=1/0)>, count=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, summed_variance=Traced<ShapedArray(float32[230427])>with<DynamicJaxprTrace(level=1/0)>) {'params': {'hidden_0': {'bias': Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace(level=1/0)>, 'kernel': Traced<ShapedArray(float32[230427,256])>with<DynamicJaxprTrace(level=1/0)>}, 'hidden_1': {'bias': Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace(level=1/0)>, 'kernel': Traced<ShapedArray(float32[256,256])>with<DynamicJaxprTrace(level=1/0)>}, 'hidden_2': {'bias': Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace(level=1/0)>, 'kernel': Traced<ShapedArray(float32[256,256])>with<DynamicJaxprTrace(level=1/0)>}, 'hidden_3': {'bias': Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace(level=1/0)>, 'kernel': Traced<ShapedArray(float32[256,256])>with<DynamicJaxprTrace(level=1/0)>}, 'hidden_4': {'bias': Traced<ShapedArray(float32[256])>with<DynamicJaxprTrace(level=1/0)>, 'kernel': Traced<ShapedArray(float32[256,256])>with<DynamicJaxprTrace(level=1/0)>}, 'hidden_5': {'bias': Traced<ShapedArray(float32[16])>with<DynamicJaxprTrace(level=1/0)>, 'kernel': Traced<ShapedArray(float32[256,16])>with<DynamicJaxprTrace(level=1/0)>}}}
```
This is where the `230427` comes from

<div class="alert alert-warning">

## Modification Try 2
1. The current implementation `_unpmap` the jax array and then `_re_vmap` it later on after everything is being processed
2. All pipeline working, but encountering the issue where in some of the code that the basic brax training is using `shape` of the obs, which is intentionally to be a **concatenated jax array** that is changing due to vmap. However, this changing mechanism changes when we just implement a shape attribute in the **BraxData class**.
    - `TypeError: cannot reshape array of shape (128,) (size 128) into shape [230427] (size 230427)`
    - This is just getting index 0 instead of index 1, it is flipped
    - actually this may be caused by the shape tuple not getting the `vmap`, it need to get `vmap`
    - shape[-1] might be used to avoid `vmap` issue to always get the observation space size
    - There is a type check `TypeError: where requires ndarray or scalar arguments, got <class 'vnl_brax.data.BraxData'> at position 1.`
    - This occurs becaus  the function is using `jp.where`, which only works on array
3. Reduced problem of doing not mapping and then mapping by using `axis=1`

</div>

## Report from Week 8 (Out->In 神经网络结构)
Modularized Training:

1. Successful in building up pipeline and dealing with `vmap` situation and matching dimensions -> everything in `network.py` file is good other than need to customize parameter size a bit more.
2. Need to find a way to modify the `network.py` file better
    - customized parameter size
    - impute the (27 proprioception + 16 vision activation) array right back to the observation space
3. Notice there are 2 environment, one `env_state` and one `env_eval`, `env_eval` can be passed in separately
    - There is a `type` check due to the use of `jp.where()` later on in `eval_env`
    - Resolving by using the Brax data for the purpose of policy and then switch bakc to the brax_original_form data after making the policy?
    - `eval_env` is not instantiated, it is a pure `Walker` class
    - This can be resolved by modifying the `unroll` function in the `acting.py` file, but this would encounter issues when passing back to the `network.py` file again.
    - changed the after processed state obs space back to th  original obs space (need to modify later)
    - **problem**:
    
```python
 # this function calls to network.py file -> need BraxData while need d-array, problematic
(final_state, _), data = jax.lax.scan(
    f, (env_state, key), (), length=unroll_length)
```

## Proposition from Week 8 (一体式神经网络)
1. This is not a problem with `BraxData` class as we need to impute back such result anyways, the key is making sure that we change both of the environment once the activation parameter is retrieved, we also need to consider whether just these activation is enough and whether th egradient can update them and feed back to vision network

2. We may need to consider having a separate **conv_net** in a separate file that trains the cnn then feeding it into the `obs_space`
3. We need to re-consider the whole  architecture of the network, **can we add conv layer diretly and leaving the rest as it is, then just specify the input? not out put of one as the input of another, but rather one network as a whole** -> (check `.apply` functions in the basic network file. See how to manipulate obs apply to which layer.) -> [一体式神经网络](https://github.com/google/flax/blob/main/examples/ppo/models.py)
    - -> the current method is doing way to much modifications
    - -> may also need to adjust the BraxData class to just makeit easier, just use an array and cut it later
    - -> minimize touching of brax's file

**The basic `network.py` file in brax uses a jax supported `linen.Module`, which create a class and allow the input to be a **jax array**, we can customize where the data goes by designing the data flow in the class.**

> ##### Really great experience on understanding all the flow of data, the complexity of the project, **right level of abstarction**, **involve in ML pipeline experience**, and **seeing the collective effort**

**Start to look at this method when thinking about how does the `param` gets distrbuted with the `.apply` function**

After building the network and then solving some intergarting issues with:
- feed_in data dimension,
- 2 dimension slicing,
- parameter feedin tuning,
- reshaping image data to `vmap` form acceptable as well
- `vmap` data handling (causing unequal split problem)

The pipeline seems to work! One iteration performed!

**after working with it for a while, you know kind of what maight be causing the actual bug (i.e. some array only have one dim))**


<img src=./images/vision_encoder.png width=90%>

## Small Fixes with Shape
- `network_vision.py` file is not called by the ppo `train.py` directly in import, but it is feed in through the `network_factory` label!
```python
TypeError: cannot reshape array of shape (5, 485, 230427) (size 558785475) into shape (5, 240, 320, 3) (size 1152000)
```
- This is a dynamic increasing aray, this is why [-1] is used to access obs space size previously in the `train.py`.
```python
TypeError: cannot reshape array of shape (5, 485, 230427) (size 558785475) into shape (-1, 240, 320, 3) because the product of specified axis sizes (230400) does not evenly divide 558785475
```
- figure out a good share or maybe no need to reshape at all since `vmap` is already applied!

## Convolution Structure
* Convolutional Layers: Three convolutional layers are applied in sequence to vision_data, each followed by a ReLU activation function. These layers progressively apply filters to extract features from the input images:

    - The first conv layer applies 32 filters of size (8, 8) with a stride of (4, 4).
    - The second conv layer applies 64 filters of size (4, 4) with a stride of (2, 2).
    - The third conv layer applies 64 filters of size (3, 3) with a stride of (1, 1).

* Flattening: After passing through the convolutional layers, vision_data is flattened into a two-dimensional array with the second dimension being 76800. This operation is crucial for transitioning from convolutional layers to fully connected layers.

## New Vision Encoder Architecture
<img src=./images/vision_encoder_updated.png width=70%>

## Report from Week 9 (内存过大)
1. Pipeline in general works! handling spltting batches in image data also works!
2. memory issue encountered
3. need to solve rendering issue
4. Consider to reduce parralel env
5. Consider chaneg mjx.render function with new updated version

<img src=./images/gpu_overload.png width=90%>

- Some rendering tips in [here](https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html) from Charles
- Switch `jax.lax.scan` or other methods for not for looping
- Switch to mujoco renderer

<div class="alert alert-warning">

## Debugging for Memory Overload
- Incremental Development:
    - ConvNet shape doesn't seem to be too big
    - Checking without ConvNet just having Proprioreceptive


**There are no `for loop issue` and `reshape issue`. This is from running just the proprioreceptive data over the network, no ConvNet, seems like the problem is with the `original obs_array size`, we may want to reduce the image size by a bit**
- Consider 64 x 64 x 3 array?
- `this is out of full ppo network (5, 64, 256)` -> `this is out of full ppo network (5, 1, 1)`

```python
Peak buffers:
        Buffer 1:
                Size: 8.79GiB
                Operator: op_name="pmap(training_epoch)/jit(main)/while/body/transpose[permutation=(0, 2, 1, 3)]" source_file="/home/jovyan/Brax-Rodent-Run/train.py" source_line=110
                XLA Label: fusion
                Shape: f32[32,64,5,230427]
                ==========================

        Buffer 2:
                Size: 8.79GiB
                Operator: op_name="pmap(training_epoch)/jit(main)/while/body/transpose[permutation=(0, 2, 1, 3)]" source_file="/home/jovyan/Brax-Rodent-Run/train.py" source_line=110
                XLA Label: fusion
                Shape: f32[32,64,5,230427]
                ==========================

        Buffer 3:
                Size: 8.79GiB
                Operator: op_name="pmap(training_epoch)/jit(main)/while/body/while/body/jit(_take)/select_n" source_file="/home/jovyan/Brax-Rodent-Run/train.py" source_line=110
                XLA Label: fusion
                Shape: f32[1,2048,5,230427]
                ==========================
```

- This idea is very correct! -> when reducing image size, training works!
- mjx rendering with rendering tips in [here](https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html) from Charles works now! -> image slice to 64x64x3
- There are quite some big update on Brax's MJX backend training loop after the update, but the vision network seems to work pretty well!

</div>

## Report from Week 10 (内存过大)
1. The above idea is very correct! -> when reducing image size, training works!
2. mjx rendering with rendering tips in [here](https://pytorch.org/rl/reference/generated/knowledge_base/MUJOCO_INSTALLATION.html) from Charles works now! -> image slice to 64x64x3
3. There are quite some big update on Brax's MJX backend training loop after the update, but the vision network seems to work pretty well!
4. Remanber all jax's reward is not calculated from the obs space but rather from the data, which is collecting all environment information
    - [jax.lax.scan](https://github.com/google/brax/blob/b68d9387f8c0b05271e0a5fd4cff8f851a256995/brax/envs/base.py#L121) in the `pipeline_step` function is where the error comes from
    - Previously have not dealing with the data side directly, no `step` function is called directly
    - This is not a problem with the image, the same error occurs even when just proprioreceptive data is feed into it!
5. Seems like the leak tracer may be coming from not `jit` correctly according to issues on brax
6. Import on brax seems to behave a bit not correct!
7. Both problem resolved, it is a problem with brax's dependency with flax -> `flax`, `jax` need to be updated more often when `brax` is updated
8. Add distance reward using distance from origin, not velocity -> count negative **do the eucledian distance from the origin**

## Report from Week 11 (Gap增大，Performance下降)
1. Starting to make an bowl escape class from dm_control for multiple class training
2. Seems like `mjGEOM_HFIELD` and `mjGEOM_SPHERE` is not supported for collision in mjx, but it is in mjx (**Heightfield collision is not supported in mjx yet**)
3. Now consider having two training in sequence with different gap length? then do neural representation analysis?
    - COM changes
    - Solver changes
4. Considering to increase gap length for training to try to see if there would be gap behavior
5. Added eucledian distance
6. Figuring out COM is really hard (should link to Brax, not dm_control)
7. The gap ant is not learning when 1. gap too big? 2. euclidean distance getting too much punishment
    - Using exact implementation hyperparameter that dm_control uses now.
8. **Seems like the CNN is not working too well**:
    - maybe trying edge detector?
    - maybe trying pretrained model?
    - maybe trying gey-scale of the full image?

***
# Training Configuration & Rendor Trajectory
***

This is a direct easy view of some of the most important hyperparametyer tuning that we need todo. Also visit PPO Documentation for more details (https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py).
1. `num_env`:
    - Number of environment is refering to the number of parrallel environment that the agent is traine on. In another word, it is how many instances of the registered environment that have being activated and trained in the same time.
    - It creates more robust and diverse policy
    - Agent learns quicker because it gathers experiences from all of them and just choose the best policy
2. `num_timesteps`:
    - Total number of interactions that will happen between the agent and environment
3. `eval_every`:
    - eval_every is the learning time, it is how often do we update the policy parameter. This is the same with the horizon idea N in the pytorch simple version PPO we implemented
4. `episode_length`:
    - episode length is the number of timesteps that constitute one episode. In here it would be 1000 * 10_000 episodes.
5. `num_evals`:
    - In total how many evals there are, should be num_timesteps/eval_every
6. `batch_size`:
    - Batch size is the number of samples extracted from the "replay buffer" for calling as previous experiences everytime during bellman equation optimization for 1 batch.
    - This 1 batch is used for 1 SGD only, we SGD multiple times with multiple subsamples or "mini batches"
    - In brax implementation of ppo, it doesn't have an replay buffer, but the idea is the same.
7. `num_minibatches`:
    - This is how many batches there are by splitting all data in the replay buffer to little chunks
    - This is also how many times the SGD is run
    - num_batch * batch_size = all_data
    - as replay buffer increase, num_minibatches is the same, each batch_size increases
8. `unroll_length`:
    - The number of timesteps to unroll in each environment. Instead of looking at one time step on the trajectory, unroll_length helps to look at n-steps on the trajectory and collect all the data at once for computational efficiency

**Number of evaluations is proportion to the number of frames that can be rendered later on. However, it does not mean that eval at 500 steps and eval at 1000 steps's differences is just adding frames, it generate completely new insights and new learnings.**

In [None]:
# There are quite some big update on Brax_mjx
class Walker(PipelineEnv):
  '''
  This is greatly coustomizable of what reward you want to give: reward engineering
  '''
  def __init__(
      self,
      forward_reward_weight=5.0,
      ctrl_cost_weight=0.1,
      healthy_reward=0.5,
      terminate_when_unhealthy=True, # should be false in rendering
      healthy_z_range=(0.0, 1.0), # healthy reward takes care of not falling, this is the contact_termination in dm_control
      train_reward=5.0,
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,):
    '''
    Defining initilization of the agent
    '''

    mj_model = physics.model.ptr
    # this is directly a mj_model already of type mujoco_py.MjModel (This is already a MJModel, same as previously in brax)
    # the original xml load is directly creaing an new MjModel instance, which carries the configuration of everything, including mjtCone
    # but this pass in one doesn't, it uses the default mjCONE_PYRAMIDAL, but MjModel now uses the eliptic model, so reset is needed

    # solver is an optimization system
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON #.mjSOL_CG
    mj_model.opt.cone = mujoco.mjtCone.mjCONE_PYRAMIDAL # Read documentation

    #Iterations for solver
    mj_model.opt.iterations = 2
    mj_model.opt.ls_iterations = 4

    sys = mjcf_brax.load_model(mj_model)

    # Defult framne to be 5, but can self define in kwargs
    physics_steps_per_control_step = 3
    
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    # Parents inheritence from MjxEnv class
    #super().__init__(model=mj_model, **kwargs)
    super().__init__(sys, **kwargs)

    # Global vraiable for later calling them
    self._model = mj_model
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._train_reward = train_reward
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (exclude_current_positions_from_observation)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""

    #Creating randome keys
    #rng = random number generator key for starting random initiation
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale

    #Vectors of generalized joint position in the configuration space
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )

    #Vectors of generalized joint velocities in the configuration space
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    #Reset everything
    obs = self._get_obs(data, jp.zeros(self.sys.nu)) #['proprioceptive']
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'train_reward': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics) # State is a big wrapper that contains all information about the environment

  def step(self, state: State, action: jp.ndarray) -> State: # push towards another State
    """Runs one timestep of the environment's dynamics."""
    #Previous Pipeline
    data0 = state.pipeline_state

    #Current pipeline state, step 1
    #Looking at the documenttaion of pipeline_step, "->" means return a modified State
    data = self.pipeline_step(data0, action)

    #Running forward (Velocity) tracking base on center of mass movement
    com_before = data0.subtree_com[3]
    com_after = data.subtree_com[3]

    #print(data.data)
    
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    #Reaching the target location distance using eucledian distance and considering moving backwards
    # def euclidean_distance(point1, point2):
    #   squared_diff = jp.square(point1 - point2)
    #   distance = jp.sqrt(jp.sum(squared_diff))
    #   return distance
    # distance_reward = self._distance_reward * euclidean_distance(com_before, com_after)
    # def negate_distance_reward(_):
    #   return -distance_reward
    # def identity_distance_reward(_):
    #   return distance_reward
    # condition = jp.dot(com_before, com_after) < 0
    # distance_reward = jax.lax.cond(condition, 
    #                           negate_distance_reward, 
    #                           identity_distance_reward, 
    #                           None)

    train_reward = self._train_reward * self.dt # as more training, more rewards

    #Height being healthy
    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)

    #Termination condition
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    #Control force cost
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    #Feedback from env
    obs = self._get_obs(data, action)
    reward = forward_reward + train_reward + healthy_reward - ctrl_cost

    #print(obs)

    #Termination State
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        train_reward=train_reward,
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )
    return state.replace(pipeline_state=data, obs=obs, reward=reward, done=done)

  def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
    """environment feedback of observing walker's proprioreceptive and vision data"""

    # Vision Data Mujoco Version
    # passed in data is a pipeline_state.data object, pipeline_state is the sate
    renderer = mujoco.Renderer(model = self._model)

    # this here is the correct format, need qpos in calling
    #d = mjx.get_data(self._model, data)
    d = mujoco.MjData(self._model)

    mujoco.mj_forward(self._model, d)
    renderer.update_scene(d, camera=3) # can call via name too!
    image = renderer.render()
    image_jax = jax.numpy.array(image)
    print(f'image out of mujoco is {image_jax.shape}')
    # cam = mujoco.MjvCamera()

    # fake_image = jax.numpy.array(np.random.rand(64, 64, 3))
    # image_jax = fake_image.flatten() # fit into jp array

    o_height, o_width, _ = 240,320,3
    c_x,  c_y = o_width//2, o_height//2
    cropped_jax_image = image_jax[c_y-32:c_y+32, c_x-32:c_x+32, :]
    print(f'image cropped {cropped_jax_image.shape}')

    image_jax = cropped_jax_image.flatten()
    image_jax_noise = image_jax * 1e-12 # noise added
    print(f'image cropped flatened {image_jax_noise.shape}')

    # Proprioreceptive Data
    position = data.qpos
    velocity = data.qvel
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    proprioception = jp.concatenate([position, velocity])
    
    # buffer_proprioception = jax.numpy.array(np.random.rand(27,))

    # num = (230427-(27+16)) # image size - (proprioreception + activation parameter)
    # buffer_vision = jax.numpy.array(np.random.rand(num,))

    # # for shape call in train.py of ppo
    # shape = jp.concatenate([proprioception,image_jax]).shape[0] # shape -1 is one number, give as shape tuple

    # full = jp.concatenate([proprioception,image_jax])
  
    # return BraxData(
    #   proprioception = proprioception,
    #   vision = image_jax,
    #   full=full,
    #   buffer_proprioception = buffer_proprioception,
    #   buffer_vision = buffer_vision,
    #   shape = (128, shape) # this works, but there is a type check in jax
    # )

    return jp.concatenate([proprioception, image_jax_noise])

# Registering the environment setup in env as humanoid_mjx
envs.register_environment('walker', Walker)

## Rendoring a Rollout

You need `ffmpeg` to render a rollout, mac rendering can avoid the GL error on Linux serevr,  now rendoring on mac works just fine!
The environment actually works -> establishes and step successfuly. 
- For Linux Server rendering, you need to run this pipeline in a Gpu based and having neccessary colab import and installation with handling GL error.

dm_control model did implement `camera`, but the name is different from brax's model, so we need to find the specific name for the specific model (camera is usually implemeneted in the walker's xml file, it is not something that is normally directly ccaried by MjModle, but it can be set as well):

camera default argument is actually None in brax documentation here in this [line](https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/base.py#L205)

Rendering now works!

In [None]:
env = envs.get_environment(env_name='walker')

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# initialize the state
state = jit_reset(jax.random.PRNGKey(0))

#Creating an container for rollout states
rollout = [state.pipeline_state]

# grab a trajectory
for i in (range(500)):
    ctrl = jax.numpy.array(np.random.uniform(-1,1, env.sys.nu))
    # -1 * jp.ones(env.sys.nu) is just gravity
    # positive is flattening
    # negative grabs tight
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

In [None]:
media.show_video(env.render(rollout,camera=1), fps=1.0 / env.dt)