# Task 3: Reinforcement Learning

## Setup Code / Packages
If you are running the code locally follow the instructions [here](../docs/setup.md).

If you are using Google Colab make sure to run the two cells below to install all dependencies. You will also need to upload your xml scene to the assets folder. 

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax

#Download the assets folder
!git init temp-repo
%cd temp-repo
!git remote add -f origin https://github.com/triton-droids/simulation.git
!git config core.sparseCheckout true
!echo "assets/*" >> .git/info/sparse-checkout
!git pull origin onboarding
%cd ..
!mv temp-repo/assets ./assets
!rm -rf temp-repo

In [None]:
# Configure MuJoCo for Colab
import os
os.environ['MUJOCO_GL'] = 'egl'  # enables headless rendering

## Part 1: Implementing the Pick-and-Place MJX Environment

In this part, you will implement the `PickAndPlace` environment using MuJoCo and MJX.

- `_init_env`: Initialize important environment variables such as the robot’s home pose, cube and bin locations, and body/joint IDs.
- `reset`: Reset the simulation state with randomized initial positions for the robot arm, cube, and bin. 
- `step`: Advance the simulation forward given an action. Update environment state, check if the cube has been grasped, compute the reward, and determine if the episode has ended.
- `_get_obs`: Define the observation space.
- `check_grasp`: Determine whether the cube is grasped by the robot.
- `check_success`: Define the success condition by checking if the cube is inside the bin boundaries.
- `_compute_reward`: Design a shaped reward that encourages approaching the cube, grasping it, and successfully placing it into the bin.

In [None]:
from brax.envs.base import PipelineEnv, State
from mujoco.mjx._src import support
from brax.io import mjcf
from jax import numpy as jp
from typing import Any
import mujoco
import jax

XML_PATH = os.path.abspath("assets/descriptions/DropCubeInBinEnv.xml")

class PickAndPlace(PipelineEnv):
    def __init__(self, xml_path: str, cfg, **kwargs: Any):
        #DO NOT MODIFY
        mj_model = mujoco.MjModel.from_xml_path(xml_path)

        sys = mjcf.load_model(
            mj_model
        )
        sys = sys.tree_replace(
            {
                "opt.timestep": cfg.sim.timestep,
                "opt.solver": cfg.sim.solver,
                "opt.iterations": cfg.sim.iterations,
                "opt.ls_iterations": cfg.sim.ls_iterations,
            }
        )
        kwargs["n_frames"] = cfg.action.n_frames
        kwargs["backend"] = "mjx"

        super().__init__(sys, **kwargs)

        self._init_env()

    def _init_env(self) -> None:
        self.init_q = jp.array(self.sys.mj_model.keyframe("home").qpos)

        self.nu = self.sys.nu
        self.nq = self.sys.nq
        self.nv = self.sys.nv

        self.tcp_body_id = support.name2id(self.sys, mujoco.mjtObj.mjOBJ_BODY, "panda_hand")

        #TODO: Store any useful environment variables
        

    def reset(self, rng: jp.ndarray) -> State:
        qpos = self.init_q.copy()
        qvel = jp.zeros(self.nv)

        # TODO: Randomize the initial positions of the robot arm, cube, and bin.


        pipeline_state = self.pipeline_init(qpos, qvel) 

        state_info = {
            "rng": rng,
            "step": 0,
            "done": jp.array(False),
        }

        # TODO: Initialize the environment state information.

        
        return State(
            pipeline_state, obs, reward, done, metrics, state_info
        )

    def step(self, state: State, action: jp.ndarray) -> State:

        pipeline_state = self.pipeline_step(state.pipeline_state, action)

        success = self.check_success(pipeline_state)

        # TODO: Update the environment state after each step.

        state.info["step"] += 1
        state.info["step"] = jp.where(
            success | (state.info["step"] > 500), 0, state.info["step"]
        )


        # Episode is done if task is successful or max steps exceeded
        done = success | (state.info["step"] > 500)
        state.info["done"] = done

        return state.replace(
            pipeline_state = pipeline_state,
            obs = obs,
            reward = reward,
            done = done.astype(jp.float32),
        )
    
    def _get_obs(
        self,
        pipeline_state: State,
        info: dict[str, Any],
    ):
        # TODO: Construct the observations for the agent.

        return {
            "state": obs,
            "privileged_state": obs
        }

    def check_success(self, pipeline_state: State):
        # TODO: Define the success condition for an episode.
        pass
    
    def check_grasp(self, pipeline_state: State):
        # TODO: Implement a condition to determine if the object
        # is grasped. Could use distance between gripper and object, contact forces, etc.
        pass


    def _compute_reward(self, pipeline_state: State, info: dict[str, Any]):
        # TODO: Implement the reward function.
        # HINT: How can we break this up into subtasks ?
        pass


## Using Configs in the Codebase
In our codebase, we use Hydra and dataclasses to manage configurations. This keeps experiment settings (like simulation parameters, observation sizes, or agent hyperparameters) separate from the actual environment and training code, making it easy to reproduce results or run variations of the same experiment.

