# Inference on Trained PerAct Agent on handoversim data

In [None]:
import numpy as np
np.bool = np.bool_ # bad trick to fix numpy version issue :(
import os
import sys
sys.path = [p for p in sys.path if '/peract/' not in p]

# Set `PYOPENGL_PLATFORM=egl` for pyrender visualizations
os.environ["DISPLAY"] = ":0"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" # Depends on your computer and available GPUs

In [None]:
import json

from notebook_helpers.constants import * # Load global constant variables from constants.py
from notebook_helpers.build_training import build_agent

# Choose the run
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-10_14-59" # Good {non-uniform 1 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad {uniform 2 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_13-34"
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_17-25" # {crop skip 10}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_21-23"
run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-16_11-25/" # {crop skip 10 new}

# Obtain settings
path_settings = os.path.join(os.path.dirname(run_dir), "training_settings.json")
with open(path_settings, 'r') as f:
    settings = json.load(f)
CAMERAS = settings['cameras']

if BATCH_SIZE != 1:
    raise ValueError("For Inference, 'batch_size' must be set to 1 in constants.py.")

peract_agent = build_agent(settings, training=False)
peract_agent.set_language_goal("handing over banana")

# Choose model
iteration = "run8000"
best_type = "best_model_general"
model_path = os.path.join(run_dir, iteration, best_type)

peract_agent.load_weights(model_path)

### Inference: Run Inference on just observation data and Save as video

In [None]:
import imageio
import numpy as np
from matplotlib import pyplot as plt
from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs
from arm.utils import visualise_voxel
from arm.utils import get_gripper_render_pose


# What to visualize
episode_idx_to_visualize = 646  # Index of the episode to visualize
# Video output path
video_output_path = f"demo{episode_idx_to_visualize}_visualization.mp4"

# Get demo
demo = get_stored_demo(data_path=test_data_path,
                       index=episode_idx_to_visualize,
                       cameras=CAMERAS,
                       depth_scale=DEPTH_SCALE)

episode_length = list(range(len(demo._observations)))

