# Task 2: Imitation Learning

### Setup Code / Packages
If you are running the code locally install ..... 

If you are using Google Colab make sure to run the cell below to install all dependencies. 

In [20]:
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
import numpy as np
from scipy.spatial.transform import Rotation as R
import mujoco
import mplib
import matplotlib.pyplot as plt


## Part 0: Creating a Standardized Environment with Gymnasium
In this part, you will implement the `TrajEnv` class to wrap a MuJoCo simulation in a Gymnasium interface, providing a consistent way to reset the environment, step through actions, and collect observations. 

Your task is to decide which observations are useful as states in a behavior cloning setting. You will also implement the `is_grasped` property to define when the robot has successfully grasped the object and the `terminated` property to specify when the cube is in the bin. The `reset_model` method should randomize initial arm and object positions to encourage generalization.

In [6]:
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
import numpy as np
import mujoco


class TrajEnv(MujocoEnv):
    metadata = {"render_modes": ["human", "rgb_array", "depth_array"]}

    def __init__(self, xml_file: str, frame_skip: int = 5, **kwargs):
        # TODO: Define the observation space and store any useful state variables.
        observation_space = Box(
            low=-np.inf, high=np.inf, shape=(000,), dtype=np.float32  # TODO: set correct shape
        )

        super().__init__(
            xml_file,
            frame_skip,
            observation_space=observation_space,
            **kwargs,
        )

        self.panda_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, 'panda_arm')
        

    
    @property
    def is_grasped(self):
        # TODO: Implement a condition to determine if the object
        # is grasped. Could use distance between gripper and object, contact forces, etc.
        ...

    @property
    def terminated(self):
        # TODO: Define the termination condition for an episode.
        ...
    
    def _get_obs(self):
        # TODO: Return the observations selected for the environment state
        # Example: joint positions, velocities, end-effector pose, object positions
        ...

    def step(self, action) -> tuple[np.ndarray, bool, bool]:
        # DO NOT MODIFY THIS FUNCTION
        self.do_simulation(action, self.frame_skip)
        obs = self._get_obs()
        terminated = self.terminated

        if self.render_mode == "human":
            self.render()

        # truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
        return obs, None, terminated, False, None

    def reset_model(self):
        self.default_qpos = np.array(
                self.model.keyframe('home').qpos,
                dtype=np.float32
            )

        qpos = self.default_qpos.copy()
        qvel = np.zeros(self.model.nv)

        # TODO: Randomize initial positions for the arm (9 joints) and objects
        # to create diverse starting conditions. Ensure valid states and objects are not overlapping.
        ...

        self.set_state(qpos, qvel)
        return self._get_obs()



To understand what is happening in the environment, you can call `env.render()` to visualize it.  
Since we set `render_mode="rgb_array"`, `env.render()` will return an RGB array that can be displayed.  

> **Note:** If the default view of the environment is not ideal, you may need to adjust the camera.  
> You can do this by defining a camera in your Mujoco XML file and then passing its name to the `MujocoEnv` constructor via the `camera_name` argument.


In [None]:
from gymnasium.wrappers import TimeLimit
import matplotlib.pyplot as plt


env = TrajEnv("", 20, render_mode='rgb_array')
image = env.render() 
image = image
plt.imshow(image)


With these core functionalities we can build the basic interaction loop that consists of a reset followed by steps until terminated or truncated is True. We can also record a video as well

In [None]:
obs, _ = env.reset()
done = False
images = [env.render()]
# keep taking steps until either we terminated or truncate
while not done:
    obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
    images.append(env.render())
    done = terminated or truncated

# save a video
from mani_skill.utils.visualization import images_to_video
images_to_video(images, output_dir="videos", video_name="example", fps=20)

In [None]:
from IPython.display import Video
Video("./videos/example.mp4", embed=True, width=640) # Watch our replay

## Part 1: Collect your Imitation Learning Dataset


In [21]:
# ===== Setup =====
env = TrajEnv(
    xml_file="/home/anthony-roumi/Desktop/sim_onboarding/descriptions/DropCubeInBinPandaEnv.xml", 
    frame_skip=20,
    
)
#TODO: Define the link and joint names for the planner. How can we do this using MuJoCo?
link_names = []
joint_names = []

planner = mplib.Planner(
    urdf='assets/robots/panda/panda_v2.urdf',
    srdf='assets/robots/panda/panda_v2.srdf',
    user_link_names=link_names,
    user_joint_names=joint_names,
    move_group="panda_hand_tcp",
    joint_vel_limits=np.ones(7) * 0.8,
    joint_acc_limits=np.ones(7) * 0.8,
)