The `PickAndPlaceConfig` defines all the key parameters needed for this task:
- **sim**: Physics engine settings such as timestep and solver iterations.
- **action**: How actions are repeated in the environment.
- **ppo_agent**: PPO hyperparameters.
- **ppo_networks**: Neural network configuration. 
- **seed**: Random seed for reproducibility.

In [3]:
from dataclasses import dataclass, field
from typing import Tuple

@dataclass
class PickAndPlaceConfig:
    """Configuration class for the MJX environment."""

    @dataclass
    class SimConfig:
        timestep: float = 0.002
        solver: int = 2
        iterations: int = 5 
        ls_iterations: int = 8

    @dataclass
    class ActionConfig:
        n_frames: int = 5
    
    @dataclass
    class PPOAgentConfig:
        num_timesteps: int = 10_000_000
        num_evals: int = 100
        episode_length: int = 500
        unroll_length: int = 50
        num_minibatches: int = 32
        num_updates_per_batch: int = 4
        discounting: float = 0.97
        learning_rate: float = 3e-4
        entropy_cost: float = 0.02
        clipping_epsilon: float = 0.2
        num_envs: int = 1024
        batch_size: int = 256
        seed: int = 42
        normalize_observations: bool = True
        action_repeat: int = 1
        max_grad_norm: float = 1.0
        num_resets_per_eval: int = 1
    
    @dataclass
    class PPONetworksConfig:
        policy_hidden_layer_sizes: Tuple[int, ...] = (256, 256, 256)
        value_hidden_layer_sizes: Tuple[int, ...] = (256, 256, 256)
        policy_obs_key: str = "state"
        value_obs_key: str = "privileged_state"
    
    sim: SimConfig = field(default_factory=SimConfig)
    action: ActionConfig = field(default_factory=ActionConfig)
    ppo_agent: PPOAgentConfig = field(default_factory=PPOAgentConfig)
    ppo_networks: PPONetworksConfig = field(default_factory=PPONetworksConfig)
    seed: int = PPOAgentConfig.seed

config = PickAndPlaceConfig()

Next, we create our train and evaluation environment. We define a factory function for building PPO networks with the architecture and observation settings specified in the config.

We set things up this way because we will use Brax’s PPO implementation, which expects an environment and a networks factory to run training.

In [None]:
from brax.training.agents.ppo import networks as ppo_networks
import functools

env = PickAndPlace(
        xml_path=XML_PATH,
        cfg = config,
    )
    
eval_env = PickAndPlace(
    xml_path=XML_PATH,
    cfg = config,
)

make_networks_factory = functools.partial(
    ppo_networks.make_ppo_networks,
    policy_hidden_layer_sizes=config.ppo_networks.policy_hidden_layer_sizes,
    value_hidden_layer_sizes=config.ppo_networks.value_hidden_layer_sizes,
    policy_obs_key=config.ppo_networks.policy_obs_key,
    value_obs_key=config.ppo_networks.value_obs_key,
)

We also define a function for saving parameters that runs every:

`num_timesteps` / `num_evals`


In [None]:
import os 
from flax.training import orbax_utils
from brax.io import model
from orbax import checkpoint as ocp

def policy_params_fn(current_step: int, make_policy: Any, params: Any):
    # save checkpoints
    orbax_checkpointer = ocp.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(params)
    path = os.path.join("checkpoints", f"step_{current_step}")
    os.makedirs(path, exist_ok=True)
    orbax_checkpointer.save(path, params, force=True, save_args=save_args)
    policy_path = os.path.join(path, "policy")
    model.save_params(policy_path, (params[0], params[1].policy))

Lastly, we set up Weights & Biases to monitor training progress and log metrics in real time and pass the algorithm parameters to the train function.


In [None]:
from brax.training.agents.ppo import train as ppo
from dataclasses import asdict
import functools
import wandb
import time

run_name = "PickandPlace"
wandb.init(
    project="OnBoarding",
    name=run_name,
    config=config.ppo_agent
    )


train_fn = functools.partial(
        ppo.train,
        **asdict(config.ppo_agent)
    )

times = [time.time()]

last_ckpt_step = 0
best_ckpt_step = 0
best_episode_reward = -float("inf")
last_video_step = 0

def progress(num_steps, metrics):
    global best_episode_reward, best_ckpt_step, last_ckpt_step, last_video_step

    times.append(time.time())
    wandb.log(metrics, step=num_steps)

    last_ckpt_step = num_steps

    episode_reward = float(metrics.get("eval/episode_reward", 0))
    if episode_reward > best_episode_reward:
        best_episode_reward = episode_reward
        best_ckpt_step = num_steps
    print(f"{num_steps}: {metrics['eval/episode_reward']}")

We are ready to start training!

In [None]:
make_inference_fn, params, _ = train_fn(
    environment=env, eval_env=eval_env, progress_fn=progress
)

print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")
print(f"best checkpoint step: {best_ckpt_step}")
print(f"best episode reward: {best_episode_reward}")

## Visualize Rollouts

In [None]:
import jax

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

In [None]:
import mediapy as media

rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(200):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)

render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)