# Open a video writer
with imageio.get_writer(video_output_path, fps=10) as video_writer:
    for ts in episode_length[::5]:
        print(ts)
        # Extract obs at timestep
        obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
        # gripper_pose = demo[ts].gripper_pose
        gripper_open = demo[ts].gripper_open
        gripper_joint_positions = demo[ts].gripper_joint_positions

        obs_dict["gripper_open"] = gripper_open
        obs_dict["gripper_joint_positions"] = gripper_joint_positions

        (continuous_trans, continuous_quat, gripper_open, _, _), \
        (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent.forward(obs_dict, ts)
        print(continuous_trans, continuous_quat, gripper_open)

        # Things to visualize
        vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()
        pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()

        voxel_size = 0.045
        voxel_scale = voxel_size * 100
        gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                                   SCENE_BOUNDS[:3],
                                                   continuous_trans,
                                                   continuous_quat)

        rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                         None,
                                         [pred_trans_coord],
                                         None,
                                         voxel_size=voxel_size,
                                         rotation_amount=np.deg2rad(0),
                                         render_gripper=True,
                                         gripper_pose=gripper_pose_mat,
                                         gripper_mesh_scale=voxel_scale)

        rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                           None,
                                           [pred_trans_coord],
                                           None,
                                           voxel_size=voxel_size,
                                           rotation_amount=np.deg2rad(45+180),
                                           render_gripper=True,
                                           gripper_pose=gripper_pose_mat,
                                           gripper_mesh_scale=voxel_scale)

        # Plot figures into a NumPy array
        fig = plt.figure(figsize=(20, 15))
        fig.add_subplot(1, 2, 1)
        plt.imshow(rendered_img_0)
        plt.axis('off')
        plt.title("Front view")
        fig.add_subplot(1, 2, 2)
        plt.imshow(rendered_img_270)
        plt.axis('off')
        plt.title("Side view")

        # Add timestamp as text with white font and black background
        fig.text(0.02, 0.95, f"Timestep: {ts}", ha='left', fontsize=16, color='white', weight='bold',
                 bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.3'))

        # Convert the matplotlib figure to a NumPy array
        fig.canvas.draw()
        img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        video_writer.append_data(img_array)  # Add frame to video
        plt.close(fig)  # Close the figure to free memory

print(f"Video saved to {video_output_path}")


### Validation: Run Inference on just input data with ground truth to check what happends

In [None]:
def get_gt_from_frame(episode_keypoints_gt_obs_dict, frame_idx):
    # Get a sorted list of dictionary keys (frames with stored data)
    episode_keypoints = sorted(episode_keypoints_gt_obs_dict.keys())

    # Iterate through the given frame indices
    # Find the smallest key that is greater than or equal to the current frame
    for episode_kp in episode_keypoints:
        if frame_idx <= episode_kp:
            episode_kp_gt_obs_dict = episode_keypoints_gt_obs_dict[episode_kp]
            return episode_kp_gt_obs_dict
    
    return None

In [None]:
#TODO: NEED SOME WAY TO COLLECT THE GROUND TRUTH DATA FOR ALL FRAMES

import imageio
import numpy as np
from matplotlib import pyplot as plt

from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs_gt

from arm.demo import _keypoint_discovery_available
from arm.utils import point_to_voxel_index, visualise_voxel, get_gripper_render_pose


# What to visualize
episode_idx_to_visualize = 645  # Index of the episode to visualize
# Video output path
video_output_path = f"demo{episode_idx_to_visualize}_gt_pred_visualization.mp4"

# Get demo
demo = get_stored_demo(data_path=test_data_path,
                       index=episode_idx_to_visualize,
                       cameras=CAMERAS,
                       depth_scale=DEPTH_SCALE)

episode_keypoints_gt_obs_dict = dict()
episode_keypoints = _keypoint_discovery_available(demo, approach_distance=0.3) #NOTE: Approach_distance Set
episode_keypoints = [episode_keypoints[-1]]
for episode_keypoint in episode_keypoints:
    episode_keypoints_gt_obs_dict[episode_keypoint] = extract_obs_gt(obs = demo._observations[episode_keypoint],
                                                                  cameras=CAMERAS)

episode_length = list(range(len(demo._observations)))
# episode_length = [40]

# Open a video writer
with imageio.get_writer(video_output_path, fps=10) as video_writer:
    for ts in episode_length[::5]: # Skip some frames
        print(ts)
        # Extract obs at timestep
        obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
        # gripper_pose = demo[ts].gripper_pose
        gripper_open = demo[ts].gripper_open
        gripper_joint_positions = demo[ts].gripper_joint_positions

        obs_dict["gripper_open"] = gripper_open
        obs_dict["gripper_joint_positions"] = gripper_joint_positions

        (continuous_trans, continuous_quat, gripper_open, trans_confidence, _), \
        (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent.forward(obs_dict, ts)

        pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()
        
        # Get the ground truth
        episode_keypoint_gt_obs_dict = get_gt_from_frame(episode_keypoints_gt_obs_dict, ts)
        if not (episode_keypoint_gt_obs_dict is None):
            gt_gripper_pose = episode_keypoint_gt_obs_dict["gripper_pose"]
            gt_trans_coord = point_to_voxel_index(gt_gripper_pose[:3], VOXEL_SIZES, SCENE_BOUNDS)[0]
            error = np.linalg.norm(gt_gripper_pose[:3] - continuous_trans)
            print(f"GT (voxel): {gt_trans_coord} - Prediction (voxel): {pred_trans_coord} - Error: {error} - Prediction-score: {round(trans_confidence,4)}")
        else:
            gt_trans_coord = None
            error = False
            print("GT coordinates not available for this frame")
        
        # Things to visualize
        vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()

        voxel_size = 0.045
        voxel_scale = voxel_size * 100
        gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                                   SCENE_BOUNDS[:3],
                                                   continuous_trans,
                                                   continuous_quat)

        rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                         None,
                                         [pred_trans_coord],
                                         gt_trans_coord,
                                         alpha = 0.2,
                                         voxel_size=voxel_size,
                                         rotation_amount=np.deg2rad(0),
                                         render_gripper=True,
                                         gripper_pose=gripper_pose_mat,
                                         gripper_mesh_scale=voxel_scale)

        rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                           None,
                                           [pred_trans_coord],
                                           gt_trans_coord,
                                           alpha = 0.2,
                                           voxel_size=voxel_size,
                                           rotation_amount=np.deg2rad(45+180),
                                           render_gripper=True,
                                           gripper_pose=gripper_pose_mat,
                                           gripper_mesh_scale=voxel_scale)

        # Plot figures into a NumPy array
        fig = plt.figure(figsize=(20, 15))
        fig.add_subplot(1, 2, 1)
        plt.imshow(rendered_img_0)
        plt.axis('off')
        plt.title("Front view")
        fig.add_subplot(1, 2, 2)
        plt.imshow(rendered_img_270)
        plt.axis('off')
        plt.title("Side view")

        # Add timestamp as text with white font and black background
        if error:
            fig.text(0.02, 0.95, f"Timestep: {ts}, Prediction-score: {round(trans_confidence,4)}, Error: {np.round(error, 3)}", ha='left', fontsize=16, color='white', weight='bold',
                    bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.3'))
        else:
            fig.text(0.02, 0.95, f"Timestep: {ts}, Prediction-score: {round(trans_confidence,4)}", ha='left', fontsize=16, color='white', weight='bold',
                    bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.3'))

        # Convert the matplotlib figure to a NumPy array
        fig.canvas.draw()
        img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        video_writer.append_data(img_array)  # Add frame to video
        # plt.show()
        plt.close(fig)  # Close the figure to free memory
        

print(f"Video saved to {video_output_path}")


### Validation: Run Inference with same batches as during training (You can change `batch_size` in <i>constants.py</i>)

In [None]:
# import json
# from notebook_helpers.build_replay import load_replay_buffer

# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-16_11-25/" # {crop skip 10 new}

# path_settings = os.path.join(os.path.dirname(run_dir), "training_settings.json")
# with open(path_settings, 'r') as f:
#     settings = json.load(f)

# train_data_iter, test_data_iter = load_replay_buffer(settings)

In [None]:
import json

from notebook_helpers.constants import * # Load global constant variables from constants.py
from notebook_helpers.build_training import build_agent
from notebook_helpers.build_replay import load_replay_buffer

# Choose the run
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-10_14-59" # Good {non-uniform 1 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad {uniform 2 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_13-34"
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_17-25" # {crop skip 10}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_21-23"
run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-16_11-25/" # {crop skip 10 new}

# Obtain settings
path_settings = os.path.join(os.path.dirname(run_dir), "training_settings.json")
with open(path_settings, 'r') as f:
    settings = json.load(f)

train_data_iter, test_data_iter = load_replay_buffer(settings)

peract_agent = build_agent(settings, training=True)
peract_agent.set_language_goal("handing over banana")

# Choose model
iteration = "run8000"
best_type = "best_model_general"
model_path = os.path.join(run_dir, iteration, best_type)

peract_agent.load_weights(model_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs

CAMERAS = settings["cameras"]

batch = next(train_data_iter)

# what to visualize
episode_idx_to_visualize = 646#INDEXES[0] # out of 10 demos
# ts = 70#25 # timestep out of total timesteps

# get demo
demo = get_stored_demo(data_path=test_data_path,
                    index=episode_idx_to_visualize,
                    cameras=CAMERAS,
                    depth_scale=DEPTH_SCALE,)

episode_length = list(range(len(demo._observations)))
for ts in episode_length:

    # extract obs at timestep
    obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
    gripper_pose = demo[ts].gripper_pose
    gripper_open = demo[ts].gripper_open
    gripper_joint_positions = demo[ts].gripper_joint_positions

    # obs_dict["gripper_pose"] = gripper_pose
    obs_dict["gripper_open"] = gripper_open
    obs_dict["gripper_joint_positions"] = gripper_joint_positions

    # plot rgb and depth at timestep
    fig = plt.figure(figsize=(20, 10))
    rows, cols = 2, len(CAMERAS)

    plot_idx = 1
    for camera in CAMERAS:
        # rgb
        rgb_name = "%s_%s" % (camera, 'rgb')
        rgb = np.transpose(obs_dict[rgb_name], (1, 2, 0))
        fig.add_subplot(rows, cols, plot_idx)
        plt.imshow(rgb)
        plt.axis('off')
        plt.title("%s_rgb | step %s" % (camera, ts))

        # depth
        depth_name = "%s_%s" % (camera, 'depth')
        depth = np.transpose(obs_dict[depth_name], (1, 2, 0))
        fig.add_subplot(rows, cols, plot_idx+len(CAMERAS))
        plt.imshow(depth)
        plt.axis('off')
        plt.title("%s_depth | step %s" % (camera, ts))

        plot_idx += 1

    plt.show()

    print(obs_dict)

    (continuous_trans, continuous_quat, gripper_open), (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent.forward(obs_dict, ts)

    from arm.utils import visualise_voxel
    from arm.utils import discrete_euler_to_quaternion, get_gripper_render_pose

    # things to visualize
    vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()
    pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()

    # discrete to continuous
    continuous_trans = continuous_trans[0].detach().cpu().numpy()
    continuous_quat = discrete_euler_to_quaternion(rot_and_grip_indices[0][:3].detach().cpu().numpy(),
                                                resolution=peract_agent._rotation_resolution)
    gripper_open = bool(rot_and_grip_indices[0][-1].detach().cpu().numpy())
    ignore_collision = bool(test_update_dict['pred_action']['collision'][0][0].detach().cpu().numpy())

    # # gripper visualization pose
    voxel_size = 0.045
    voxel_scale = voxel_size * 100
    gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                            SCENE_BOUNDS[:3],
                                            continuous_trans,
                                            continuous_quat)

    # #@markdown #### Show Q-Prediction and Best Action
    show_expert_action = True  #@param {type:"boolean"}
    show_q_values = False  #@param {type:"boolean"}
    render_gripper = False  #@param {type:"boolean"}
    rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

    rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(0),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale,
                                perspective=False)

    rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(45),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale)


    fig = plt.figure(figsize=(20, 15))
    fig.add_subplot(1, 2, 1)
    plt.imshow(rendered_img_0)
    plt.axis('off')
    plt.title("Front view")
    fig.add_subplot(1, 2, 2)
    plt.imshow(rendered_img_270)
    plt.axis('off')
    plt.title("Side view")


    # #@markdown #### Show Q-Prediction and Best Action
    show_expert_action = True  #@param {type:"boolean"}
    show_q_values = True  #@param {type:"boolean"}
    render_gripper = True  #@param {type:"boolean"}
    rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

    rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(0),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale)

    rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(45),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale)


    fig = plt.figure(figsize=(20, 15))
    fig.add_subplot(1, 2, 1)
    plt.imshow(rendered_img_0)
    plt.axis('off')
    plt.title("Front view")
    fig.add_subplot(1, 2, 2)
    plt.imshow(rendered_img_270)
    plt.axis('off')
    plt.title("Side view")

    print(f"Lang goal: {lang_goal}")


