In [None]:
import numpy as np
np.bool = np.bool_ # bad trick to fix numpy version issue :(
import os
import sys
from natsort import natsorted

sys.path = [p for p in sys.path if '/peract/' not in p]

# Set `PYOPENGL_PLATFORM=egl` for pyrender visualizations
os.environ["DISPLAY"] = ":0"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" # Depends on your computer and available GPUs

In [None]:
### Depending on your workspace, you may already have this repository installe, otherwise clone once again
if not os.path.exists(os.path.join(os.getcwd(), 'peract_colab')):
    !git clone https://github.com/yuki1003/peract_colab.git

!cd peract_colab && git pull origin master

In [None]:
# Check if modules in peract_colab repository are recognized.
try: # test import
    from rlbench.utils import get_stored_demo
except ImportError as error_message:
    print(error_message)
    print("Adding peract_colab repository to system path.")
    sys.path.append('peract_colab')

In [None]:
import shutil

import torch
import clip

from arm.replay_buffer import create_replay, fill_replay, uniform_fill_replay, fill_replay_copy_with_crop_from_approach, fill_replay_only_approach_test
from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer

from agent.perceiver_io import PerceiverIO
from agent.peract_agent import PerceiverActorAgent
import time
import json
from datetime import datetime

In [None]:
## STATIC VALUES USED IN BELOW FUNCTION: SETTING THEM AS GLOBAL FOR FURTHER USE

#___DATA___
TASK = 'handing_over_banana'

# Data Constants
WORKSPACE_DIR = os.getcwd()
DATA_FOLDER = os.path.join(WORKSPACE_DIR, "task_data", "handoversim")
EPISODES_FOLDER = os.path.join(TASK, "all_variations", "episodes")

EPISODE_FOLDER = 'episode%d'
SETUP = "s1" # Options: "s1"
train_data_path = os.path.join(DATA_FOLDER, f"train_{SETUP}", EPISODES_FOLDER)
TRAIN_INDEXES = [int(episode_nr.replace("episode", "")) for episode_nr in natsorted(os.listdir(train_data_path))]
test_data_path = os.path.join(DATA_FOLDER, f"val_{SETUP}", EPISODES_FOLDER)
TEST_INDEXES = [int(episode_nr.replace("episode", "")) for episode_nr in natsorted(os.listdir(test_data_path))]

print(f"TRAIN | Total #: {len(TRAIN_INDEXES)}, indices: {TRAIN_INDEXES}")
print(f"TEST | Total #: {TEST_INDEXES}")

# Replaybuffer related constants
LOW_DIM_SIZE = 4    # 4 dimensions - proprioception: {gripper_open, left_finger_joint, right_finger_joint, timestep}
IMAGE_SIZE =  128  # 128x128 - if you want to use higher voxel resolutions like 200^3, you might want to regenerate the dataset with larger images
# DEMO_AUGMENTATION_EVERY_N = 5#10 NOTE CHANGED through setting # Only select every n-th frame to use for replaybuffer from demo
ROTATION_RESOLUTION = 5 # degree increments per axis
TARGET_OBJ_KEYPOINTS=False # Real - (changed later)
TARGET_OBJ_USE_LAST_KP=False # Real - (changed later)
TARGET_OBJ_IS_AVAIL = True # HandoverSim - (changed later)

DEPTH_SCALE = 1000
STOPPING_DELTA = 0.001
SCENE_BOUNDS = [0.11, -0.5, 0.8, 1.11, 0.5, 1.8]  # Must be 1m each

# Training Settings Constants
BATCH_SIZE = 2
TRAINING_ITERATIONS = 10000
LEARNING_RATE = 0.001
TRANSFORM_AUGMENTATION = True


# Unused training settings
LR_SCHEDULER = False
NUM_WARMUP_STEPS = 300  # LR_SCHEDULER: losses seem to stabilize after ~300 iterations
NUM_CYCLES = 1  # As per: https://github.com/peract/peract/blob/02fb87681c5a47be9dbf20141bedb836ee2d3ef9/agents/peract_bc/qattention_peract_bc_agent.py#L232
VOXEL_SIZES = [100]  # 100x100x100 voxels
NUM_LATENTS = 512  # PerceiverIO latents: lower-dimension features of input data

