# Documentation for Explaining Brax Rodent
Theoretical understanding comes from mathematical maturity and implications directly comes from data science maturity

## Installation
0. You need to create an virtual environment specifically just for this first and install all the dependencies (python venv). Do everything once you are in the environmnt that you want to run your program in
1. `git clone "https://github.com/talmolab/Brax-Rodent-Run.git"`
2. `pip install -r requirements.txt`
    - mujoco
    - mujoco-mjx
    - brax
    - wandb
    - mediapy
3. `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`
    - specifically for jax
4. `pip install -U numba`

In [None]:
import numpy as np

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
import wandb

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, MjxEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx

import yaml
from typing import List, Dict, Text
from tqdm import tqdm
from IPython.display import HTML, clear_output

## Load the Rodent Model

In [None]:
# Paramter loading
def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params

params = load_params("params/params.yaml")

## Setting Rodent Env

In [None]:
class Rodent(MjxEnv):
    def __init__(
            self,
            forward_reward_weight=5,
            ctrl_cost_weight=0.1,
            healthy_reward=0.5,
            terminate_when_unhealthy=False,
            healthy_z_range=(0.2, 1.0), # from ant documentation
            reset_noise_scale=1e-2,
            exclude_current_positions_from_observation=False,
            **kwargs,
    ):
        params = load_params("params/params.yaml")
        mj_model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        mj_model.opt.iterations = 6
        mj_model.opt.ls_iterations = 6

        physics_steps_per_control_step = 5
        kwargs['n_frames'] = kwargs.get(
            'n_frames', physics_steps_per_control_step)

        super().__init__(model=mj_model, **kwargs)

        self._forward_reward_weight = forward_reward_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_z_range = healthy_z_range
        self._reset_noise_scale = reset_noise_scale
        self._exclude_current_positions_from_observation = (
            exclude_current_positions_from_observation
        )

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale
        qpos = self.sys.qpos0 + jax.random.uniform(
            rng1, (self.sys.nq,), minval=low, maxval=hi
        )
        qvel = jax.random.uniform(
            rng2, (self.sys.nv,), minval=low, maxval=hi
        )

        data = self.pipeline_init(qpos, qvel)

        obs = self._get_obs(data.data, jp.zeros(self.sys.nu))
        reward, done, zero = jp.zeros(3)
        metrics = {
            'forward_reward': zero,
            'reward_linvel': zero,
            'reward_quadctrl': zero,
            'reward_alive': zero,
            'x_position': zero,
            'y_position': zero,
            'distance_from_origin': zero,
            'x_velocity': zero,
            'y_velocity': zero,
        }
        return State(data, obs, reward, done, metrics)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        
        # initial data
        data0 = state.pipeline_state

        # This is the step, which this environment is wrapped and directly for ppo to use
        data = self.pipeline_step(data0, action)
        
        # based on the timestep simulation, calculate the rewards
        com_before = data0.data.subtree_com[1]
        com_after = data.data.subtree_com[1]

        velocity = (com_after - com_before) / self.dt
        forward_reward = self._forward_reward_weight * velocity[0]

        min_z, max_z = self._healthy_z_range
        is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
        is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
        
        if self._terminate_when_unhealthy:
            healthy_reward = self._healthy_reward
        else:
            healthy_reward = self._healthy_reward * is_healthy

        ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

        obs = self._get_obs(data.data, action)
        reward = forward_reward + healthy_reward - ctrl_cost
        
        # terminates when unhealthy
        done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
        state.metrics.update(
            forward_reward=forward_reward,
            reward_linvel=forward_reward,
            reward_quadctrl=-ctrl_cost,
            reward_alive=healthy_reward,
            x_position=com_after[0],
            y_position=com_after[1],
            distance_from_origin=jp.linalg.norm(com_after),
            x_velocity=velocity[0],
            y_velocity=velocity[1],
        )

        return state.replace(
            pipeline_state=data, obs=obs, reward=reward, done=done
        )

    def _get_obs(
            self, data: mjx.Data, action: jp.ndarray
    ) -> jp.ndarray:
        """Observes humanoid body position, velocities, and angles."""
        position = data.qpos
        if self._exclude_current_positions_from_observation:
            position = position[2:]
            
        # external_contact_forces are excluded
        return jp.concatenate([
            position,
            data.qvel,
            data.cinert[1:].ravel(),
            data.cvel[1:].ravel(),
            data.qfrc_actuator,
        ])