### Validation: Run Inference with same batches as during training (You can change `batch_size` in <i>constants.py</i>) - Check confidence over different iteration timesteps

In [None]:
import json
import os
from natsort import natsorted
import numpy as np


from notebook_helpers.constants import BATCH_SIZE # Load global constant variables from constants.py
from notebook_helpers.build_training import build_agent
from notebook_helpers.build_replay import load_replay_buffer


if not BATCH_SIZE in [4, 6]:
    raise ValueError('Set BATCH_SIZE = 6 in notebook_helpers.constants.py!')

# TASK = "handing_over_banana"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_16-17"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_18-40"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_21-05"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_23-30"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_01-56"

TASK = "handing_over_mug"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_17-28" #camera_setting = {6, 8, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_19-52" #camera_setting = {6, 7, 8, 9, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_22-18" #camera_setting = {0, 1, 2, 3, 4, 5}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_00-43" #camera_setting = {1, 3, 5, 6, 8, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_03-09" #camera_setting = {4, 5, 6, 7, 8, 9}
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-22_17-07" #task="handing_over_mug_and_grasp_handle"
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-22_23-18" #task="handing_over_mug_and_grasp_rim"
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-23_06-46" #task={"handing_over_mug_and_grasp_handle", "handing_over_mug_and_grasp_rim"}

# TASK = "handing_over_bowl"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_12-22"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17-13-35"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17-14-49"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_16-11"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_17-30"

# TASK = "handing_over_pitcher_base"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_11-02"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_12-06"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_13-10"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_14-15"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_15-19"

# Obtain settings
path_settings = os.path.join(run_dir, "training_settings.json")
with open(path_settings, 'r') as f:
    settings = json.load(f)

EPISODE_FOLDER = 'episode%d'
SETUP = 's1'

WORKSPACE_DIR = os.getcwd()
DATA_FOLDER  = os.path.join(WORKSPACE_DIR, 'task_data', 'handoversim_v4')
DATA_FOLDER = DATA_FOLDER.replace('/peract_colab', '')
TASK = 'handing_over_mug'
EPISODES_FOLDER = os.path.join(TASK, "all_variations", "episodes")

train_data_path = os.path.join(DATA_FOLDER, f"train_{SETUP}", EPISODES_FOLDER)
test_data_path = os.path.join(DATA_FOLDER, f"val_{SETUP}", EPISODES_FOLDER)
TRAIN_INDEXES = [int(episode_nr.replace("episode", "")) for episode_nr in natsorted(os.listdir(train_data_path))]
TEST_INDEXES = [int(episode_nr.replace("episode", "")) for episode_nr in natsorted(os.listdir(test_data_path))]
#TASK = MUG. Separate by handle vs. rim
# TRAIN_INDEXES = [66, 266, 268, 269, 368, 465, 466]
# TRAIN_INDEXES = [167, 169, 265, 365, 366, 367, 369, 468, 566, 567, 568, 569, 966, 967, 968]
test_data_path = os.path.join(DATA_FOLDER, f"train_{SETUP}", EPISODES_FOLDER)
TRAIN_INDEXES = [66, 266, 268, 269, 368] # handle
TEST_INDEXES = [465, 466] # handle