In [None]:
def train_peract_agent(settings):

    # BATCH SETTINGS
    FILL_REPLAY_SETTING = settings['fill_replay_setting']
    CAMERAS = settings['cameras']
    RGB_AUGMENTATION = settings['RGB_AUGMENTATION']
    USE_APPROACH = settings['keypoint_approach']
    DEMO_AUGMENTATION_EVERY_N = settings['demo_augm_n']

    # Summary of run properties
    print("\nExperiment Setup")
    print(f"Task: {TASK} - SETUP: {SETUP} - Cameras: {len(CAMERAS)}")
    print("Run Properties")
    print(f"TrAugm: {TRANSFORM_AUGMENTATION} - RGBAugm: {RGB_AUGMENTATION} -Uniform: {FILL_REPLAY_SETTING}")

    #___REPLAY-BUFFER___
    train_replay_storage_dir = os.path.join(WORKSPACE_DIR,'replay_train')
    if os.path.exists(train_replay_storage_dir):
        print(f"Emptying {train_replay_storage_dir}")
        shutil.rmtree(train_replay_storage_dir)
    if not os.path.exists(train_replay_storage_dir):
        print(f"Could not find {train_replay_storage_dir}, creating directory.")
        os.mkdir(train_replay_storage_dir)

    test_replay_storage_dir = os.path.join(WORKSPACE_DIR,'replay_test')
    if os.path.exists(test_replay_storage_dir):
        print(f"Emptying {test_replay_storage_dir}")
        shutil.rmtree(test_replay_storage_dir)
    if not os.path.exists(test_replay_storage_dir):
        print(f"Could not find {test_replay_storage_dir}, creating directory.")
        os.mkdir(test_replay_storage_dir)

    train_replay_buffer = create_replay(batch_size=BATCH_SIZE,
                                        timesteps=1,
                                        save_dir=train_replay_storage_dir,
                                        cameras=CAMERAS,
                                        voxel_sizes=VOXEL_SIZES,
                                        image_size=IMAGE_SIZE,
                                        low_dim_size=LOW_DIM_SIZE)

    test_replay_buffer = create_replay(batch_size=BATCH_SIZE,
                                    timesteps=1,
                                    save_dir=test_replay_storage_dir,
                                    cameras=CAMERAS,
                                    voxel_sizes=VOXEL_SIZES,
                                    image_size=IMAGE_SIZE,
                                    low_dim_size=LOW_DIM_SIZE)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load("RN50", device=device) # CLIP-ResNet50

    print("-- Train Buffer --")
    if FILL_REPLAY_SETTING.lower() == "uniform":
        uniform_fill_replay(
            data_path=train_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=train_replay_buffer,
            # start_idx=0,
            # num_demos=NUM_DEMOS,
            d_indexes=TRAIN_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=USE_APPROACH,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    elif FILL_REPLAY_SETTING.lower() == "crop":
        fill_replay_copy_with_crop_from_approach(
            data_path=train_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=train_replay_buffer,
            # start_idx=0,
            # num_demos=NUM_DEMOS,
            d_indexes=TRAIN_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=USE_APPROACH,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    elif FILL_REPLAY_SETTING.lower() == "standard":
        fill_replay_only_approach_test(
        # fill_replay(
            data_path=train_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=train_replay_buffer,
            # start_idx=0,
            # num_demos=NUM_DEMOS,
            d_indexes=TRAIN_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=USE_APPROACH,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    else:
        raise ValueError("Unkown setting for fill replay buffer")

        
    print("-- Test Buffer --")
    if FILL_REPLAY_SETTING.lower() == "uniform":
        uniform_fill_replay(
            data_path=test_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=test_replay_buffer,
            # start_idx=start_idx,
            # num_demos=num_demos,
            d_indexes=TEST_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=USE_APPROACH,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    elif FILL_REPLAY_SETTING.lower() == "crop":
        fill_replay_copy_with_crop_from_approach(
            data_path=test_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=test_replay_buffer,
            # start_idx=start_idx,
            # num_demos=num_demos,
            d_indexes=TEST_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=USE_APPROACH,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    elif FILL_REPLAY_SETTING.lower() == "standard":
        fill_replay_only_approach_test(
        # fill_replay(
            data_path=test_data_path,
            episode_folder=EPISODE_FOLDER,
            replay=test_replay_buffer,
            # start_idx=start_idx,
            # num_demos=num_demos,
            d_indexes=TEST_INDEXES,
            demo_augmentation=True,
            demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
            cameras=CAMERAS,
            rlbench_scene_bounds=SCENE_BOUNDS,
            voxel_sizes=VOXEL_SIZES,
            rotation_resolution=ROTATION_RESOLUTION,
            crop_augmentation=False,
            depth_scale=DEPTH_SCALE,
            use_approach=USE_APPROACH,
            approach_distance=0.3,
            stopping_delta=STOPPING_DELTA,
            target_obj_keypoint=TARGET_OBJ_KEYPOINTS,
            target_obj_use_last_kp=TARGET_OBJ_USE_LAST_KP,
            target_obj_is_avail=TARGET_OBJ_IS_AVAIL,
            clip_model=clip_model,
            device=device,
            )
    else:
        raise ValueError("Unkown setting for fill replay buffer")

    # delete the CLIP model since we have already extracted language features
    del clip_model

    # wrap buffer with PyTorch dataset and make iterator
    train_wrapped_replay = PyTorchReplayBuffer(train_replay_buffer)
    train_dataset = train_wrapped_replay.dataset()
    train_data_iter = iter(train_dataset)

    test_wrapped_replay = PyTorchReplayBuffer(test_replay_buffer)
    test_dataset = test_wrapped_replay.dataset()
    test_data_iter = iter(test_dataset)

    #___AGENT___

    perceiver_encoder = PerceiverIO(
        depth=6,
        iterations=1,
        voxel_size=VOXEL_SIZES[0],
        initial_dim=3 + 3 + 1 + 3,
        low_dim_size=4,
        layer=0,
        num_rotation_classes=72,
        num_grip_classes=2,
        num_collision_classes=2,
        num_latents=NUM_LATENTS,
        latent_dim=512,
        cross_heads=1,
        latent_heads=8,
        cross_dim_head=64,
        latent_dim_head=64,
        weight_tie_layers=False,
        activation='lrelu',
        input_dropout=0.1,
        attn_dropout=0.1,
        decoder_dropout=0.0,
        voxel_patch_size=5,
        voxel_patch_stride=5,
        final_dim=64,
    )

    peract_agent = PerceiverActorAgent(
        coordinate_bounds=SCENE_BOUNDS,
        perceiver_encoder=perceiver_encoder,
        camera_names=CAMERAS,
        batch_size=BATCH_SIZE,
        voxel_size=VOXEL_SIZES[0],
        voxel_feature_size=3,
        num_rotation_classes=72,
        rotation_resolution=ROTATION_RESOLUTION,
        training_iterations = TRAINING_ITERATIONS,
        lr=LEARNING_RATE,
        lr_scheduler=LR_SCHEDULER,
        num_warmup_steps = NUM_WARMUP_STEPS,
        num_cycles = NUM_CYCLES,
        image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
        lambda_weight_l2=0.000001,
        transform_augmentation=TRANSFORM_AUGMENTATION,
        rgb_augmentation=RGB_AUGMENTATION,
        optimizer_type='lamb',
    )
    peract_agent.build(training=True, device=device)

    #___TRAINING___

    LOCAL_FREQ = 10
    SAVE_MODELS = True
    GLOBAL_FREQ = 1000
    calc_test_loss = True

    # Misc
    train_loss = 1e8
    test_loss = 1e8
    general_loss = [1e8,1e8]

    # Directories where to save the best models for train/test
    model_run_time = datetime.now()
    model_save_dir = os.path.join(WORKSPACE_DIR,"outputs", "models", TASK, model_run_time.strftime("%Y-%m-%d_%H-%M"))

    model_save_dir_iter = os.path.join(model_save_dir, "run%d")

    model_save_dir_best_general_iter = os.path.join(model_save_dir_iter, "best_model_general")
    model_save_dir_best_train_iter = os.path.join(model_save_dir_iter, "best_model_train")
    model_save_dir_best_test_iter = os.path.join(model_save_dir_iter, "best_model_test")
    model_save_dir_last_iter = os.path.join(model_save_dir_iter, "last_model")
    metrics_save_path_iter = os.path.join(model_save_dir_iter, "training_metrics.json")  # JSON file to save metrics

    metrics_save_path = os.path.join(model_save_dir, "training_metrics.json")  # JSON file to save metrics
    settings_save_path = os.path.join(model_save_dir, "training_settings.json")

    start_time = time.time()

    # Initialize metrics dictionary
    metrics = {
        "train": [],
        "test": []
    }

    for iteration in range(TRAINING_ITERATIONS):
        
        batch = next(train_data_iter)
        batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
        update_dict = peract_agent.update(iteration, batch) # Here backprop == True: for training reaons, hence training_loss == total_loss

        if iteration % LOCAL_FREQ == 0:
            elapsed_time = (time.time() - start_time) / 60.0

            # Log training metrics
            train_metrics = {
                "iteration": iteration,
                "learning_rate": update_dict['learning_rate'],
                "total_loss": update_dict['total_loss'],
                "trans_loss": update_dict['trans_loss'],
                "rot_loss": update_dict['rot_loss'],
                "col_loss": update_dict['col_loss'],
                "elapsed_time": elapsed_time
            }
            metrics["train"].append(train_metrics)

            if calc_test_loss:
                batch = next(test_data_iter)
                batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
                test_update_dict = peract_agent.update(iteration, batch, backprop=False) # Here backprop == False: for evaluation, hence test_loss == total_loss

                # Log test metrics
                test_metrics = {
                    "iteration": iteration,
                    "total_loss": test_update_dict['total_loss'],
                    "trans_loss": test_update_dict['trans_loss'],
                    "rot_loss": test_update_dict['rot_loss'],
                    "col_loss": test_update_dict['col_loss']
                }
                metrics["test"].append(test_metrics)

                print("Iteration: %d/%d | Learning Rate: %f | Train Loss [tot,trans,rot,col]: [%0.2f, %0.2f, %0.2f, %0.2f] | Test Loss [tot,trans,rot,col]: [%0.2f, %0.2f, %0.2f, %0.2f] | Elapsed Time: %0.2f mins"\
                    % (iteration, TRAINING_ITERATIONS, 
                        update_dict['learning_rate'],
                        update_dict['total_loss'], update_dict['trans_loss'], update_dict['rot_loss'], update_dict['col_loss'],
                        test_update_dict['total_loss'], test_update_dict['trans_loss'], test_update_dict['rot_loss'], test_update_dict['col_loss'], 
                        elapsed_time))
            else:
                print("Iteration: %d/%d | Learning Rate: %f| Train Loss [tot,trans,rot,col]: [%0.2f, %0.2f, %0.2f, %0.2f] | Elapsed Time: %0.2f mins"\
                    % (iteration, TRAINING_ITERATIONS, 
                        update_dict['learning_rate'],
                        update_dict['total_loss'], update_dict['trans_loss'], update_dict['rot_loss'], update_dict['col_loss'],
                        elapsed_time))
                
            if (SAVE_MODELS == True):
                save_model_freq_iter_number = (iteration // GLOBAL_FREQ) * GLOBAL_FREQ

                model_save_dir_general_iteration = model_save_dir_best_general_iter % save_model_freq_iter_number
                model_save_dir_best_train_iteration = model_save_dir_best_train_iter % save_model_freq_iter_number
                model_save_dir_best_test_iteration = model_save_dir_best_test_iter % save_model_freq_iter_number

                # Create directories
                if not os.path.exists(model_save_dir_general_iteration):
                    print(f"Could not find {model_save_dir_general_iteration}, creating directory.")
                    os.makedirs(model_save_dir_general_iteration)
                if not os.path.exists(model_save_dir_best_train_iteration):
                    print(f"Could not find {model_save_dir_best_train_iteration}, creating directory.")
                    os.makedirs(model_save_dir_best_train_iteration)
                if not os.path.exists(model_save_dir_best_test_iteration):
                    print(f"Could not find {model_save_dir_best_test_iteration}, creating directory.")
                    os.makedirs(model_save_dir_best_test_iteration)
                    
                # Only save the best if better
                if update_dict['total_loss'] < general_loss[0] and test_update_dict['total_loss'] < general_loss[1]:
                    print("Saving Best Model - General")
                    general_loss = [update_dict['total_loss'], test_update_dict['total_loss']]
                    peract_agent.save_weights(model_save_dir_general_iteration)
                if update_dict['total_loss'] < train_loss:
                    print("Saving Best Model - Train")
                    train_loss = update_dict['total_loss']
                    peract_agent.save_weights(model_save_dir_best_train_iteration)
                if test_update_dict['total_loss'] < test_loss:
                    print("Saving Best Model - Test")
                    test_loss = test_update_dict['total_loss']
                    peract_agent.save_weights(model_save_dir_best_test_iteration)
            
                if (iteration+LOCAL_FREQ) % GLOBAL_FREQ == 0:# and iteration // GLOBAL_FREQ: #0-500 -> 0, 500-1000 -> 1
                    save_model_freq_iter_number = (iteration // GLOBAL_FREQ) * GLOBAL_FREQ

                    # Save last checkpoint
                    model_save_dir_last_iteration = model_save_dir_last_iter%save_model_freq_iter_number
                    metrics_save_path_iteration = metrics_save_path_iter%save_model_freq_iter_number

                    if not os.path.exists(model_save_dir_last_iteration):
                        print(f"Could not find {model_save_dir_last_iteration}, creating directory.")
                        os.makedirs(model_save_dir_last_iteration)

                    print(f"Saving Model - Last stage: {save_model_freq_iter_number}")
                    peract_agent.save_weights(model_save_dir_last_iteration)

                    # Save metrics to JSON file
                    with open(metrics_save_path_iteration, 'w') as f:
                        json.dump(metrics, f, indent=4, cls=NumpyEncoder)

    # Save metrics to JSON file
    with open(metrics_save_path, 'w') as f:
        json.dump(metrics, f, indent=4, cls=NumpyEncoder)

    # Save training settings to JSON file
    with open(settings_save_path, 'w') as f:
        json.dump(settings, f, indent=4, cls=NumpyEncoder)

    print(f"Training metrics and settings saved to {metrics_save_path} and {settings_save_path}")

    del peract_agent

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)

In [None]:
import itertools

available_cameras = [f"view_{camera_i}" for camera_i in range(3)]
# Grid search
grid = {
    'fill_replay_setting': ["standard"],#, "standard", "uniform"],
    'cameras': [available_cameras],# [available_cameras[0]]],
    'RGB_AUGMENTATION': ['None','partial','full'],
    'demo_augm_n': [5],
    'keypoint_approach': [True],#, False],
    'only_learn_approach': [True]
}
# Loop over al grid search combinations
counter = 0
lst_settings = []
for values in itertools.product(*grid.values()):
    
    point = dict(zip(grid.keys(), values))
    # merge the general settings
    settings = {**point}
    lst_settings.append(settings)
    print(counter, settings)
    counter += 1

In [None]:
# Loop over al grid search combinations: and run
for run_settings in lst_settings:
    print(run_settings)

    train_peract_agent(run_settings)

In [None]:
## For single (test) run
# available_cameras = [f"view_{camera_i}" for camera_i in range(3)]
# run_settings = {
#     'fill_replay_setting': "crop",
#     'cameras': available_cameras,
#     'RGB_AUGMENTATION': 'partial',
#     'keypoint_approach': True,
#     'demo_augm_n': 5,
#     'only_learn_approach': True
# }
# print(run_settings)

# train_peract_agent(run_settings)