envs.register_environment('rodent', Rodent)

## Instantiate Env

In [None]:
# instantiate the environment
env_name = 'rodent'
env = envs.get_environment(env_name)

# define the jit reset/step functions
# jit stands for "just in time", which is what jax uses to accelerate computation by converting it to a different format
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
# tqdm time bar
for i in tqdm(range(10)):
  ctrl = -0.1 * jp.ones(env.sys.nu)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)

In [None]:
state = jit_reset(jax.random.PRNGKey(0))
HTML(html.render(env.sys, [state.pipeline_state]))

## Configuration
This is a direct easy view of some of the most important hyperparametyer tuning that we need todo. Also visit PPO Documentation for more details (https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py).
1. `num_env`:
    - Number of environment is refering to the number of parrallel environment that the agent is traine on. In another word, it is how many instances of the registered environment that have being activated and trained in the same time.
    - It creates more robust and diverse policy
    - Agent learns quicker because it gathers experiences from all of them and just choose the best policy
2. `num_timesteps`:
    - Total number of interactions that will happen between the agent and environment
3. `eval_every`:
    - eval_every is the learning time, it is how often do we update the policy parameter. This is the same with the horizon idea N in the pytorch simple version PPO we implemented
4. `episode_length`:
    - episode length is the number of timesteps that constitute one episode. In here it would be 1000 * 10_000 episodes.
5. `num_evals`:
    - In total how many evals there are, should be num_timesteps/eval_every
6. `batch_size`:
    - Batch size is the number of samples extracted from the "replay buffer" for calling as previous experiences everytime during bellman equation optimization for 1 batch.
    - This 1 batch is used for 1 SGD only, we SGD multiple times with multiple subsamples or "mini batches"
    - In brax implementation of ppo, it doesn't have an replay buffer, but the idea is the same.
7. `num_minibatches`:
    - This is how many batches there are by splitting all data in the replay buffer to little chunks
    - This is also how many times the stochastic gradient descent is run
    - num_batch * batch_size = all_data
    - as replay buffer increase, num_minibatches is the same, each batch_size increases

In [None]:
config = {
    "env_name": env_name,
    "algo_name": "ppo",
    "task_name": "run",
    "num_envs": 2048,
    "num_timesteps": 10_000_000,
    "eval_every": 10_000,
    "episode_length": 1000,
    "num_evals": 1000,
    "batch_size": 512,
    "learning_rate": 3e-4,
    "terminate_when_unhealthy": False
}

## Train, Record on Wandb, and Save Models

In [None]:
train_fn = functools.partial(
    ppo.train, num_timesteps=config["num_timesteps"], num_evals=int(config["num_timesteps"]/config["eval_every"]),
    reward_scaling=0.1, episode_length=config["episode_length"], normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=8, num_updates_per_batch=4,
    discounting=0.98, learning_rate=config["learning_rate"], entropy_cost=1e-3, num_envs=config["num_envs"],
    batch_size=config["batch_size"], seed=0)

# Saving everything to Wandb
run = wandb.init(project="vnl", config=config)

wandb.run.name = f"{config['env_name']}_{config['task_name']}_{config['algo_name']}_brax"

def wandb_progress(num_steps, metrics):
    metrics["num_steps"] = num_steps
    wandb.log(metrics)
    
# Making inference
make_inference_fn, params, _= train_fn(environment=env, progress_fn=wandb_progress) # diectly use wandb as progress function

#@title Save Model
model_path = '/cps/brax_ppo_rodent_run'
model.save_params(model_path, params)