TRAIN_INDEXES = TRAIN_INDEXES + [167, 169, 265, 365, 366, 367, 369, 468, 566, 567, 568, 569] # rim
TEST_INDEXES = TEST_INDEXES + [966, 967, 968] # rim

train_data_iter, test_data_iter = load_replay_buffer(settings,
                                                     WORKSPACE_DIR, SETUP, EPISODE_FOLDER,
                                                     TASK,
                                                     train_data_path, test_data_path, TRAIN_INDEXES, TEST_INDEXES)

peract_agent = build_agent(settings, training=True) # Set training to True for running with replaybuffer
peract_agent.set_language_goal(TASK)

In [None]:
from notebook_helpers.build_training import NumpyEncoder
from notebook_helpers.constants import *


model_runs = natsorted([run for run in os.listdir(run_dir) if "run" in run])

# Choose the loss metric at which model is saved
# chosen_model = "best_model_train"
# chosen_model = "best_model_test"
# chosen_model = "best_model_general"
chosen_model = "last_model"

# Choose validation set
testing_set = "train"
testing_set = "train_handle"
testing_set = "val_handle_or_rim"

# Save results
model_run_scored = dict()
model_run_distances = dict()
# model_run_rotations = dict() # Ignore rotation, location matters most

# Loop over iterations
for model_run_iter in model_runs:
    
    # Load model if available
    try:
        peract_agent.load_weights(os.path.join(run_dir, model_run_iter, chosen_model))
    except:
        print(f"Model {model_run_iter} not found, skipping.")
        continue

    distances_run = []
    scores_run = []
    
    for i in range(30): # collect using ... samples
        
        if "train" in testing_set:
            batch = next(train_data_iter) # collect batch
        if "val" in testing_set:
            batch = next(test_data_iter) # collect batch

        lang_goal = batch['lang_goal']
        print(f"batch: {i} - analyzing: {lang_goal}")

        # Set batch tensor on GPU and predict
        batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
        update_dict = peract_agent.update(None, batch, backprop=False)
        
        # Results
        prediction_scores = torch.amax(update_dict["q_trans"], dim=(1,2,3,4)).detach().cpu().numpy()
        prediction_scores = np.around(prediction_scores, 4)

        pred_trans = update_dict["pred_action"]["trans"]
        gt_trans = update_dict["expert_action"]["action_trans"]
        dist = np.round(np.linalg.norm(pred_trans-gt_trans, axis=1), 4)
        
        # Save Results
        distances_run.extend(dist.tolist())
        scores_run.extend(prediction_scores.tolist())
    
    zipped_lists = zip(distances_run, scores_run)
    sorted_lists = sorted(zipped_lists, key=lambda x: x[0]) # Order ascending distances
    sorted_distances, sorted_scores = zip(*sorted_lists) # Unzip

    # Save results to iteration
    model_run_distances[model_run_iter] = sorted_distances
    model_run_scored[model_run_iter] = sorted_scores

with open(os.path.join(run_dir, f"results_distances_{chosen_model}_on_{testing_set}.json"), 'w') as f:
    json.dump(model_run_distances, f, indent=4, cls=NumpyEncoder)

with open(os.path.join(run_dir, f"results_scores_{chosen_model}_on_{testing_set}.json"), 'w') as f:
    json.dump(model_run_scored, f, indent=4, cls=NumpyEncoder)


In [None]:
import os
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


n_samples = 3

# Choose the loss metric at which model is saved
# chosen_model = "best_model_train"
# chosen_model = "best_model_test"
# chosen_model = "best_model_general"
chosen_model = "last_model"

# Choose validation set
# testing_set = "train_handle"
# testing_set = "train_rim"
testing_set = "val_handle"

# TASK = "handing_over_banana"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_16-17"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_18-40"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_21-05"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_23-30"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_01-56"

TASK = "handing_over_mug"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_17-28" #camera_setting = {6, 8, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_19-52" #camera_setting = {6, 7, 8, 9, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_22-18" #camera_setting = {0, 1, 2, 3, 4, 5}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_00-43" #camera_setting = {1, 3, 5, 6, 8, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_03-09" #camera_setting = {4, 5, 6, 7, 8, 9}
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-22_17-07" #task="handing_over_mug_and_grasp_handle"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-22_23-18" #task="handing_over_mug_and_grasp_rim"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-23_06-46" #task={"handing_over_mug_and_grasp_handle", "handing_over_mug_and_grasp_rim"}

# TASK = "handing_over_bowl"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_12-22"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17-13-35"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17-14-49"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_16-11"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_17-30"

# TASK = "handing_over_pitcher_base"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_11-02"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_12-06"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_13-10"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_14-15"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_15-19"

with open(os.path.join(run_dir, f"results_distances_{chosen_model}_on_{testing_set}.json")) as f:
    model_run_distances = json.load(f)

with open(os.path.join(run_dir, f"results_scores_{chosen_model}_on_{testing_set}.json")) as f:
    model_run_scored = json.load(f)

run_iteration_keys = list(model_run_distances.keys())

run_iterations = np.linspace(0, len(run_iteration_keys)-1, n_samples, dtype=int)
sampled_keys = [[run_iteration_keys[run_iteration] for run_iteration in run_iterations][-1]]
# sampled_keys = [sampled_keys[-1]] if n_samples == 1 else sampled_keys

model_run_distances_sampled = {key: model_run_distances[key] for key in sampled_keys}
model_run_scored_sampled = {key: model_run_scored[key] for key in sampled_keys}