# Helper: execute planned motion and record data
def move_to_pose(pose, gripper, episode_data):
    result = planner.plan_screw(
            pose,
            env.data.qpos[:9],
            time_step=env.dt,
        )
    if 'Success' in result:
        for pos in result["position"]:
            action = np.concatenate([pos, [gripper]])
            obs = env._get_obs()
            env.step(action)
            episode_data.append((obs, action))
    else:
        print(f"Planner failed: {result}")
        return None  # Signal failure
    return pos

In [None]:
trajectories = []
num_episodes = 10_000
success_count = 0
ep = 0

while ep < num_episodes:
    env.reset()
    episode_data = []

    # IDs for objects
    panda_id = env.panda_id
    cube_body_id = env.cube_body_id
    bin_body_id = env.bin_body_id
    tcp_body_id = env.tcp_body_id

    # this sets the planner object up such that you can plan with poses in the world frame.
    planner.set_base_pose(np.concatenate([env.data.xpos[panda_id], env.data.xquat[panda_id]]))

    # Get positions/quats
    cube_pose = env.data.xpos[cube_body_id]
    cube_quat = env.data.xquat[cube_body_id]
    bin_pose = env.data.xpos[bin_body_id]
    bin_quat = env.data.xquat[bin_body_id]
    tcp_quat = env.data.xquat[tcp_body_id]

    try:
        # ==== Motion Sequence ====
        # 1. Move above cube
        r_tcp_rel = R.from_quat(tcp_quat).inv() * R.from_quat(cube_quat)
        tcp_quat_new = (R.from_quat(cube_quat) * r_tcp_rel).as_quat()
        above_cube = np.concatenate([cube_pose + np.array([0, 0, 0.1]), tcp_quat_new])
        last_pos = move_to_pose(above_cube, 255, episode_data)
        if last_pos is None: 
            raise RuntimeError("Skipping episode due to planner failure at above_cube")

        # 2. Move down to cube
        # cube_pose = cube_pose + np.array([0, 0, 0.04])
        to_cube = np.concatenate([cube_pose, tcp_quat_new])
        last_pos = move_to_pose(to_cube, 255, episode_data)
        if last_pos is None: 
            raise RuntimeError("Skipping episode due to planner failure at to_cube")

        # 3. Close gripper
        for step in range(50):
            action = np.concatenate([last_pos, [0]])
            obs = env._get_obs()
            env.step(action)
            episode_data.append((obs, action))

        # 4. Move above bin
        r_tcp_rel = R.from_quat(tcp_quat).inv() * R.from_quat(bin_quat)
        tcp_quat_new = (R.from_quat(bin_quat) * r_tcp_rel).as_quat()
        above_bin = np.concatenate([bin_pose + np.array([0, 0, 0.2]), tcp_quat_new])
        last_pos = move_to_pose(above_bin, 0, episode_data)
        if last_pos is None: 
            raise RuntimeError("Skipping episode due to planner failure at above_bin")

        # 5. Move down to bin
        to_bin = np.concatenate([bin_pose + np.array([0, 0, 0.1]), tcp_quat_new])
        last_pos = move_to_pose(to_bin, 0, episode_data)
        if last_pos is None: 
            raise RuntimeError("Skipping episode due to planner failure at to_bin")

        # 6. Open gripper
        for step in range(25):
            action = np.concatenate([last_pos, [255]])
            obs = env._get_obs()
            env.step(action)
            episode_data.append((obs, action))

        if env.terminated:
            trajectories.append(episode_data)
            print(f"Episode {ep} completed successfully.")
            success_count += 1
        ep += 1  # Only increment episode count if successful

    except RuntimeError as e:
        print(e)
        print("Resetting environment and skipping this trajectory.")
        print('='*100)
        continue
    
# ===== Save dataset =====
with open("pick_place_dataset.pkl", "wb") as f:
    pickle.dump(trajectories, f)

print(f"Success rate: {success_count / num_episodes}")
print(f"Total trajectories: {len(trajectories)}")
env.close()

## Part 2: Dataset Class

In [None]:
import torch
from torch.utils.data import Dataset
import pickle

class TrajectoryDataset(Dataset):
    def __init__(self, data):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        #This should return observation idx and action idx
        pass


