# **PerAct**

**Important: Before starting, change the runtime to GPU.**

### Modified
This notebook is a modified version of the original `PerAct_Tutorial.ipynb` that explains the training of [Perceiver-Actor (PerAct)](https://peract.github.io/). We will look at training a single-task agent on the `...` task.  The tutorial will start from loading calibrated RGB-D images, and end with visualizing *action detections* in voxelized observations. Overall, this guide is
meant to complement the [paper](https://peract.github.io/) by providing concrete implementation details. Moreover, this modified notebook allows for storing the model weights.

### Full Code
See [this Github repository](https://github.com/peract/peract) for the full code, pre-trained checkpoints, and pre-generated datasets. You should be able to use the pre-generated datasets with this notebook.

### Credit
This notebook heavily builds on data-loading and pre-preprocessing code from [`ARM`](https://github.com/stepjam/ARM), [`YARR`](https://github.com/stepjam/YARR), [`PyRep`](https://github.com/stepjam/PyRep), [`RLBench`](https://github.com/stepjam/RLBench) by [James et al.](https://stepjam.github.io/) The [PerceiverIO](https://arxiv.org/abs/2107.14795) code is adapted from [`perceiver-pytorch`](https://github.com/lucidrains/perceiver-pytorch) by [Phil Wang (lucidrains)](https://github.com/lucidrains). The optimizer is based on [this LAMB implementation](https://github.com/cybertronai/pytorch-lamb). See the corresponding licenses below.

<img src="https://peract.github.io/media/figures/sim_task.jpg" alt="drawing" style="width:100px;"/>

### Licenses
- [PerAct License (Apache 2.0)](https://github.com/peract/peract/blob/main/LICENSE) - Perceiver-Actor Transformer
- [ARM License](https://github.com/peract/peract/blob/main/ARM_LICENSE) - Voxelization and Data Preprocessing
- [YARR Licence (Apache 2.0)](https://github.com/stepjam/YARR/blob/main/LICENSE)
- [RLBench Licence](https://github.com/stepjam/RLBench/blob/master/LICENSE)
- [PyRep License (MIT)](https://github.com/stepjam/PyRep/blob/master/LICENSE)
- [Perceiver PyTorch License (MIT)](https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE)
- [LAMB License (MIT)](https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE)
- [CLIP License (MIT)](https://github.com/openai/CLIP/blob/main/LICENSE)


### Google Colab or Kaggle
Due to computation constraints of my GPU, some parts have to been run through support of a GPUcloud. Here, we can use [Google Colab](https://colab.research.google.com/) or [Kaggle](https://www.kaggle.com/) to use their GPUs to run training. For this reason, some dependencies or environment variables have to be set in order to make this notebook work. This option is set through the flag `colab`.

In [1]:
# Some setting for Colab/Kaggle vs. Computer
colab = False
data_origin = "handoversim" # Choose {"example", "real", "handoversim"}
visualize_image = True
visualize_voxel = True
visualize_pcd = True

#### Install Libraries

In [2]:
if colab:
    !pip install scipy ftfy regex tqdm torch open3d tensorboard natsort transformers git+https://github.com/openai/CLIP.git einops pyrender==0.1.45 trimesh==3.9.34 pycollada==0.6

### Clone Repo and Setup

Clone [https://github.com/peract/peract_colab.git](github.com/peract/peract_colab.git), if not cloned yet.

This repo contains barebones code from [`ARM`](https://github.com/stepjam/ARM), [`YARR`](https://github.com/stepjam/YARR), [`PyRep`](https://github.com/stepjam/PyRep), [`RLBench`](https://github.com/stepjam/RLBench) to get started with  PerAct without the actual [V-REP](https://www.coppeliarobotics.com/) simulator.

The repo also contains a pre-generated RLBench dataset of 10 expert demonstrations for the `open_drawer` task. This task has three variations: "open the top drawer", "open the middle drawer", and "open the bottom drawer"



In [3]:
import numpy as np
np.bool = np.bool_ # bad trick to fix numpy version issue :(
import os
import sys
import pickle
import random
from natsort import natsorted

import matplotlib.pyplot as plt
%matplotlib inline

# Set `PYOPENGL_PLATFORM=egl` for pyrender visualizations
os.environ["DISPLAY"] = ":0"
os.environ["PYOPENGL_PLATFORM"] = "egl"

if not colab:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" # Depends on your computer and available GPUs

In [4]:
### Depending on your workspace, you may already have this repository installe, otherwise clone once again
if not os.path.exists(os.path.join(os.getcwd(), 'peract_colab')):
    !git clone https://github.com/yuki1003/peract_colab.git

sys.path = [p for p in sys.path if '/peract/' not in p]

If you fork-off this repo, you might want to pull the latest changes.

In [None]:
!cd peract_colab && git pull origin master

Define some constants and setting variables.

The `BATCH_SIZE` is 1 to fit the model on a single GPU. But you can play around with the voxel sizes and Transformer layers to increase this.  

In the paper, we use `NUM_LATENTS=2048` by default, but smaller latents like `512` are also fine (see Appendix G).

In [6]:
# Constants
WORKSPACE_DIR = os.getcwd()

# Replaybuffer related constants
LOW_DIM_SIZE = 4    # 4 dimensions - proprioception: {gripper_open, left_finger_joint, right_finger_joint, timestep}
IMAGE_SIZE =  128  # 128x128 - if you want to use higher voxel resolutions like 200^3, you might want to regenerate the dataset with larger images
DEMO_AUGMENTATION_EVERY_N = 5 # Only select every n-th frame to use for replaybuffer from demo
ROTATION_RESOLUTION = 5 # degree increments per axis
TARGET_OBJ_KEYPOINTS=False # Real - (changed later)
TARGET_OBJ_USE_LAST_KP=False # Real - (changed later)
TARGET_OBJ_IS_AVAIL = False # HandoverSim - (changed later)

## Paths for collected data & Assigned tasks
if data_origin in ["real", "handoversim"]: ## Custom Data paths of task demos

    DEPTH_SCALE = 1000
    STOPPING_DELTA = 0.001

    # Choose Task
    TASK = 'handing_over_banana'
    EPISODES_FOLDER = f'{TASK}/all_variations/episodes'
    
    if colab:
        DATA_FOLDER = "/kaggle/input/peract-task-data"
    else:
        if data_origin == "handoversim":
            DATA_FOLDER = os.path.join(WORKSPACE_DIR, "task_data", "handoversim_v2") # Change directory
            # DATA_FOLDER = "/media/ywatabe/ESD-USB/task_data/handoversim_v4"
            CAMERAS = [f"view_{camera_i}" for camera_i in range(3)]#+ ['wrist']#,'shoulder']#,'wrist']  # TODO: Depends on available cameras from collected data
            SCENE_BOUNDS = [0.11, -0.5, 0.8, 1.11, 0.5, 1.8] #NOTE: must be 1m each
            TARGET_OBJ_IS_AVAIL = True # NOTE: Object locations are available for HandoverSim
        elif data_origin == "real":
            DATA_FOLDER = os.path.join(WORKSPACE_DIR, "task_data", "real")
            CAMERAS = ['front']
            SCENE_BOUNDS = [0.0, -0.5, -0.2, 1., 0.5, 0.8] #NOTE: must be 1m each
            TARGET_OBJ_KEYPOINTS=True # TODO: Choose based on task (commonly True)
        else:
            raise NotImplementedError
    

elif data_origin in ["example"]: ## Running the tutorial data from Colab PerAct_Tutorial
    DEPTH_SCALE = 2**24 -1
    STOPPING_DELTA = 0.05

    # Example Task provided by Tutorial
    TASK = 'open_drawer'
    EPISODES_FOLDER = f'colab_dataset/{TASK}/all_variations/episodes'

    DATA_FOLDER = os.path.join(WORKSPACE_DIR, 'peract_colab', 'data') # Change directory

    CAMERAS = ['front', 'left_shoulder', 'right_shoulder', 'wrist'] # TODO: Depends on available cameras from collected data

    TARGET_OBJ_USE_LAST_KP=True # TODO: Choose based on task (commonly True)
    SCENE_BOUNDS = [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized

else:
    raise NotImplementedError


# === Data Handling ===
# Construct paths for data processing
data_path = os.path.join(DATA_FOLDER, EPISODES_FOLDER)
EPISODE_FOLDER = 'episode%d'

In [None]:
## Train Settings
BATCH_SIZE = 1
TRAINING_ITERATIONS = 10000 # 2400
LEARNING_RATE = 0.001
LR_SCHEDULER = False
NUM_WARMUP_STEPS = 300 #FOR LR_SCHEDULER - I tend to see that after ~300 iterations, the losses seems to fluctuate
NUM_CYCLES = 1 #TODO: FOR LR_SCHEDULER - same as paper: https://github.com/peract/peract/blob/02fb87681c5a47be9dbf20141bedb836ee2d3ef9/agents/peract_bc/qattention_peract_bc_agent.py#L232
TRANSFORM_AUGMENTATION = True
RGB_AUGMENTATION = "None" #"partial" # "None", "full", "partial"
FILL_REPLAY_UNIFORM = False
VOXEL_SIZES = [100] # 100x100x100 voxels
NUM_LATENTS = 512 # PerceiverIO latents (This is a lower-dim space/features capturing the input data)

print("RUN PROPERTIES")
print(f"task: {TASK}")
print(f"transform augmentation: {TRANSFORM_AUGMENTATION}")
print(f"rgb augmentation: {RGB_AUGMENTATION}")
print(f"learning_rate: {LEARNING_RATE}")
print(f"cameras: {CAMERAS}")
print(f"fill_replay_uniform: {FILL_REPLAY_UNIFORM}")

# === Data Handling ===
# Construct paths for data processing
data_path = os.path.join(DATA_FOLDER, EPISODES_FOLDER)
if data_origin == "handoversim":
    data_path = os.path.join(DATA_FOLDER, "train_s1", EPISODES_FOLDER)
EPISODE_FOLDER = 'episode%d'

## TRAIN/VALIDATE/TEST - here VALIDATE = TEST & TEST = TEST_TASK
TOTAL_DEMOS = len(os.listdir(data_path))
## Determine Training/Testing set
TRAIN_FRAC = 0.8
if data_origin == "handoversim":
    TRAIN_FRAC = 1
SPLIT_POINT = int(TOTAL_DEMOS * TRAIN_FRAC) # #TODO OR CHOOSE FIXED
INDEXES = [int(episode_nr.replace("episode","")) for episode_nr in natsorted(os.listdir(data_path))] #np.arange(0, TOTAL_DEMOS)
shuffle = False
if shuffle:
    random_seed = 10
    random.Random(random_seed).shuffle(INDEXES) # IF you want to randomise the demos
TRAIN_INDEXES, TEST_INDEXES = np.split(INDEXES, [SPLIT_POINT]) # Split demos to train/test sets
TEST_INDEXES=TEST_INDEXES

print(f"DEMOS | Total #: {len(INDEXES)}, indexes: {INDEXES}")
print(f"Split index: {SPLIT_POINT}, shuffle: {shuffle}")
print(f"TRAIN | Total #: {len(TRAIN_INDEXES)}, indices: {TRAIN_INDEXES}")
print(f"TEST | Total #: {len(TEST_INDEXES)}, indices:", TEST_INDEXES)


## Testing (USED AT LAST CELL)
test_on_diff_set = True
TEST_TASK = TASK

print(f"Data TRAIN/TEST: {data_path}")

if not test_on_diff_set:
    test_data_path = data_path

else: # Choose a different task for testing
    test_task = TEST_TASK
    test_task_folder = f'{test_task}/all_variations/episodes'    
    test_data_path = os.path.join(DATA_FOLDER, "val_s1", test_task_folder)
    TEST_INDEXES = [int(episode_nr.replace("episode","")) for episode_nr in natsorted(os.listdir(test_data_path))]
    print(f"TEST_DIFF | Total #: {TEST_INDEXES}")
    print(f"Data TEST (updated): {data_path}")



Add `peract_colab` to the system path and make a directory for storing the replay buffer.  For now, we will store the replay buffer on disk to avoid memory issues with putting everthing on RAM.

In [8]:
# Check if modules in peract_colab repository are recognized.
try: # test import
    from rlbench.utils import get_stored_demo
except ImportError as error_message:
    print(error_message)
    print("Adding peract_colab repository to system path.")
    sys.path.append('peract_colab')
  

In [None]:
import shutil
train_replay_storage_dir = os.path.join(WORKSPACE_DIR,'replay_train')
if os.path.exists(train_replay_storage_dir):
    print(f"Emptying {train_replay_storage_dir}")
    shutil.rmtree(train_replay_storage_dir)
if not os.path.exists(train_replay_storage_dir):
    print(f"Could not find {train_replay_storage_dir}, creating directory.")
    os.mkdir(train_replay_storage_dir)

test_replay_storage_dir = os.path.join(WORKSPACE_DIR,'replay_test')
if os.path.exists(test_replay_storage_dir):
    print(f"Emptying {test_replay_storage_dir}")
    shutil.rmtree(test_replay_storage_dir)
if not os.path.exists(test_replay_storage_dir):
    print(f"Could not find {test_replay_storage_dir}, creating directory.")
    os.mkdir(test_replay_storage_dir)

## Data Loading & Preprocessing

An expert demonstration recorded at ~20Hz contains 100s of individual timesteps in a sequence. Each timestep contains observations recorded from 4 calibrated cameras (front, left_shoulder, right_shoulder, and wrist) and other proprioception sensors. "Calibrated" means we know the extrinsics and intrinsics.

Let's take a look at what these observations look like. Play around with different `episode_idx_to_visualize` and timesteps `ts`.

In [None]:
from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs

if visualize_image:
    # what to visualize
    episode_idx_to_visualize = INDEXES[0] # out of 10 demos
    ts = 21 # timestep out of total timesteps

    # get demo
    demo = get_stored_demo(data_path=data_path,
                        index=episode_idx_to_visualize,
                        cameras=CAMERAS,
                        depth_scale=DEPTH_SCALE)

    # extract obs at timestep
    obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
    gripper_pose = demo[ts].gripper_pose
    gripper_open = demo[ts].gripper_open

    # total timesteps in demo
    print(f"Demo {episode_idx_to_visualize} | {len(demo._observations)} total steps\n")
    print(f"The gripper is at: {gripper_pose[:3]}")
    print(f"gripper_open: {gripper_open}")

    # plot rgb and depth at timestep
    fig = plt.figure(figsize=(20, 10))
    rows, cols = 2, len(CAMERAS)

    plot_idx = 1
    for camera in CAMERAS:
        # rgb
        rgb_name = "%s_%s" % (camera, 'rgb')
        rgb = np.transpose(obs_dict[rgb_name], (1, 2, 0))
        fig.add_subplot(rows, cols, plot_idx)
        plt.imshow(rgb)
        plt.axis('off')
        plt.title("%s_rgb | step %s" % (camera, ts))

        # depth
        depth_name = "%s_%s" % (camera, 'depth')
        # depth = np.transpose(obs_dict[depth_name], (1, 2, 0)).reshape(IMAGE_SIZE, IMAGE_SIZE)
        depth = np.transpose(obs_dict[depth_name], (1, 2, 0))
        fig.add_subplot(rows, cols, plot_idx+len(CAMERAS))
        plt.imshow(depth)
        plt.axis('off')
        plt.title("%s_depth | step %s" % (camera, ts))

        # # mask
        # mask_name = "%s_%s" % (camera, 'mask')
        # # depth = np.transpose(obs_dict[depth_name], (1, 2, 0)).reshape(IMAGE_SIZE, IMAGE_SIZE)
        # mask = np.transpose(obs_dict[mask_name], (1, 2, 0))
        # fig.add_subplot(rows, cols, plot_idx+len(CAMERAS))
        # plt.imshow(mask)
        # plt.axis('off')
        # plt.title("%s_mask | step %s" % (camera, ts))

        plot_idx += 1

    plt.show()

### Create Replay Buffer

As described in **Section 3.4** of the paper, PerAct is trained with discrete-time input-action tuples from a dataset of demonstrations. These tuples are stored in a Replay Buffer following the [`ARM`](https://github.com/stepjam/ARM) codebase. You can use your own storage format, but here we follow `ARM` to benchmark against baselines and other methods by James et al.

This replay buffer stores **<observation, language goal, keyframe action>** tuples sampled from demonstrations.

In [11]:
from arm.replay_buffer import create_replay


train_replay_buffer = create_replay(batch_size=BATCH_SIZE,
                                    timesteps=1,
                                    save_dir=train_replay_storage_dir,
                                    cameras=CAMERAS,
                                    voxel_sizes=VOXEL_SIZES,
                                    image_size=IMAGE_SIZE,
                                    low_dim_size=LOW_DIM_SIZE)

test_replay_buffer = create_replay(batch_size=BATCH_SIZE,
                                   timesteps=1,
                                   save_dir=test_replay_storage_dir,
                                   cameras=CAMERAS,
                                   voxel_sizes=VOXEL_SIZES,
                                   image_size=IMAGE_SIZE,
                                   low_dim_size=LOW_DIM_SIZE)


### Fill Replay with Demos

#### Keyframe Extraction

Instead of directly trying to predict every action in the demonstration, which could be very noisy and inefficient, we extract keyframe actions that capture **bottleneck** poses \[[James et al.](https://arxiv.org/abs/2105.14829)\]. This extraction is done with a simple heuristic: an action is a keyframe action if (1) the joint-velocities are near zero and (2) the gripper open state has not changed. Then every timestep in the demonstration can be cast as a predict "the next (best) keyframe" classification task, like the orange points in this figure:  

<div>
<img src="https://peract.github.io/media/figures/keypoints.jpg" alt="drawing"  width="300"/>
</div>

Let's take a look at what these keyframe actions look like.  



In [None]:
from arm.demo import _keypoint_discovery, _keypoint_discovery_available, _target_object_discovery

# Display for every demo (i.e. episode)
# Display Demo 1 like last time
if visualize_image:
    for i in INDEXES:
        episode_idx_to_visualize = i#INDEXES#746#INDEXES[0]#746#146#INDEXES[0]

        demo = get_stored_demo(data_path=data_path,
                                index=episode_idx_to_visualize,
                                cameras=CAMERAS,
                                depth_scale=DEPTH_SCALE)

        # total timesteps
        print("Demo %s | %s total steps" % (episode_idx_to_visualize, len(demo._observations)))

        # use the heuristic to extract keyframes (aka keypoints) NOTE: the absolute-difference (per joint) that was at STOPPING_DELTA.
        

        if data_origin in ["example"]:
            episode_keypoints = _keypoint_discovery(demo, stopping_delta=STOPPING_DELTA)
            episode_target_object = _target_object_discovery(demo)
        elif data_origin in ["real"]:
            episode_target_object = _target_object_discovery(demo, keypoints=TARGET_OBJ_KEYPOINTS, stopping_delta=STOPPING_DELTA, last_kp=TARGET_OBJ_USE_LAST_KP)
        elif data_origin in ["handoversim"]:
            episode_keypoints = _keypoint_discovery(demo, stopping_delta=STOPPING_DELTA)#_keypoint_discovery_available(demo, 0.3)
            episode_target_object = _target_object_discovery(demo, is_available=True)

        # visualize rgb observations from these keyframes
        for kp_idx, kp in enumerate(episode_keypoints):
            
            obs_dict = extract_obs(demo._observations[kp], CAMERAS, t=kp)
            gripper_pose = demo[kp].gripper_pose
            gripper_open = demo[kp].gripper_open
            target_object = episode_target_object[kp]

            # plot rgb and depth at timestep
            fig = plt.figure(figsize=(10, 5))
            rows, cols = 1, len(CAMERAS)

            plot_idx = 1

            for camera in CAMERAS:

                rgb_name = "%s_%s" % (camera, 'rgb')
                rgb = np.transpose(obs_dict[rgb_name], (1, 2, 0))
                fig.add_subplot(rows, cols, plot_idx)
                plt.imshow(rgb)
                plt.axis('off')
                fig.suptitle("step %s | \n gripper at %s \n gripper open %s \n object at %s" % (kp, gripper_pose[:3], gripper_open, target_object[:3]))

                plot_idx += 1
            
            fig.tight_layout()
            fig.subplots_adjust(top=0.88)

            plt.show()

Notice that the motion-planner used to generate demonstrations might take various paths to execute the "opening" motion, but all paths strictly pass through these **bottleneck** poses, since that's how the expert demonstrations were collected in RLBench. This essentially circuments the issue of training directly on randomized motion paths from sampling-based motion planners, which can be quite noisy to learn from for end-to-end methods.

#### Fill Replay

Fill the replay buffer

Load a [pre-trained CLIP model](https://arxiv.org/abs/2103.00020) to extract language features. You can probably swap this with other language models, but CLIP's language features were trained to be aligned with image features, which might give it a multi-modal edge over text-only models 🤷

Finally fill the replay buffer.

In [None]:
import torch
import clip

from arm.replay_buffer import fill_replay, uniform_fill_replay
from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("RN50", device=device) # CLIP-ResNet50

# if FILL_REPLAY_UNIFORM:
#     fill_replay = uniform_fill_replay

print("-- Train Buffer --")
fill_replay( # fill_replay_copy_with_crop_from_approach
            data_path=data_path,
            episode_folder=EPISODE_FOLDER,
            replay=train_replay_buffer,
            # start_idx=0,
            # num_demos=NUM_DEMOS,
            d_indexes=TRAIN_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=True,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    
print("-- Test Buffer --")
fill_replay( # fill_replay_copy_with_crop_from_approach
            data_path=test_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=test_replay_buffer,
            # start_idx=start_idx,
            # num_demos=num_demos,
            d_indexes=TEST_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=True,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )

# delete the CLIP model since we have already extracted language features
del clip_model

# wrap buffer with PyTorch dataset and make iterator
train_wrapped_replay = PyTorchReplayBuffer(train_replay_buffer)
train_dataset = train_wrapped_replay.dataset()
train_data_iter = iter(train_dataset)

test_wrapped_replay = PyTorchReplayBuffer(test_replay_buffer)
test_dataset = test_wrapped_replay.dataset()
test_data_iter = iter(test_dataset)

## Training PerAct

### Voxelization

Now we define a class for voxelizing calibrated RGB-D observations following [C2FARM \(James et al.\)](https://arxiv.org/pdf/2106.12534.pdf)

The input to the voxelizer is:
- Flattened RGB images
- Flattened global-coordinate point clouds
- Scene bounds in metric units that specify the volume to be voxelized

The output is a 10-dimensional voxel grid (see Appendix B for details).

Let's try to use this voxelizer on observation samples from the replay buffer.

But first, lets define some helper functions to normalize and format RGB and pointcloud input:

The rgb and pointcloud inputs have to be flattened before feeding them into the voxelizer:

#### Voxel: Original (no-augmentation vs. translation/rotation-augmentation vs. RGB-augmentation)

In [None]:
import torch
import torchvision.transforms as T

from agent.utils import _preprocess_inputs, pcd_bbox
from agent.voxel_grid import VoxelGrid
from arm.utils import visualise_voxel, point_to_voxel_index, voxel_index_to_point
from arm.augmentation import apply_se3_augmentation, perturb_se3

RGB_AUGMENTATION = "None"

if visualize_voxel or visualize_pcd:

    # initialize voxelizer
    vox_grid = VoxelGrid(
        coord_bounds=SCENE_BOUNDS,
        voxel_size=VOXEL_SIZES[0],
        device=device,
        batch_size=BATCH_SIZE,
        feature_size=3,
        max_num_coords=np.prod([IMAGE_SIZE, IMAGE_SIZE]) * len(CAMERAS),
    )

    # sample from dataset
    batch = next(train_data_iter)
    lang_goal = batch['lang_goal'][0][0][0]
    print(batch["view_0_rgb"].shape)
    batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
    print(batch["view_0_rgb"].shape)

    # preprocess observations
    rgbs_pcds, _ = _preprocess_inputs(batch, CAMERAS)
    pcds = [rp[1] for rp in rgbs_pcds]

    # batch_size
    bs = rgbs_pcds[0][0].shape[0]

    # metric scene bounds
    bounds = torch.tensor(SCENE_BOUNDS,device=device).unsqueeze(0)

    # identity matrix
    identity_4x4 = torch.eye(4).unsqueeze(0).repeat(bs, 1, 1).to(device=device)
    
    # sample
    action_trans = batch['trans_action_indicies'][:, -1, :3].int()
    action_rot_grip = batch['rot_grip_action_indicies'][:, -1].int()
    action_ignore_collisions = batch['ignore_collisions'][:, -1].int()
    action_gripper_pose = batch['gripper_pose'][:, -1]
    gripper_pose = batch['gripper_state'][:, -1]
    object_pose = batch['object_state'][:, -1]
    lang_goal_embs = batch['lang_goal_embs'][:, -1].float()

    # Get the pose voxel location
    gripper_pcd = pcd_bbox(gripper_pose, 1, VOXEL_SIZES[0], bounds, bs, device)
    object_pcd = pcd_bbox(object_pose, 1, VOXEL_SIZES[0], bounds, bs, device)
    
    if RGB_AUGMENTATION.lower() == "partial":
        
        gripper_bbox_pcd = pcd_bbox(gripper_pose, 10, VOXEL_SIZES[0], bounds, bs, device)
        object_bbox_pcd = pcd_bbox(object_pose, 10, VOXEL_SIZES[0], bounds, bs, device)
        
    if TRANSFORM_AUGMENTATION:
        action_trans, \
        action_rot_grip, \
        pcds, \
        trans_shift_4x4, \
        rot_shift_4x4, \
        action_gripper_4x4 = apply_se3_augmentation(pcds,
                                                    action_gripper_pose,
                                                    action_trans,
                                                    action_rot_grip,
                                                    bounds,
                                                    0,  
                                                    [0.125,0.125,0.125],#[0.125,0.125,0.125],#[0.125,0.125,0.125],#[0.125, 0.125, 0.125], # Translation Range
                                                    [0.0,0.0,45.0],#[0.0, 0.0, 45.0], # Rotation Range # was [0.0, 0.0, 45.0]
                                                    5, # Increments of rotation
                                                    VOXEL_SIZES[0],
                                                    ROTATION_RESOLUTION, # Resolution to fix to
                                                    device)
        
        gripper_pcd = perturb_se3(gripper_pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds)
        object_pcd = perturb_se3(object_pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds)
        
        if RGB_AUGMENTATION.lower() == "partial":
            gripper_bbox_pcd = perturb_se3(gripper_bbox_pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds)
            object_bbox_pcd = perturb_se3(object_bbox_pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds)

    # flatten observations
    pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(bs, -1, 3) for p in pcds], 1)

    rgb = [rp[0] for rp in rgbs_pcds] # Loop per camera
    
    feat_size = rgb[0].shape[1]
    flat_imag_features = torch.cat(
        [p.permute(0, 2, 3, 1).reshape(bs, -1, feat_size) for p in rgb], 1)
    
    # voxelize!
    voxel_grid = vox_grid.coords_to_bounding_voxel_grid(pcd_flat,
                                                        flat_imag_features,
                                                        coord_bounds=bounds)
    # swap to channels fist
    voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()#.cpu().numpy() #(B, (RGB, point_robot_cs, occupancy, position_index), V_depth, V_height, V_width)
    
    
    if RGB_AUGMENTATION.lower() in ["full", "partial"]:
        transform = T.Compose([
            T.ColorJitter(brightness=10.2, contrast=10.2, saturation=10.2, hue=0.1),  # Create transformation for brightness/contrast/saturation/hue
        ])
        image_features_scaled = torch.cat([(rgb_batch + 1) / 2 for rgb_batch in rgb]) # Conversion [-1,1] -> [0, 1] & Split batches
        image_features_scaled_augmented = transform(image_features_scaled) # Do transformation
        rgb_unscaled = image_features_scaled_augmented * 2 - 1

        rgb_augmented = [rgb_unscaled[i:i + bs] for i in range(0, len(rgb_unscaled), bs)] # Merge back to batches
        flat_imag_features_augmented = torch.cat(
            [p.permute(0, 2, 3, 1).reshape(bs, -1, feat_size) for p in rgb_augmented], 1)
        
        voxel_grid_augmented = vox_grid.coords_to_bounding_voxel_grid(pcd_flat,
                                                                    coord_features=flat_imag_features_augmented,
                                                                    coord_bounds=bounds)

        # swap to channels fist
        voxel_grid_augmented = voxel_grid_augmented.permute(0, 4, 1, 2, 3).detach()#.cpu().numpy() #(B, (point_robot_cs, RGB, occupancy, position_index), V_depth, V_height, V_width) NOTE: ORDER IS DIFFERENT?
        
        if RGB_AUGMENTATION.lower() == "partial":

            bounding_box_pcd = [torch.cat((gripper_bbox_pcd[0], object_bbox_pcd[0]), dim=2)]
            bounding_box_pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(bs, -1, 3) for p in bounding_box_pcd], 1)
            
            bounding_box_indices = []
            for b, bounding_box_pcd_flat_b in enumerate(bounding_box_pcd_flat):
                bounding_box_indices_b = point_to_voxel_index(bounding_box_pcd_flat_b.cpu().numpy(),
                                                            VOXEL_SIZES[0],
                                                            bounds[0].cpu().numpy())
                bounding_box_indices.append(bounding_box_indices_b)
                voxel_grid_augmented[b, 3:6, 
                                        bounding_box_indices_b[:, 0], 
                                        bounding_box_indices_b[:, 1], 
                                        bounding_box_indices_b[:, 2]] = voxel_grid[b, 3:6, 
                                                                                bounding_box_indices_b[:, 0], 
                                                                                bounding_box_indices_b[:, 1], 
                                                                                bounding_box_indices_b[:, 2]]

    # expert action voxel indicies and coord
    vis_gt_coords = action_trans[:, :3].int().detach().cpu().numpy()
    gt_coord_pcd = np.array([voxel_index_to_point(vis_gt_coord, VOXEL_SIZES[0], bounds[0].cpu().numpy()) 
                             for vis_gt_coord in vis_gt_coords])

    # Current gripper and (estimated) target object location
    # gripper_object_pcd = [torch.cat((gripper_pcd[0], object_pcd[0]), dim=2)]
    gripper_object_pcd = [object_pcd[0]]
    gripper_object_pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(bs, -1, 3) for p in gripper_object_pcd], 1)
    
    gripper_object_indices = []
    for b, gripper_object_pcd_flat_b in enumerate(gripper_object_pcd_flat):
        gripper_object_indices_b = point_to_voxel_index(gripper_object_pcd_flat_b.cpu().numpy(),
                                                    VOXEL_SIZES[0],
                                                    bounds[0].cpu().numpy())
        gripper_object_indices.append(gripper_object_indices_b)

    for b in range(bs):
        print(f"Expert Action Voxel Indices: {vis_gt_coords[b]}")
        gripper_object_indices_b = gripper_object_indices[b]
        idx_gripper_object = int(len(gripper_object_indices_b)/2)
        print(f"Gripper Voxel Indices: {gripper_object_indices_b[:idx_gripper_object]}")
        print(f"Object (estimated) Voxel Indices: {gripper_object_indices_b[idx_gripper_object:]}")

##### Visualizing voxel

In [None]:
if not colab and visualize_voxel: # Does NOT work within colab
    # render voxel grid with expert action (blue)
    #@markdown #### Show voxel grid and expert action (blue)
    #@markdown Adjust `rotation_amount` to change the camera yaw angle for rendering.
    # rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

    gripper_pose_indices = [None]        
    # gripper_pose_indices = [gripper_object_indices[0]] if RGB_AUGMENTATION == "full" else [bounding_box_indices[0]] # bounding box

    angle = 0

    rendered_img_0 = visualise_voxel(voxel_grid[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                highlight_alpha=1.0,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(0-angle),
                                perspective = False)

    rendered_img_90 = visualise_voxel(voxel_grid[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                highlight_alpha=1.0,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(90-angle),
                                perspective = False)
    
    rendered_img_180 = visualise_voxel(voxel_grid[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                highlight_alpha=1.0,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(180-angle),
                                perspective = False)
    
    rendered_img_270 = visualise_voxel(voxel_grid[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                highlight_alpha=1.0,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(270-angle),
                                perspective = False)
    
    rendered_img_0_persp = visualise_voxel(voxel_grid[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                highlight_alpha=1.0,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(0-angle))
    
    rendered_img_side_persp = visualise_voxel(voxel_grid[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                highlight_alpha=1.0,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(45-angle))
    
                    
    
    fig = plt.figure(figsize=(20, 15))
    fig.add_subplot(3, 2, 1)
    plt.imshow(rendered_img_0)
    plt.title("0-degree view")
    fig.add_subplot(3, 2, 2)
    plt.imshow(rendered_img_90)
    plt.title("90-degree view")
    fig.add_subplot(3, 2, 3)
    plt.imshow(rendered_img_180)
    plt.title("180-degree view")
    fig.add_subplot(3, 2, 4)
    plt.imshow(rendered_img_270)
    plt.title("270-degree view")
    fig.add_subplot(3, 2, 5)
    plt.imshow(rendered_img_0_persp)
    plt.axis('off')
    plt.title("00-degree view")
    fig.add_subplot(3, 2, 6)
    plt.imshow(rendered_img_side_persp)
    plt.axis('off')
    plt.title("side view")

    print(f"Lang goal: {lang_goal}")

In [16]:
if not colab and visualize_voxel and (RGB_AUGMENTATION.lower() in ["full", "partial"]): # Does NOT work within colab
    # render voxel grid with expert action (blue)
    #@markdown #### Show voxel grid and expert action (blue)
    #@markdown Adjust `rotation_amount` to change the camera yaw angle for rendering.
    # rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

    gripper_pose_indices = [None]
    # gripper_pose_indices = [gripper_object_indices[0]] # gripper and object location highlights

    rendered_img_0 = visualise_voxel(voxel_grid_augmented[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                # highlight_alpha=0.3,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(0))

    rendered_img_side = visualise_voxel(voxel_grid_augmented[0].cpu().numpy(),
                                None,
                                gripper_pose_indices[0],
                                vis_gt_coords[0],
                                # highlight_alpha=0.3,
                                voxel_size=0.03,
                                rotation_amount=np.deg2rad(45))
                    
    fig = plt.figure(figsize=(20, 15))
    fig.add_subplot(1, 2, 1)
    plt.imshow(rendered_img_0)
    plt.axis('off')
    plt.title("Front view")
    fig.add_subplot(1, 2, 2)
    plt.imshow(rendered_img_side)
    plt.axis('off')
    plt.title("Side view")

    print(f"Lang goal: {lang_goal}")

In [None]:
if visualize_pcd:

    import open3d as o3d
    
    def create_geometry_at_points(points, radius = 0.03):
        geometries = o3d.geometry.TriangleMesh()
        for point in points:
            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius) #create a small sphere to represent point
            sphere.translate(point) #translate this sphere to point
            geometries += sphere
        geometries.paint_uniform_color([1.0, 1.0, 0.0])
        return geometries

    def points_in_scene_bounds(points, scene_bounds=np.array(SCENE_BOUNDS)):
        x_min, y_min, z_min, x_max, y_max, z_max = scene_bounds
        mask = (
            (points[:, 0] >= x_min) & (points[:, 0] <= x_max) &  # x bounds
            (points[:, 1] >= y_min) & (points[:, 1] <= y_max) &  # y bounds
            (points[:, 2] >= z_min) & (points[:, 2] <= z_max)    # z bounds
        )
        return points[mask]
    
    # General point cloud from camera(s)
    point_cloud_o3d = o3d.geometry.PointCloud()
    pcd_np = pcd_flat[0].cpu().numpy()
    pcd_np_filtered = points_in_scene_bounds(pcd_np)
    point_cloud_o3d.points = o3d.utility.Vector3dVector(pcd_np_filtered)

    # Ground truth to next action
    gt_coord_cloud_o3d = o3d.geometry.PointCloud()
    gt_coord_cloud_o3d.points = o3d.utility.Vector3dVector(gt_coord_pcd)
    gt_coord_o3d = create_geometry_at_points(gt_coord_cloud_o3d.points)

    # Gripper and object location
    gripper_object_coord_cloud_o3d = o3d.geometry.PointCloud()
    gripper_object_pcd_np = gripper_object_pcd_flat[0].cpu().numpy()
    gripper_object_coord_cloud_o3d.points = o3d.utility.Vector3dVector(np.array(gripper_object_pcd_np))
    gripper_object_o3d = create_geometry_at_points(gripper_object_coord_cloud_o3d.points)

    # Gripper and object bbox point cloud
    gripper_object_bbox_cloud_o3d = o3d.geometry.PointCloud()
    # gripper_object_bbox_pcd_np = bounding_box_pcd_flat[0].cpu().numpy()
    # gripper_object_coord_cloud_o3d.points = o3d.utility.Vector3dVector(gripper_object_bbox_pcd_np)
    # gripper_object_coord_cloud_o3d.colors = o3d.utility.Vector3dVector(gripper_object_bbox_pcd_np)

    # Used for fixed axis
    fixed_axis_cloud_o3d = o3d.geometry.PointCloud()
    fixed_axis_pcd_np = np.array([[-2,-2,-2],
                                  [2, 2, 2]])
    fixed_axis_cloud_o3d.points = o3d.utility.Vector3dVector(fixed_axis_pcd_np)

    o3d.visualization.draw_plotly([point_cloud_o3d,
                                   gt_coord_o3d,
                                #    gripper_object_coord_cloud_o3d,
                                #    gripper_object_o3d,
                                   fixed_axis_cloud_o3d,
                                   ])

    # # camera position: 0.3732, -0.7225, 0.4119

This visualization shows a voxel grid of size 100x100x100 = 1 million voxels, and one expert keyframe action (blue voxel). These samples are what PerAct is trained with. Given a language goal and voxel grid, we train a detector to detect the next best action with supervised learning.

### PerceiverIO

Now we can start implementing the actual Transformer backbone of PerAct from **Section 3.3**.

The input grid is 100×100×100 = 1 million voxels. If we extract 5×5×5 patches, the input is still 20×20×20 = 8000 embeddings long. This sequence is way too long for a standard Transformer with O(n^2) self-attention connections. So we use the [PerceiverIO architecture](https://arxiv.org/abs/2107.14795) instead.  

Perceiver uses a small set of **latent vectors** to encode the input. These latent vectors are randomly initialized and trained end-to-end. This approach decouples the depth of the Transformer self-attention layers from the dimensionality of the input space, which allows us train PerAct on very large input voxel grids. We can potentially scale the input to 200^3 voxels without increasing self-attention layer parameters.

Refer to **Appendix B** in the paper for additional details.

<div>
<img src="https://peract.github.io/media/figures/perceiver.png" alt="drawing"  width="600"/>
</div>

In [None]:
# initialize PerceiverIO Transformer
from agent.perceiver_io import PerceiverIO


perceiver_encoder = PerceiverIO(
    depth=6,
    iterations=1,
    voxel_size=VOXEL_SIZES[0],
    initial_dim=3 + 3 + 1 + 3,
    low_dim_size=4,
    layer=0,
    num_rotation_classes=72,
    num_grip_classes=2,
    num_collision_classes=2,
    num_latents=NUM_LATENTS,
    latent_dim=512,
    cross_heads=1,
    latent_heads=8,
    cross_dim_head=64,
    latent_dim_head=64,
    weight_tie_layers=False,
    activation='lrelu',
    input_dropout=0.1,
    attn_dropout=0.1,
    decoder_dropout=0.0,
    voxel_patch_size=5,
    voxel_patch_stride=5,
    final_dim=64,
)

### Q Functions

Finally we put everything together to make PerAct's Q-Functions.  

This module voxelizes RGB-D input, encodes per-voxel features, and predicts discretized actions.  

<div>
<img src="https://peract.github.io/media/figures/arch.png" alt="drawing"  width="1000"/>
</div>

### PerAct Agent

Let's initialize PerAct and define an update function for the training loop.

The keyframe actions used for supervision are represented as one-hot vectors. Then we use cross-entropy loss to train PerAct, just like a standard classifier. This training method is also closely related to [Energy-Based Models (EBMs)](https://arxiv.org/abs/2109.00137).

In [None]:
# initialize PerceiverActor
from agent.peract_agent import PerceiverActorAgent


peract_agent = PerceiverActorAgent(
    coordinate_bounds=SCENE_BOUNDS,
    perceiver_encoder=perceiver_encoder,
    camera_names=CAMERAS,
    batch_size=BATCH_SIZE,
    voxel_size=VOXEL_SIZES[0],
    voxel_feature_size=3,
    num_rotation_classes=72,
    rotation_resolution=ROTATION_RESOLUTION,
    training_iterations = TRAINING_ITERATIONS,
    lr=LEARNING_RATE,
    lr_scheduler=LR_SCHEDULER,
    num_warmup_steps = NUM_WARMUP_STEPS,
    num_cycles = NUM_CYCLES,
    image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
    lambda_weight_l2=0.000001,
    transform_augmentation=TRANSFORM_AUGMENTATION,
    rgb_augmentation=RGB_AUGMENTATION,
    optimizer_type='lamb',
)
peract_agent.build(training=True, device=device)

### Training Loop

The final training loop samples data from the replay buffer and trains the agent with supervised learning. 2400 iterations should take ~130mins.

❗2400 iterations is probably not enough for robust performance, so you might see some weird predictions. Training for longer periods, particularly with data augmentation (see **Appendix E**), will improve performance. But we will stick with this for now to avoid Colab timeouts.

#### Using Tensorboard in PyTorch

Let’s now try using TensorBoard with PyTorch! Before logging anything, we need to create a SummaryWriter instance. Writer will output to `logir = ./runs/` directory by default.

In [None]:
# Creating TensorBoard log

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [None]:
# # Load the TensorBoard notebook extension

# if not colab:
#     %load_ext tensorboard
# else:
#     !load_ext tensorboard

In [None]:
import time
import json
from datetime import datetime

LOG_FREQ = 10
SAVE_MODELS = True
SAVE_MODEL_FREQ = 50
calc_test_loss = True

# Misc
train_loss = 1e8
test_loss = 1e8

# Directories where to save the best models for train/test
model_run_time = datetime.now()
model_save_dir = os.path.join(WORKSPACE_DIR,"outputs", "models", TASK, model_run_time.strftime("%Y-%m-%d_%H-%M"))
model_save_dir_best_train = os.path.join(model_save_dir, "best_model_train")
model_save_dir_best_test = os.path.join(model_save_dir, "best_model_test")
metrics_save_path = os.path.join(model_save_dir, "training_metrics.json")  # JSON file to save metrics

# Create directories
if not os.path.exists(model_save_dir_best_train):
    print(f"Could not find {model_save_dir_best_train}, creating directory.")
    os.makedirs(model_save_dir_best_train)
if not os.path.exists(model_save_dir_best_test):
    print(f"Could not find {model_save_dir_best_test}, creating directory.")
    os.makedirs(model_save_dir_best_test)

start_time = time.time()

# Initialize metrics dictionary
metrics = {
    "train": [],
    "test": []
}

for iteration in range(TRAINING_ITERATIONS):
    batch = next(train_data_iter)
    batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
    update_dict = peract_agent.update(iteration, batch) # Here backprop == True: for training reaons, hence training_loss == total_loss

    if iteration % LOG_FREQ == 0:
        elapsed_time = (time.time() - start_time) / 60.0

        # Log training metrics
        train_metrics = {
            "iteration": iteration,
            "learning_rate": update_dict['learning_rate'],
            "total_loss": update_dict['total_loss'],
            "trans_loss": update_dict['trans_loss'],
            "rot_loss": update_dict['rot_loss'],
            "col_loss": update_dict['col_loss'],
            "elapsed_time": elapsed_time
        }
        metrics["train"].append(train_metrics)

        writer.add_scalar("Learning Rate", update_dict['learning_rate'], iteration)
            
        writer.add_scalar("Loss/train", update_dict['total_loss'], iteration)
        writer.add_scalar("Trans Loss/train", update_dict['trans_loss'], iteration)
        writer.add_scalar("Rot Loss/train", update_dict['rot_loss'], iteration)
        writer.add_scalar("Collision Loss/train", update_dict['col_loss'], iteration)

        if calc_test_loss:
            batch = next(test_data_iter)
            batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
            test_update_dict = peract_agent.update(iteration, batch, backprop=False) # Here backprop == False: for evaluation, hence test_loss == total_loss

            # Log test metrics
            test_metrics = {
                "iteration": iteration,
                "total_loss": test_update_dict['total_loss'],
                "trans_loss": test_update_dict['trans_loss'],
                "rot_loss": test_update_dict['rot_loss'],
                "col_loss": test_update_dict['col_loss']
            }
            metrics["test"].append(test_metrics)
            
            writer.add_scalar("Loss/test", test_update_dict['total_loss'], iteration)
            writer.add_scalar("Trans Loss/test", test_update_dict['trans_loss'], iteration)
            writer.add_scalar("Rot Loss/test", test_update_dict['rot_loss'], iteration)
            writer.add_scalar("Collision Loss", test_update_dict['col_loss'], iteration)

            print("Iteration: %d/%d | Learning Rate: %f | Train Loss [tot,trans,rot,col]: [%0.2f, %0.2f, %0.2f, %0.2f] | Test Loss [tot,trans,rot,col]: [%0.2f, %0.2f, %0.2f, %0.2f] | Elapsed Time: %0.2f mins"\
                   % (iteration, TRAINING_ITERATIONS, 
                      update_dict['learning_rate'],
                      update_dict['total_loss'], update_dict['trans_loss'], update_dict['rot_loss'], update_dict['col_loss'],
                      test_update_dict['total_loss'], test_update_dict['trans_loss'], test_update_dict['rot_loss'], test_update_dict['col_loss'], 
                      elapsed_time))
        else:
            print("Iteration: %d/%d | Learning Rate: %f| Train Loss [tot,trans,rot,col]: [%0.2f, %0.2f, %0.2f, %0.2f] | Elapsed Time: %0.2f mins"\
                   % (iteration, TRAINING_ITERATIONS, 
                      update_dict['learning_rate'],
                      update_dict['total_loss'], update_dict['trans_loss'], update_dict['rot_loss'], update_dict['col_loss'],
                      elapsed_time))
            
    if (SAVE_MODELS == True):
        if iteration % SAVE_MODEL_FREQ == 0:
            
            # Only save the best if better
            if update_dict['total_loss'] < train_loss:
                print("Saving Best Model - Train")
                train_loss = update_dict['total_loss']
                peract_agent.save_weights(model_save_dir_best_train)
            if test_update_dict['total_loss'] < test_loss:
                print("Saving Best Model - Test")
                test_loss = test_update_dict['total_loss']
                peract_agent.save_weights(model_save_dir_best_test)

# Save the last checkpoint
model_save_dir_last = os.path.join(model_save_dir, "last")

if not os.path.exists(model_save_dir_last):
    print(f"Could not find {model_save_dir_last}, creating directory.")
    os.makedirs(model_save_dir_last)

print("Saving Model - Last")
peract_agent.save_weights(model_save_dir_last)

# Save metrics to JSON file
with open(metrics_save_path, 'w') as f:
    json.dump(metrics, f, indent=4)

print(f"Training metrics saved to {metrics_save_path}")

writer.flush()

In [None]:
# # Plot training and test (val) losses

# if not colab:
#     %tensorboard --logdir=runs
# else:
#     !tensorboard --logdir=runs

## Inference and Visualization

Let's see how PerAct does on held-out test data.  

PerAct should be evaluated in simulation on scenes with randomized object poses and object instances. But this Colab notebook doesn't support the V-REP simulator (for RLBench tasks). So for now we will do inference on a static test dataset.

In [None]:
# from arm.utils import visualise_voxel, discrete_euler_to_quaternion, get_gripper_render_pose
# from scipy.spatial.transform import Rotation as R

# # create_another_test_set = True
# create_another_test_set = False


# if create_another_test_set:
#     clip_model, preprocess = clip.load("RN50", device=device) # CLIP-ResNet50
#     test_task = 'go_to_apple_test'
#     test_task_folder = f'{test_task}/all_variations/episodes'
#     data_path = os.path.join(DATA_FOLDER, test_task_folder)

#     test_replay_buffer = create_replay(batch_size=BATCH_SIZE,
#                                     timesteps=1,
#                                     save_dir=test_replay_storage_dir,
#                                     cameras=CAMERAS,
#                                     voxel_sizes=VOXEL_SIZES)

#     print("-- Test Buffer --")
#     fill_replay(
#                 data_path=data_path,
#                 replay=test_replay_buffer,
# #                 start_idx=0,
# #                 num_demos=4,
#                 demo_augmentation=True,
#                 demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
#                 cameras=CAMERAS,
#                 rlbench_scene_bounds=SCENE_BOUNDS,
#                 voxel_sizes=VOXEL_SIZES,
#                 rotation_resolution=ROTATION_RESOLUTION,
#                 crop_augmentation=False,
#                 clip_model=clip_model,
#                 device=device)
#     del clip_model

#     test_wrapped_replay = PyTorchReplayBuffer(test_replay_buffer)
#     test_dataset = test_wrapped_replay.dataset()
#     test_data_iter = iter(test_dataset)


# batch = next(test_data_iter) #Change to make it either 
# lang_goal = batch['lang_goal'][0][0][0]
# print(lang_goal)
# batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
# update_dict = peract_agent.update(iteration, batch, backprop=False)

# #things to print
# loss = update_dict['total_loss']
# rot_loss = update_dict['rot_loss']#.detach().cpu().numpy()
# trans_loss = update_dict['trans_loss']#.detach().cpy().numpy()
# col_loss = update_dict['col_loss']#.detach().cpu().numpy()
# print('The loss of this prediction is: ', loss)
# print('The rotational loss of this prediction is:', rot_loss)
# print('The translational loss of this prediction is:', trans_loss)
# print('The collision loss of this prediction is:', col_loss)

# # things to visualize
# vis_voxel_grid = update_dict['voxel_grid'][0].detach().cpu().numpy()
# vis_trans_q = update_dict['q_trans'][0].detach().cpu().numpy()
# vis_trans_coord = update_dict['pred_action']['trans'][0].detach().cpu().numpy()
# vis_gt_coord = update_dict['expert_action']['action_trans'][0].detach().cpu().numpy()

# # discrete to continuous
# continuous_trans = update_dict['pred_action']['continuous_trans'][0].detach().cpu().numpy()
# continuous_quat = discrete_euler_to_quaternion(update_dict['pred_action']['rot_and_grip'][0][:3].detach().cpu().numpy(),
#                                                resolution=peract_agent._rotation_resolution)
# gripper_open = bool(update_dict['pred_action']['rot_and_grip'][0][-1].detach().cpu().numpy())
# ignore_collision = bool(update_dict['pred_action']['collision'][0][0].detach().cpu().numpy())

# # gripper visualization pose
# voxel_size = 0.045
# voxel_scale = voxel_size * 100
# gripper_pose_mat = get_gripper_render_pose(voxel_scale,
#                                            SCENE_BOUNDS[:3],
#                                            continuous_trans,
#                                            continuous_quat)

In [None]:
# if not colab:
#     #@markdown #### Show Q-Prediction and Best Action
#     show_expert_action = True  #@param {type:"boolean"}
#     show_q_values = True  #@param {type:"boolean"}
#     render_gripper = True  #@param {type:"boolean"}
#     rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

#     rendered_img_0 = visualise_voxel(vis_voxel_grid,
#                                 vis_trans_q if show_q_values else None,
#                                 vis_trans_coord,
#                                 vis_gt_coord if show_expert_action else None,
#                                 voxel_size=voxel_size,
#                                 rotation_amount=np.deg2rad(0),
#                                 render_gripper=render_gripper,
#                                 gripper_pose=gripper_pose_mat,
#                                 gripper_mesh_scale=voxel_scale)

#     rendered_img_270 = visualise_voxel(vis_voxel_grid,
#                                 vis_trans_q if show_q_values else None,
#                                 vis_trans_coord,
#                                 vis_gt_coord if show_expert_action else None,
#                                 voxel_size=voxel_size,
#                                 rotation_amount=np.deg2rad(-90),
#                                 render_gripper=render_gripper,
#                                 gripper_pose=gripper_pose_mat,
#                                 gripper_mesh_scale=voxel_scale)


#     fig = plt.figure(figsize=(20, 15))
#     fig.add_subplot(1, 2, 1)
#     plt.imshow(rendered_img_0)
#     plt.axis('off')
#     plt.title("Front view")
#     fig.add_subplot(1, 2, 2)
#     plt.imshow(rendered_img_270)
#     plt.axis('off')
#     plt.title("Side view")

#     # print(f"Lang goal: {lang_goal}")

In [None]:
# # Save the trained PerAct model
# model_save_dir = os.path.join(WORKSPACE_DIR,"outputs", "models")

# if not os.path.exists(model_save_dir):
#     print(f"Could not find {model_save_dir}, creating directory.")
#     os.mkdir(model_save_dir)

# peract_agent.save_weights(model_save_dir)