df_dist = pd.concat([
    pd.DataFrame({'Iteration': key, 'Error': values}) for i, (key, values) in enumerate(model_run_distances_sampled.items())],
    ignore_index=True
)
# df_dist = df_dist.iloc[::3]
df_conf = pd.concat([
    pd.DataFrame({'Iteration': key, 'Confidence': values}) for i, (key, values) in enumerate(model_run_scored_sampled.items())],
    ignore_index=True
)
# df_conf = df_conf.iloc[::3]

fig, ax = plt.subplots()
sns.histplot(data=df_dist, x='Error', hue='Iteration', fill=True, alpha=0.3)
# sns.displot(data=df_dist, x='Error', kde=True)
# sns.boxenplot(data=df_dist, x="Error", y="Iteration", ax = ax)
plt.title('Distribution of Error Across Iterations')
plt.xlabel('Translation error [voxels]')
# ax.set_xlim(1, 20)
plt.show()

sns.histplot(data=df_conf, x='Confidence', hue='Iteration', fill=True, alpha=0.3)
plt.title('Distribution of Confidences Across Iterations')
plt.show()

df_merged = pd.merge(df_dist, df_conf, on='Iteration', how="inner")
sns.scatterplot(data=df_merged, x='Error', y='Confidence', hue='Iteration', alpha=0.3)

### Validation: Run Inference with same batches as during training (You can change `batch_size` in <i>constants.py</i>) - Check the losses over time per episode, using `last_model`

In [None]:
import json
import os
from natsort import natsorted
import numpy as np


from notebook_helpers.constants import BATCH_SIZE # Load global constant variables from constants.py
from notebook_helpers.build_training import build_agent
from notebook_helpers.build_replay import load_replay_buffer


if BATCH_SIZE != 1:
    raise ValueError('Set BATCH_SIZE = 1 in notebook_helpers.constants.py!')

# TASK = "handing_over_mug"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_17-28" #camera_setting = {6, 8, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_19-52" #camera_setting = {6, 7, 8, 9, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_22-18" #camera_setting = {0, 1, 2, 3, 4, 5}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_00-43" #camera_setting = {1, 3, 5, 6, 8, 10}
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-17_03-09" #camera_setting = {4, 5, 6, 7, 8, 9}
# # run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-22_17-07" #task="handing_over_mug_and_grasp_handle"
# # run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-22_23-18" #task="handing_over_mug_and_grasp_rim"
# # run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-23_06-46" #task={"handing_over_mug_and_grasp_handle", "handing_over_mug_and_grasp_rim"}
# # run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-23_15-11" #task="handing_over_mug_and_grasp_handle" val = {465, 466}

# TASK = "handing_over_pitcher_base"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_11-02"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_12-06"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_13-10"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_14-15"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_15-19"

# TASK = "handing_over_banana"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_11-22"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-12_12-06"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_13-10"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_14-15"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_15-19"

# TASK = "handing_over_banana"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-16_11-22"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-12_12-06"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_13-10"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_14-15"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-21_15-19"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-31_15-33"
# run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs/models/{TASK}/2025-01-31_17-54"

TASK = "handing_over_banana"
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs_rgb_augmentation/models/{TASK}/2025-02-04_08-05"
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs_rgb_augmentation/models/{TASK}/2025-02-04_09-11"
run_dir = f"/home/bepgroup/Projects/PerAct_ws/peract_colab/outputs_rgb_augmentation/models/{TASK}/2025-02-04_10-17"

# Obtain settings
path_settings = os.path.join(run_dir, "training_settings.json")
with open(path_settings, 'r') as f:
    settings = json.load(f)

EPISODE_FOLDER = 'episode%d'
SETUP = 's1'

WORKSPACE_DIR = os.getcwd()
DATA_FOLDER  = os.path.join(WORKSPACE_DIR, 'task_data', 'handoversim_v4')
DATA_FOLDER = DATA_FOLDER.replace('/peract_colab', '')
EPISODES_FOLDER = os.path.join(TASK, "all_variations", "episodes")

train_data_path = os.path.join(DATA_FOLDER, f"train_{SETUP}", EPISODES_FOLDER)
test_data_path = os.path.join(DATA_FOLDER, f"val_{SETUP}", EPISODES_FOLDER)
TRAIN_INDEXES = [int(episode_nr.replace("episode", "")) for episode_nr in natsorted(os.listdir(train_data_path))]
TEST_INDEXES = [int(episode_nr.replace("episode", "")) for episode_nr in natsorted(os.listdir(test_data_path))]
#TASK = MUG. Separate by handle vs. rim
# TRAIN_INDEXES = [66, 266, 268, 269, 368, 465, 466]
# TRAIN_INDEXES = [167, 169, 265, 365, 366, 367, 369, 468, 566, 567, 568, 569, 966, 967, 968]
# test_data_path = os.path.join(DATA_FOLDER, f"train_{SETUP}", EPISODES_FOLDER)
# TRAIN_INDEXES = [66, 266, 268, 269, 368] # handle
# TEST_INDEXES = [465, 466] # handle
# TRAIN_INDEXES = [167, 169, 265, 365, 366, 367, 369, 468, 566, 567, 568, 569] # rim
# TEST_INDEXES = [966, 967, 968] # rim

_, test_data_iter = load_replay_buffer(settings,
                                       WORKSPACE_DIR, SETUP, EPISODE_FOLDER,
                                       TASK,
                                       train_data_path, test_data_path, TRAIN_INDEXES, TEST_INDEXES)

peract_agent = build_agent(settings, training=True) # Set training to True for running with replaybuffer

In [None]:
## COLLECT ALL FRAMES USED FOR ANALYSIS

import torch

dict_test_episodes_frames = dict()

## First find analysis
for analyzed_episode in TEST_INDEXES:

    lang_goals_set = set()

    for i in range(1000):
        # sample from dataset
        batch = next(test_data_iter)
        lang_goal = batch['lang_goal'][0][0][0]
        task, episode, frame, kp = lang_goal.split('-')
        if analyzed_episode == int(episode.replace('episode_', '')):
            lang_goals_set.add(lang_goal)

    replay_buffer_list = natsorted(lang_goals_set)

    dict_test_episodes_frames[analyzed_episode] = replay_buffer_list