dataset_file = "pick_place_dataset.pkl"
with open(dataset_file, "rb") as f:
    data = pickle.load(f)
dataset = TrajectoryDataset(data)

## Part 3: Define Policy Network


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

actuator_ctrl_range = torch.tensor([[-2.8973e+00,  2.8973e+00],
       [-1.7628e+00,  1.7628e+00],
       [-2.8973e+00,  2.8973e+00],
       [-3.0718e+00, -6.9800e-02],
       [-2.8973e+00,  2.8973e+00],
       [-1.7500e-02,  3.7525e+00],
       [-2.8973e+00,  2.8973e+00],
       [ 0.0000e+00,  2.5500e+02]])

class GaussianActor(nn.Module):
    def __init__(self, sample_obs, sample_act, ctrl_ranges=actuator_ctrl_range):
        super().__init__()
        
        self.input_dim = sample_obs.shape[-1]
        self.output_dim = sample_act.shape[-1]

        # store actuator ranges
        self.ctrl_low = torch.tensor(ctrl_ranges[:, 0], dtype=torch.float32)
        self.ctrl_high = torch.tensor(ctrl_ranges[:, 1], dtype=torch.float32)
        
        # shared backbone
        self.net = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.LayerNorm(256),   # helps stability
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        # mean head
        self.mean_head = nn.Linear(256, self.output_dim)
        # log std head (trainable variance per action dimension)
        self.log_std_head = nn.Linear(256, self.output_dim)

        # initialize
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)
        nn.init.xavier_uniform_(self.mean_head.weight)
        nn.init.zeros_(self.mean_head.bias)
        nn.init.xavier_uniform_(self.log_std_head.weight)
        nn.init.zeros_(self.log_std_head.bias)

    def forward(self, x):
        h = self.net(x)

        mean_raw = self.mean_head(h)
        log_std = self.log_std_head(h).clamp(-5, 2)  # std ∈ [0.05, 1.65]


        # squash mean into [-1,1]
        mean_squashed = torch.tanh(mean_raw)

        # affine transform to actuator ranges
        mean = self.ctrl_low.to(mean_squashed.device) + 0.5 * (mean_squashed + 1.0) * (
            self.ctrl_high - self.ctrl_low
        ).to(mean_squashed.device)

        return mean, log_std

    def sample(self, x):
        mean, log_std = self.forward(x)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        action = dist.rsample()   # reparameterized sample
        return action, dist


## Part 4: Training Loop


In [22]:
def eval_policy(eval_env, actor, num_episodes=5):
    success = 0
    for ep in range(num_episodes):
        obs, _ = eval_env.reset()
        terminated, truncated = False, False
        
        while not (truncated):
            with torch.no_grad():
                action, _ = actor(torch.from_numpy(obs).to(device).float())
            obs, reward, terminated, truncated, info = eval_env.step(action.cpu().numpy())

        if terminated:
            success += 1
        
        print(f"Episode {ep+1}: Success={terminated}")
    return success / num_episodes

In [None]:
def nll_loss(actor, obs, expert_actions):
    mean, log_std = actor(obs)
    std = log_std.exp()


    dist = torch.distributions.Normal(mean, std)
    log_prob = dist.log_prob(expert_actions).sum(-1)   # sum over action dims
    loss = -log_prob.mean()
    return loss


num_epochs = 100
learning_rate = 3e-5
best_success_rate = 0.0
eval_interval = 5
best_loss = float('inf')

optimizer = torch.optim.Adam(actor.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Training loop
for epoch in range(num_epochs):
    actor.train()
    total_loss = 0
        
    for obs, actions in dataloader:
        obs, actions = obs.to(device).float(), actions.to(device).float()
        mean, log_std = actor(obs)   # forward pass

        
        loss = nll_loss(actor, obs, actions)

        
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
    
    if (epoch + 1) % eval_interval == 0:
        success_rate = eval_policy(eval_env, actor)
        print(f'Evaluation at epoch {epoch+1}, Success Rate: {success_rate:.2%}')
        
        if success_rate > best_success_rate:
            best_success_rate = success_rate
            torch.save(actor.state_dict(), 'best_actor_success.pth')
            print(f'New best model saved with success rate:')
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(actor.state_dict(), 'best_actor_loss.pth')
            print(f'New best model saved with best loss:')


## Part 5: Evaluate in Simulator

In [None]:
actor.load_state_dict(torch.load('best_actor_loss.pth'))
eval_policy(eval_env, actor, 5)