In [14]:
import numpy as np

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union, Optional
import wandb

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, MjxEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import os

import yaml
from typing import List, Dict, Text

# ## TODO:
# 
# - Check the healthy z-range of the rodent. Now the training
#     - Check mj_data and how to pull out kinematics of the simulations
# - Check the `brax.envs` and how I can pass the custom parameters

# In[3]:


def load_params(param_path: Text) -> Dict:
    with open(param_path, "rb") as file:
        params = yaml.safe_load(file)
    return params


params = load_params("params/params.yaml")

class Rodent(MjxEnv):
    
    # Might want to change the terminate_when_unhealthy params to enables
    # longer episode length, since the average episode length is too short (1 timestep)
    # temp change the `terminate_when_unhealthy` to extend the episode length.
    def __init__(
            self,
            forward_reward_weight=5,
            ctrl_cost_weight=0.1,
            healthy_reward=0.5,
            terminate_when_unhealthy=False,
            healthy_z_range=(0.2, 1.0),
            reset_noise_scale=1e-2,
            exclude_current_positions_from_observation=False,
            **kwargs,
    ):
        params = load_params("params/params.yaml")
        mj_model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        mj_model.opt.iterations = 6
        mj_model.opt.ls_iterations = 6

        physics_steps_per_control_step = 5
        kwargs['n_frames'] = kwargs.get(
            'n_frames', physics_steps_per_control_step)

        super().__init__(model=mj_model, **kwargs)

        self._forward_reward_weight = forward_reward_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_z_range = healthy_z_range
        self._reset_noise_scale = reset_noise_scale
        self._exclude_current_positions_from_observation = (
            exclude_current_positions_from_observation
        )

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale
        qpos = self.sys.qpos0 + jax.random.uniform(
            rng1, (self.sys.nq,), minval=low, maxval=hi
        )
        qvel = jax.random.uniform(
            rng2, (self.sys.nv,), minval=low, maxval=hi
        )

        data = self.pipeline_init(qpos, qvel)

        obs = self._get_obs(data.data, jp.zeros(self.sys.nu))
        reward, done, zero = jp.zeros(3)
        metrics = {
            'forward_reward': zero,
            'reward_linvel': zero,
            'reward_quadctrl': zero,
            'reward_alive': zero,
            'x_position': zero,
            'y_position': zero,
            'distance_from_origin': zero,
            'x_velocity': zero,
            'y_velocity': zero,
        }
        return State(data, obs, reward, done, metrics)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        data0 = state.pipeline_state
        data = self.pipeline_step(data0, action)
        # based on the timestep simulation, calculate the rewards
        com_before = data0.data.subtree_com[1]
        com_after = data.data.subtree_com[1]
        velocity = (com_after - com_before) / self.dt
        forward_reward = self._forward_reward_weight * velocity[0]

        min_z, max_z = self._healthy_z_range
        is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
        is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
        if self._terminate_when_unhealthy:
            healthy_reward = self._healthy_reward
        else:
            healthy_reward = self._healthy_reward * is_healthy

        ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

        obs = self._get_obs(data.data, action)
        reward = forward_reward + healthy_reward - ctrl_cost
        # terminates when unhealthy
        done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
        state.metrics.update(
            forward_reward=forward_reward,
            reward_linvel=forward_reward,
            reward_quadctrl=-ctrl_cost,
            reward_alive=healthy_reward,
            x_position=com_after[0],
            y_position=com_after[1],
            distance_from_origin=jp.linalg.norm(com_after),
            x_velocity=velocity[0],
            y_velocity=velocity[1],
        )

        return state.replace(
            pipeline_state=data, obs=obs, reward=reward, done=done
        )

    def _get_obs(
            self, data: mjx.Data, action: jp.ndarray
    ) -> jp.ndarray:
        """Observes humanoid body position, velocities, and angles."""
        position = data.qpos
        if self._exclude_current_positions_from_observation:
            position = position[2:]
            
        # external_contact_forces are excluded
        return jp.concatenate([
            position,
            data.qvel,
            data.cinert[1:].ravel(),
            data.cvel[1:].ravel(),
            data.qfrc_actuator,
        ])


# ## training loop

# In[5]:


# envs.register_environment('rodent', Rodent)

# # instantiate the environment
# env_name = 'rodent'
rodent_env = Rodent()

In [32]:
ant_env = envs.get_environment("ant")
humanoid_env = envs.get_environment("humanoid")