print(dict_test_episodes_frames)

In [None]:
import os
import numpy as np
np.set_printoptions(suppress=True)
import json

from scipy.spatial.transform import Rotation as Rot

from handover.ycb import YCB
from notebook_helpers.constants import SCENE_BOUNDS, ROTATION_RESOLUTION


def compose_qq(q1, q2):
    qww = q1[..., 6] * q2[..., 6]
    qxx = q1[..., 3] * q2[..., 3]
    qyy = q1[..., 4] * q2[..., 4]
    qzz = q1[..., 5] * q2[..., 5]

    q1w2x = q1[..., 6] * q2[..., 3]
    q2w1x = q2[..., 6] * q1[..., 3]
    q1y2z = q1[..., 4] * q2[..., 5]
    q2y1z = q2[..., 4] * q1[..., 5]

    q1w2y = q1[..., 6] * q2[..., 4]
    q2w1y = q2[..., 6] * q1[..., 4]
    q1z2x = q1[..., 5] * q2[..., 3]
    q2z1x = q2[..., 5] * q1[..., 3]

    q1w2z = q1[..., 6] * q2[..., 5]
    q2w1z = q2[..., 6] * q1[..., 5]
    q1x2y = q1[..., 3] * q2[..., 4]
    q2x1y = q2[..., 3] * q1[..., 4]

    q3 = np.zeros(np.broadcast_shapes(q1.shape, q2.shape))
    q3[..., 0:3] = compose_qp(q1, q2[..., 0:3])
    q3[..., 3] = q1w2x + q2w1x + q1y2z - q2y1z
    q3[..., 4] = q1w2y + q2w1y + q1z2x - q2z1x
    q3[..., 5] = q1w2z + q2w1z + q1x2y - q2x1y
    q3[..., 6] = qww - qxx - qyy - qzz

    return q3


def compose_qp(q, pt):
    px = pt[..., 0]
    py = pt[..., 1]
    pz = pt[..., 2]

    x = q[..., 0]
    y = q[..., 1]
    z = q[..., 2]
    qx = q[..., 3]
    qy = q[..., 4]
    qz = q[..., 5]
    qw = q[..., 6]

    qxx = qx**2
    qyy = qy**2
    qzz = qz**2
    qwx = qw * qx
    qwy = qw * qy
    qwz = qw * qz
    qxy = qx * qy
    qxz = qx * qz
    qyz = qy * qz

    pt2 = np.zeros((*np.broadcast_shapes(q.shape[:-1], pt.shape[:-1]), 3))
    pt2[..., 0] = x + px + 2 * ((-1 * (qyy + qzz) * px) + ((qxy - qwz) * py) + ((qwy + qxz) * pz))
    pt2[..., 1] = y + py + 2 * (((qwz + qxy) * px) + (-1 * (qxx + qzz) * py) + ((qyz - qwx) * pz))
    pt2[..., 2] = z + pz + 2 * (((qxz - qwy) * px) + ((qwx + qyz) * py) + (-1 * (qxx + qyy) * pz))

    return pt2

def normalize_quaternion(quat):
    return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True)

def quaternion_to_discrete_euler(quaternion, resolution):
    euler = Rot.from_quat(quaternion).as_euler('xyz', degrees=True) + 180 # extrinsic rotations
    assert np.min(euler) >= 0 and np.max(euler) <= 360
    disc = np.around((euler / resolution)).astype(int)
    disc[disc == int(360 / resolution)] = 0
    return disc

def discrete_euler_to_quaternion(discrete_euler, resolution):
    euler = (discrete_euler * resolution) - 180
    return Rot.from_euler('xyz', euler, degrees=True).as_quat()

def get_all_available_grasp_episode(episode):
    dex_ycb_cache = "/home/bepgroup/Projects/PerAct_ws/handover-sim/handover/data/dex-ycb-cache"
    # dex_ycb_cache = "/home/ywatabe/Projects/PerAct/handover-sim/handover/data/dex-ycb-cache"
    pose_file_str = os.path.join(dex_ycb_cache, "pose_{:03d}.npz")
    episode_index = episode
    pose_file = pose_file_str.format(episode_index)
    # print("Loading poses from cache: {}".format(pose_file))
    pose_data = np.load(pose_file)

    meta_file_str = os.path.join(dex_ycb_cache, "meta_{:03d}.json")
    meta_file = meta_file_str.format(episode_index)
    # print("Loading meta from cache: {}".format(meta_file))
    with open(meta_file, "r") as f:
        meta = json.load(f)

    # Get all poses in scene
    pose = pose_data["pose_y"]
    ycb_ids = meta["ycb_ids"]
    ycb_grasp_ind = meta["ycb_grasp_ind"]
    pose[:, :, 2] += 0.92 # Increase z-value by table height

    # Get handover-object pose at handover
    ycb_grasp_id = ycb_ids[ycb_grasp_ind]
    object_pose = pose[-1, ycb_grasp_ind] # Choose last frame
    object_posit = object_pose[:3]
    object_quat = Rot.from_euler('XYZ', object_pose[3:], degrees=False).as_quat()
    object_pose = np.concatenate([object_posit, object_quat])

    # Get handover-object class name and grasps
    grasp_dir = "/home/bepgroup/Projects/PerAct_ws/handover-sim/handover/data/assets/grasps"
    # grasp_dir = "/home/ywatabe/Projects/PerAct/handover-sim/handover/data/assets/grasps"
    ycb_classes = YCB.CLASSES
    class_name = ycb_classes[ycb_grasp_id]
    grasp_file = os.path.join(grasp_dir, "{}.npy".format(class_name))
    # print("Loading grasps from:", grasp_file)
    data = np.load(grasp_file, allow_pickle=True, encoding="bytes")
    grasps = data.item()[b"transforms"]
    grasps_pq = np.zeros((len(grasps), 7))
    grasps_pq[:, 0:3] = grasps[:, :3, 3]
    grasps_pq[:, 3:7] = Rot.from_matrix(grasps[:, :3, :3]).as_quat()

    # Get object grasps and translate to handover-object pose
    object_grasps = compose_qq(object_pose, grasps_pq)
    # NOTE: I think we need to translate to ee-frame for getting the correct grasps - YES we DO
    object_grasps = compose_qq(object_grasps, np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.7071068, -0.7071068]))

    # Convert handoer-object graspposes to voxel based
    scene_bounds = np.array(SCENE_BOUNDS)
    object_grasps_trans = (object_grasps[:,:3] - scene_bounds[:3]) / (scene_bounds[3:] - scene_bounds[:3]) * 100
    object_grasps_rot = quaternion_to_discrete_euler(object_grasps[:, 3:7], ROTATION_RESOLUTION)
    object_grasps_voxel = np.concatenate((object_grasps_trans.astype(int), object_grasps_rot), axis=1)
    return object_grasps_voxel, object_grasps

In [None]:
## Visualize the grasps for the episode using matplotlib

# import numpy as np
# import plotly.graph_objects as go
# from pyquaternion import Quaternion  # For converting quaternions to rotation vectors

# # Example 2D ndarray of x, y, z, and quaternion (x, y, z, w) coordinates
# data = object_grasps

# # Extract positions and quaternions
# positions = data[:, :3]
# quaternions = data[:, 3:]

# # Create lists to hold the arrow start and end points
# x_start, y_start, z_start = [], [], []
# x_end_x, y_end_x, z_end_x = [], [], []
# x_end_y, y_end_y, z_end_y = [], [], []
# x_end_z, y_end_z, z_end_z = [], [], []

# # Calculate orientation vectors derived from quaternions
# for pos, quat in zip(positions, quaternions):
#     q = Quaternion(quat[3], quat[0], quat[1], quat[2])  # Convert to Quaternion object
#     x_dir = q.rotate([1, 0, 0])  # X-axis direction
#     y_dir = q.rotate([0, 1, 0])  # Y-axis direction
#     z_dir = q.rotate([0, 0, 1])  # Z-axis direction
    
#     # Start points
#     x_start.append(pos[0])
#     y_start.append(pos[1])
#     z_start.append(pos[2])
    
#     # End points for x, y, z directions
#     x_end_x.append(pos[0] + x_dir[0])
#     y_end_x.append(pos[1] + x_dir[1])
#     z_end_x.append(pos[2] + x_dir[2])
    
#     x_end_y.append(pos[0] + y_dir[0])
#     y_end_y.append(pos[1] + y_dir[1])
#     z_end_y.append(pos[2] + y_dir[2])
    
#     x_end_z.append(pos[0] + z_dir[0])
#     y_end_z.append(pos[1] + z_dir[1])
#     z_end_z.append(pos[2] + z_dir[2])

# # Create a 3D scatter plot for positions
# scatter = go.Scatter3d(
#     x=positions[:, 0],
#     y=positions[:, 1],
#     z=positions[:, 2],
#     mode='markers',
#     marker=dict(size=5, color='blue'),
#     name='Positions'
# )

# # Create 3D quiver-like plots for X, Y, and Z orientations
# quiver_x = go.Cone(
#     x=x_start,
#     y=y_start,
#     z=z_start,
#     u=np.array(x_end_x) - np.array(x_start),
#     v=np.array(y_end_x) - np.array(y_start),
#     w=np.array(z_end_x) - np.array(z_start),
#     sizemode="scaled",
#     sizeref=0.5,
#     anchor="tail",
#     colorscale=[[0, 'red'], [1, 'red']],
#     name='X-axis'
# )

# quiver_y = go.Cone(
#     x=x_start,
#     y=y_start,
#     z=z_start,
#     u=np.array(x_end_y) - np.array(x_start),
#     v=np.array(y_end_y) - np.array(y_start),
#     w=np.array(z_end_y) - np.array(z_start),
#     sizemode="scaled",
#     sizeref=0.5,
#     anchor="tail",
#     colorscale=[[0, 'green'], [1, 'green']],
#     name='Y-axis'
# )

# quiver_z = go.Cone(
#     x=x_start,
#     y=y_start,
#     z=z_start,
#     u=np.array(x_end_z) - np.array(x_start),
#     v=np.array(y_end_z) - np.array(y_start),
#     w=np.array(z_end_z) - np.array(z_start),
#     sizemode="scaled",
#     sizeref=0.5,
#     anchor="tail",
#     colorscale=[[0, 'blue'], [1, 'blue']],
#     name='Z-axis'
# )

# # Combine plots
# fig = go.Figure(data=[scatter, quiver_x, quiver_y, quiver_z])

# # Customize the layout
# fig.update_layout(
#     scene=dict(
#         xaxis_title='X',
#         yaxis_title='Y',
#         zaxis_title='Z'
#     ),
#     title='Interactive 3D Visualization of Positions and Orientations'
# )

# fig.show()

In [None]:
from notebook_helpers.build_training import NumpyEncoder
from notebook_helpers.constants import *

import matplotlib.pyplot as plt

from arm.utils import get_gripper_render_pose
from arm.utils import visualise_voxel


model_run_iter = natsorted([run for run in os.listdir(run_dir) if "run" in run])[-1]

# Choose the loss metric at which model is saved
# chosen_model = "best_model_train"
# chosen_model = "best_model_test"
# chosen_model = "best_model_general"
chosen_model = "last_model"

# Save results
model_run_scored = dict()
model_run_distances = dict()
# model_run_rotations = dict() # Ignore rotation, location matters most

# Load model if available
try:
    peract_agent.load_weights(os.path.join(run_dir, model_run_iter, chosen_model))
except:
    print(f"Model {model_run_iter} not found, skipping.")

fig, axs = plt.subplots(2,1, figsize=(10, 6), sharex=True)