In [16]:
# define the jit reset/step functions
# jit_reset = jax.jit(env.reset)
# jit_step = jax.jit(env.step)

# initialize the state
rng = jax.random.PRNGKey(0)

In [33]:
# initialize the state
rng = jax.random.PRNGKey(0)
rodent_state = rodent_env.reset(rng)
ant_state = ant_env.reset(rng)
humanoid_state = humanoid_env.reset(rng)

## Notes

Since the observation size is insanly large for our rodent, this might lead to a memory overflow during training. We might want to rethink which observation should be pased into the network

In [34]:
print("Rodent, Ant, Humanoid Observation shapes are:")
len(rodent_state.obs), len(ant_state.obs), len(humanoid_state.obs)

Rodent, Ant, Humanoid Observation shapes are:


(1260, 27, 244)

In [39]:
data = rodent_state.pipeline_state.data

In [51]:
data.qfrc_actuator

Array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
       -3.71473469e-02,  1.04761586e-01, -1.62700385e-01,  6.28157929e-02,
        2.54413933e-01,  2.19533630e-02, -3.45028117e-02,  1.01627558e-01,
       -1.59005359e-01,  6.14019856e-02,  2.47480020e-01,  2.22522095e-02,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  

In [38]:
for i in rodent_state.obs:
    print(i)

0.040483385
-0.008967142
0.075515956
0.99994427
0.0017467347
0.00456508
0.009354588
0.0049890876
0.0016337344
0.008587049
0.0055944296
0.009385901
0.009719577
0.005601906
-0.00010457542
0.009367904
4.015211e-05
-0.007716243
-0.002734976
-0.0010094289
0.0077305045
0.003209548
0.0035746768
0.009618528
-0.008711902
-0.0034158942
-0.008774319
-0.0069099306
0.004924898
0.008875493
0.00018202513
0.0038369969
-0.009461475
-0.0018994613
-0.009180471
0.008077577
0.002224462
0.009942982
-0.00968115
0.006055489
-0.0069334744
-0.0026698112
-0.007760038
0.0091857575
-0.008716455
0.00023231469
0.0024697995
-0.0018483903
-0.0009160759
0.003836155
0.006834099
0.009972759
-0.00085709523
-0.0071314145
-0.002901101
0.0040631173
-0.0023284983
-0.0071229767
-0.000935724
0.009962346
0.00569658
0.0078045893
0.0027962402
-0.008259692
-0.0016155839
-0.0053429673
0.0073601343
-0.0004749298
-0.0010330556
-0.00022646692
0.0035443567
-0.008738608
-0.0013325643
-0.0074517224
0.006061701
-0.008160801
-0.008099616
-0

In [31]:
data.q

Array([ 4.0483385e-02, -8.9671416e-03,  7.5515956e-02,  9.9994427e-01,
        1.7467347e-03,  4.5650802e-03,  9.3545876e-03,  4.9890876e-03,
        1.6337344e-03,  8.5870493e-03,  5.5944296e-03,  9.3859006e-03,
        9.7195767e-03,  5.6019062e-03, -1.0457542e-04,  9.3679037e-03,
        4.0152110e-05, -7.7162432e-03, -2.7349759e-03, -1.0094289e-03,
        7.7305045e-03,  3.2095481e-03,  3.5746768e-03,  9.6185282e-03,
       -8.7119024e-03, -3.4158942e-03, -8.7743187e-03, -6.9099306e-03,
        4.9248980e-03,  8.8754930e-03,  1.8202513e-04,  3.8369969e-03,
       -9.4614746e-03, -1.8994613e-03, -9.1804713e-03,  8.0775768e-03,
        2.2244621e-03,  9.9429823e-03, -9.6811503e-03,  6.0554892e-03,
       -6.9334744e-03, -2.6698112e-03, -7.7600381e-03,  9.1857575e-03,
       -8.7164547e-03,  2.3231469e-04,  2.4697995e-03, -1.8483903e-03,
       -9.1607589e-04,  3.8361549e-03,  6.8340991e-03,  9.9727586e-03,
       -8.5709523e-04, -7.1314145e-03, -2.9011010e-03,  4.0631173e-03,
      

In [25]:
[
            data.qvel,
            data.cinert[1:].ravel(),
            data.cvel[1:].ravel(),
            data.qfrc_actuator,
        ]

AttributeError: 'State' object has no attribute 'qvel'