for analyzed_episode in TEST_INDEXES:
    

    all_frames_episode = dict_test_episodes_frames[analyzed_episode]
    
    episode_frame_length = len(all_frames_episode)-1
    episode_frames_frame_point = []
    episode_frames_dist_expert = []
    episode_frames_rot_expert = []
    episode_frames_dist = []
    episode_frames_rot = []

    # fig = plt.figure(figsize=(16, 8))

    for episode_frame_i, analyzed_frame_episode in enumerate(all_frames_episode): # Loop through all available frames of the replay buffer

        while True: # Find the frames allocated to the episode
            batch = next(test_data_iter)
            lang_goal = batch['lang_goal'][0][0][0]
            if lang_goal == analyzed_frame_episode:
                batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
                break

        update_dict = peract_agent.update(None, batch, backprop=False)

        # Get predictions
        pred_trans = update_dict["pred_action"]["trans"]
        pred_rot_grip = update_dict["pred_action"]["rot_and_grip"]

        # Get expert action and convert to real coordinates
        gt_trans_expert = update_dict["expert_action"]["action_trans"][0]
        gt_rot_and_grip_expert = update_dict["expert_action"]["rot_and_grip"][0]
        gt_quat_expert = discrete_euler_to_quaternion(gt_rot_and_grip_expert[:3], ROTATION_RESOLUTION)
        pred_quat = discrete_euler_to_quaternion(pred_rot_grip[:,:3], ROTATION_RESOLUTION)

        # Calculate distance/angular error: pred vs. expert
        dist_expert = np.round(np.linalg.norm(pred_trans-gt_trans_expert, axis=1), 4)
        angle_dist_expert = 1.0 - np.abs(np.sum(normalize_quaternion(gt_quat_expert) * normalize_quaternion(pred_quat)))
        print(f"prediction: {pred_trans}, {pred_rot_grip[:,:3]}")
        print(f"expert: {gt_trans_expert}, {gt_rot_and_grip_expert[:3]} - trans-diff: {dist_expert} - ang-diff: {angle_dist_expert}")
        episode_frames_dist_expert.append(dist_expert)
        episode_frames_rot_expert.append(angle_dist_expert)

        # Get all ground truth grasps to compare
        gt_voxel_all, gt_all = get_all_available_grasp_episode(analyzed_episode)
        gt_trans_all = gt_voxel_all[:,:3]
        gt_rot_all = gt_voxel_all[:,3:]
        gt_quat_all = discrete_euler_to_quaternion(gt_rot_all, 5)
        gt_quat_all_flipped = compose_qq(gt_all, np.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0]))[:,3:] # Also obtain the flipped (wrist) ground truth

        # Calcualte distance/angular error: pred vs. any-GT
        dist_all = np.round(np.linalg.norm(pred_trans-gt_trans_all, axis=1), 4)
        min_dist_idx = np.argmin(dist_all)
        dist = dist_all[min_dist_idx]
        angle_dist = 1.0 - np.abs(np.sum(gt_quat_all[min_dist_idx] * pred_quat))
        angle_dist_flipped = 1.0 - np.abs(np.sum(gt_quat_all_flipped[min_dist_idx] * pred_quat))
        angle_dist = min(angle_dist, angle_dist_flipped)
        print(f"ground-truth: {gt_trans_all[min_dist_idx]}, {gt_rot_all[min_dist_idx]} - trans-diff: {dist} - ang-diff: {angle_dist}")
        print(f"Quaternion ground-truth: {gt_quat_all[min_dist_idx]} {pred_quat}")
        episode_frames_dist.append(dist)
        episode_frames_rot.append(angle_dist)

        episode_frames_frame_point.append(episode_frame_i/episode_frame_length)

        # Plot the corresponding Prediction vs. Expert vs. Closest-GT
        # fig_voxel = plt.figure(figsize=(16, 8))
        voxel_size = 0.045
        render_gripper = True
        voxel_scale = voxel_size * 100

        for i in range(3):
            vis_voxel_grid = update_dict['voxel_grid'][0].cpu().numpy()
            vis_trans_coord = pred_trans
            if i == 0:
                vis_gt_coord = pred_trans[0]
                continuous_quat = pred_quat[0]
            if i == 1:
                vis_gt_coord = gt_trans_expert
                continuous_quat = gt_quat_expert
            if i == 2:
                vis_gt_coord = gt_trans_all[min_dist_idx]
                continuous_quat = gt_quat_all[min_dist_idx]
            
            scene_bounds = np.array(SCENE_BOUNDS)
            continuous_trans = vis_gt_coord*0.01 + scene_bounds[:3]

            gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                                       SCENE_BOUNDS[:3],
                                                       continuous_trans,
                                                       continuous_quat)

            rendered_img = visualise_voxel(vis_voxel_grid,
                                           None,
                                           vis_trans_coord,
                                           vis_gt_coord,
                                           voxel_size=voxel_size,
                                           rotation_amount=np.deg2rad(45),
                                           render_gripper=render_gripper,
                                           gripper_pose=gripper_pose_mat,
                                           gripper_mesh_scale=voxel_scale)
            # fig_voxel.add_subplot(1, 3, i+1)
            # plt.imshow(rendered_img)

        # plt.show()

    axs[0].plot(episode_frames_frame_point, episode_frames_dist_expert, label=f'episode (expert): {analyzed_episode}', linestyle='dashed')
    axs[0].plot(episode_frames_frame_point, episode_frames_dist, label=f'episode (any): {analyzed_episode}')

    print(episode_frames_frame_point, episode_frames_rot_expert, episode_frames_rot)
    axs[1].plot(episode_frames_frame_point, episode_frames_rot_expert, label=f'episode (expert): {analyzed_episode}', linestyle='dashed')
    axs[1].plot(episode_frames_frame_point, episode_frames_rot, label=f'episode (any): {analyzed_episode}')

axs[0].set_xlabel(f"Episode length [timestep / frame length] (@t=1, timestamp is at approach)")
axs[0].set_ylabel(f"Translation error [voxels]")
axs[1].set_xlabel(f"Episode length [timestep / frame length] (@t=1, timestamp is at approach)")
axs[1].set_ylabel(f"Rotaion error")

plt.title(f"{TASK}\n{settings}")
axs[0].legend()
axs[1].legend()
plt.show()