# Inference on Trained PerAct Agent on handoversim data

In [None]:
import numpy as np
np.bool = np.bool_ # bad trick to fix numpy version issue :(
import os
import sys
sys.path = [p for p in sys.path if '/peract/' not in p]

# Set `PYOPENGL_PLATFORM=egl` for pyrender visualizations
os.environ["DISPLAY"] = ":0"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" # Depends on your computer and available GPUs

In [None]:
import json

from notebook_helpers.constants import * # Load global constant variables from constants.py
from notebook_helpers.build_training import build_agent

# Choose the run
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-10_14-59" # Good {non-uniform 1 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad {uniform 2 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_13-34"
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_17-25" # {crop skip 10}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_21-23"
run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-16_11-25/" # {crop skip 10 new}

# Obtain settings
path_settings = os.path.join(os.path.dirname(run_dir), "training_settings.json")
with open(path_settings, 'r') as f:
    settings = json.load(f)
CAMERAS = settings['cameras']

if BATCH_SIZE != 1:
    raise ValueError("For Inference, 'batch_size' must be set to 1 in constants.py.")

peract_agent = build_agent(settings, training=False)
peract_agent.set_language_goal("handing over banana")

# Choose model
iteration = "run8000"
best_type = "best_model_general"
model_path = os.path.join(run_dir, iteration, best_type)

peract_agent.load_weights(model_path)

### Inference: Run Inference on just observation data and Save as video

In [None]:
import imageio
import numpy as np
from matplotlib import pyplot as plt
from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs
from arm.utils import visualise_voxel
from arm.utils import get_gripper_render_pose


# What to visualize
episode_idx_to_visualize = 646  # Index of the episode to visualize
# Video output path
video_output_path = f"demo{episode_idx_to_visualize}_visualization.mp4"

# Get demo
demo = get_stored_demo(data_path=test_data_path,
                       index=episode_idx_to_visualize,
                       cameras=CAMERAS,
                       depth_scale=DEPTH_SCALE)

episode_length = list(range(len(demo._observations)))

# Open a video writer
with imageio.get_writer(video_output_path, fps=10) as video_writer:
    for ts in episode_length[::5]:
        print(ts)
        # Extract obs at timestep
        obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
        # gripper_pose = demo[ts].gripper_pose
        gripper_open = demo[ts].gripper_open
        gripper_joint_positions = demo[ts].gripper_joint_positions

        obs_dict["gripper_open"] = gripper_open
        obs_dict["gripper_joint_positions"] = gripper_joint_positions

        (continuous_trans, continuous_quat, gripper_open, _, _), \
        (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent.forward(obs_dict, ts)
        print(continuous_trans, continuous_quat, gripper_open)

        # Things to visualize
        vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()
        pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()

        voxel_size = 0.045
        voxel_scale = voxel_size * 100
        gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                                   SCENE_BOUNDS[:3],
                                                   continuous_trans,
                                                   continuous_quat)

        rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                         None,
                                         [pred_trans_coord],
                                         None,
                                         voxel_size=voxel_size,
                                         rotation_amount=np.deg2rad(0),
                                         render_gripper=True,
                                         gripper_pose=gripper_pose_mat,
                                         gripper_mesh_scale=voxel_scale)

        rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                           None,
                                           [pred_trans_coord],
                                           None,
                                           voxel_size=voxel_size,
                                           rotation_amount=np.deg2rad(45+180),
                                           render_gripper=True,
                                           gripper_pose=gripper_pose_mat,
                                           gripper_mesh_scale=voxel_scale)

        # Plot figures into a NumPy array
        fig = plt.figure(figsize=(20, 15))
        fig.add_subplot(1, 2, 1)
        plt.imshow(rendered_img_0)
        plt.axis('off')
        plt.title("Front view")
        fig.add_subplot(1, 2, 2)
        plt.imshow(rendered_img_270)
        plt.axis('off')
        plt.title("Side view")

        # Add timestamp as text with white font and black background
        fig.text(0.02, 0.95, f"Timestep: {ts}", ha='left', fontsize=16, color='white', weight='bold',
                 bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.3'))

        # Convert the matplotlib figure to a NumPy array
        fig.canvas.draw()
        img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        video_writer.append_data(img_array)  # Add frame to video
        plt.close(fig)  # Close the figure to free memory

print(f"Video saved to {video_output_path}")


### Validation: Run Inference on just input data with ground truth to check what happends

In [None]:
def get_gt_from_frame(episode_keypoints_gt_obs_dict, frame_idx):
    # Get a sorted list of dictionary keys (frames with stored data)
    episode_keypoints = sorted(episode_keypoints_gt_obs_dict.keys())

    # Iterate through the given frame indices
    # Find the smallest key that is greater than or equal to the current frame
    for episode_kp in episode_keypoints:
        if frame_idx <= episode_kp:
            episode_kp_gt_obs_dict = episode_keypoints_gt_obs_dict[episode_kp]
            return episode_kp_gt_obs_dict
    
    return None

In [None]:
#TODO: NEED SOME WAY TO COLLECT THE GROUND TRUTH DATA FOR ALL FRAMES

import imageio
import numpy as np
from matplotlib import pyplot as plt

from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs_gt

from arm.demo import _keypoint_discovery_available
from arm.utils import point_to_voxel_index, visualise_voxel, get_gripper_render_pose


# What to visualize
episode_idx_to_visualize = 645  # Index of the episode to visualize
# Video output path
video_output_path = f"demo{episode_idx_to_visualize}_gt_pred_visualization.mp4"

# Get demo
demo = get_stored_demo(data_path=test_data_path,
                       index=episode_idx_to_visualize,
                       cameras=CAMERAS,
                       depth_scale=DEPTH_SCALE)

episode_keypoints_gt_obs_dict = dict()
episode_keypoints = _keypoint_discovery_available(demo, approach_distance=0.3) #NOTE: Approach_distance Set
episode_keypoints = [episode_keypoints[-1]]
for episode_keypoint in episode_keypoints:
    episode_keypoints_gt_obs_dict[episode_keypoint] = extract_obs_gt(obs = demo._observations[episode_keypoint],
                                                                  cameras=CAMERAS)

episode_length = list(range(len(demo._observations)))
# episode_length = [40]

# Open a video writer
with imageio.get_writer(video_output_path, fps=10) as video_writer:
    for ts in episode_length[::5]: # Skip some frames
        print(ts)
        # Extract obs at timestep
        obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
        # gripper_pose = demo[ts].gripper_pose
        gripper_open = demo[ts].gripper_open
        gripper_joint_positions = demo[ts].gripper_joint_positions

        obs_dict["gripper_open"] = gripper_open
        obs_dict["gripper_joint_positions"] = gripper_joint_positions

        (continuous_trans, continuous_quat, gripper_open, trans_confidence, _), \
        (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent.forward(obs_dict, ts)

        pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()
        
        # Get the ground truth
        episode_keypoint_gt_obs_dict = get_gt_from_frame(episode_keypoints_gt_obs_dict, ts)
        if not (episode_keypoint_gt_obs_dict is None):
            gt_gripper_pose = episode_keypoint_gt_obs_dict["gripper_pose"]
            gt_trans_coord = point_to_voxel_index(gt_gripper_pose[:3], VOXEL_SIZES, SCENE_BOUNDS)[0]
            error = np.linalg.norm(gt_gripper_pose[:3] - continuous_trans)
            print(f"GT (voxel): {gt_trans_coord} - Prediction (voxel): {pred_trans_coord} - Error: {error} - Prediction-score: {round(trans_confidence,4)}")
        else:
            gt_trans_coord = None
            error = False
            print("GT coordinates not available for this frame")
        
        # Things to visualize
        vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()

        voxel_size = 0.045
        voxel_scale = voxel_size * 100
        gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                                   SCENE_BOUNDS[:3],
                                                   continuous_trans,
                                                   continuous_quat)

        rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                         None,
                                         [pred_trans_coord],
                                         gt_trans_coord,
                                         alpha = 0.2,
                                         voxel_size=voxel_size,
                                         rotation_amount=np.deg2rad(0),
                                         render_gripper=True,
                                         gripper_pose=gripper_pose_mat,
                                         gripper_mesh_scale=voxel_scale)

        rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                           None,
                                           [pred_trans_coord],
                                           gt_trans_coord,
                                           alpha = 0.2,
                                           voxel_size=voxel_size,
                                           rotation_amount=np.deg2rad(45+180),
                                           render_gripper=True,
                                           gripper_pose=gripper_pose_mat,
                                           gripper_mesh_scale=voxel_scale)

        # Plot figures into a NumPy array
        fig = plt.figure(figsize=(20, 15))
        fig.add_subplot(1, 2, 1)
        plt.imshow(rendered_img_0)
        plt.axis('off')
        plt.title("Front view")
        fig.add_subplot(1, 2, 2)
        plt.imshow(rendered_img_270)
        plt.axis('off')
        plt.title("Side view")

        # Add timestamp as text with white font and black background
        if error:
            fig.text(0.02, 0.95, f"Timestep: {ts}, Prediction-score: {round(trans_confidence,4)}, Error: {np.round(error, 3)}", ha='left', fontsize=16, color='white', weight='bold',
                    bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.3'))
        else:
            fig.text(0.02, 0.95, f"Timestep: {ts}, Prediction-score: {round(trans_confidence,4)}", ha='left', fontsize=16, color='white', weight='bold',
                    bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.3'))

        # Convert the matplotlib figure to a NumPy array
        fig.canvas.draw()
        img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        video_writer.append_data(img_array)  # Add frame to video
        # plt.show()
        plt.close(fig)  # Close the figure to free memory
        

print(f"Video saved to {video_output_path}")


### Validation: Run Inference with same batches as during training (You can change `batch_size` in <i>constants.py</i>)

In [None]:
# import json
# from notebook_helpers.build_replay import load_replay_buffer

# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-16_11-25/" # {crop skip 10 new}

# path_settings = os.path.join(os.path.dirname(run_dir), "training_settings.json")
# with open(path_settings, 'r') as f:
#     settings = json.load(f)

# train_data_iter, test_data_iter = load_replay_buffer(settings)

In [None]:
import json

from notebook_helpers.constants import * # Load global constant variables from constants.py
from notebook_helpers.build_training import build_agent
from notebook_helpers.build_replay import load_replay_buffer

# Choose the run
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-10_14-59" # Good {non-uniform 1 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad {uniform 2 kp}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_12-30" # Bad
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-11_13-34"
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_17-25" # {crop skip 10}
# run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-12_21-23"
run_dir = "/home/ywatabe/Projects/PerAct/outputs/models/handing_over_banana/2024-12-16_11-25/" # {crop skip 10 new}

# Obtain settings
path_settings = os.path.join(os.path.dirname(run_dir), "training_settings.json")
with open(path_settings, 'r') as f:
    settings = json.load(f)

train_data_iter, test_data_iter = load_replay_buffer(settings)

peract_agent = build_agent(settings, training=True)
peract_agent.set_language_goal("handing over banana")

# Choose model
iteration = "run8000"
best_type = "best_model_general"
model_path = os.path.join(run_dir, iteration, best_type)

peract_agent.load_weights(model_path)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from rlbench.utils import get_stored_demo
from rlbench.backend.utils import extract_obs

CAMERAS = settings["cameras"]

batch = next(train_data_iter)

# what to visualize
episode_idx_to_visualize = 646#INDEXES[0] # out of 10 demos
# ts = 70#25 # timestep out of total timesteps

# get demo
demo = get_stored_demo(data_path=test_data_path,
                    index=episode_idx_to_visualize,
                    cameras=CAMERAS,
                    depth_scale=DEPTH_SCALE,)

episode_length = list(range(len(demo._observations)))
for ts in episode_length:

    # extract obs at timestep
    obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
    gripper_pose = demo[ts].gripper_pose
    gripper_open = demo[ts].gripper_open
    gripper_joint_positions = demo[ts].gripper_joint_positions

    # obs_dict["gripper_pose"] = gripper_pose
    obs_dict["gripper_open"] = gripper_open
    obs_dict["gripper_joint_positions"] = gripper_joint_positions

    # plot rgb and depth at timestep
    fig = plt.figure(figsize=(20, 10))
    rows, cols = 2, len(CAMERAS)

    plot_idx = 1
    for camera in CAMERAS:
        # rgb
        rgb_name = "%s_%s" % (camera, 'rgb')
        rgb = np.transpose(obs_dict[rgb_name], (1, 2, 0))
        fig.add_subplot(rows, cols, plot_idx)
        plt.imshow(rgb)
        plt.axis('off')
        plt.title("%s_rgb | step %s" % (camera, ts))

        # depth
        depth_name = "%s_%s" % (camera, 'depth')
        depth = np.transpose(obs_dict[depth_name], (1, 2, 0))
        fig.add_subplot(rows, cols, plot_idx+len(CAMERAS))
        plt.imshow(depth)
        plt.axis('off')
        plt.title("%s_depth | step %s" % (camera, ts))

        plot_idx += 1

    plt.show()

    print(obs_dict)

    (continuous_trans, continuous_quat, gripper_open), (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent.forward(obs_dict, ts)

    from arm.utils import visualise_voxel
    from arm.utils import discrete_euler_to_quaternion, get_gripper_render_pose

    # things to visualize
    vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()
    pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()

    # discrete to continuous
    continuous_trans = continuous_trans[0].detach().cpu().numpy()
    continuous_quat = discrete_euler_to_quaternion(rot_and_grip_indices[0][:3].detach().cpu().numpy(),
                                                resolution=peract_agent._rotation_resolution)
    gripper_open = bool(rot_and_grip_indices[0][-1].detach().cpu().numpy())
    ignore_collision = bool(test_update_dict['pred_action']['collision'][0][0].detach().cpu().numpy())

    # # gripper visualization pose
    voxel_size = 0.045
    voxel_scale = voxel_size * 100
    gripper_pose_mat = get_gripper_render_pose(voxel_scale,
                                            SCENE_BOUNDS[:3],
                                            continuous_trans,
                                            continuous_quat)

    # #@markdown #### Show Q-Prediction and Best Action
    show_expert_action = True  #@param {type:"boolean"}
    show_q_values = False  #@param {type:"boolean"}
    render_gripper = False  #@param {type:"boolean"}
    rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

    rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(0),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale,
                                perspective=False)

    rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(45),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale)


    fig = plt.figure(figsize=(20, 15))
    fig.add_subplot(1, 2, 1)
    plt.imshow(rendered_img_0)
    plt.axis('off')
    plt.title("Front view")
    fig.add_subplot(1, 2, 2)
    plt.imshow(rendered_img_270)
    plt.axis('off')
    plt.title("Side view")


    # #@markdown #### Show Q-Prediction and Best Action
    show_expert_action = True  #@param {type:"boolean"}
    show_q_values = True  #@param {type:"boolean"}
    render_gripper = True  #@param {type:"boolean"}
    rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

    rendered_img_0 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(0),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale)

    rendered_img_270 = visualise_voxel(vis_voxel_grid,
                                None,
                                [pred_trans_coord],
                                None,
                                voxel_size=voxel_size,
                                rotation_amount=np.deg2rad(45),
                                render_gripper=render_gripper,
                                gripper_pose=gripper_pose_mat,
                                gripper_mesh_scale=voxel_scale)


    fig = plt.figure(figsize=(20, 15))
    fig.add_subplot(1, 2, 1)
    plt.imshow(rendered_img_0)
    plt.axis('off')
    plt.title("Front view")
    fig.add_subplot(1, 2, 2)
    plt.imshow(rendered_img_270)
    plt.axis('off')
    plt.title("Side view")

    print(f"Lang goal: {lang_goal}")


In [None]:
model_runs = natsorted([os.path.join(run_dir, run) for run in os.listdir(run_dir) if "run" in run])

chosen_model = "best_model_train"

model_run_distances = dict()
model_run_scores = dict()

for model_run_iter in model_runs:
    
    try:
        peract_agent.load_weights(os.path.join(model_run_iter, chosen_model))
    except:
        continue

    distances_run = []
    scores_run = []
    
    for i in range(20): # collest using 100 samples
        print(i)
        batch = next(train_data_iter) # NOTE: Choose which set to infer
        lang_goal = batch['lang_goal'][0][0][0]
        batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
        
        update_dict = peract_agent.update(None, batch, backprop=False)
        prediction_score = round(update_dict["q_trans"].max().item(), 4)
        pred_trans = update_dict["pred_action"]["trans"].detach().cpu().numpy()[0]
        gt_trans = update_dict["expert_action"]["action_trans"].detach().cpu().numpy()[0]
        dist = np.round(np.linalg.norm(pred_trans-gt_trans), 4)
        
        distances_run.append(dist)
        scores_run.append(prediction_score)
    
    zipped_lists = zip(distances_run, scores_run)

    sorted_lists = sorted(zipped_lists, key=lambda x: x[0])

    sorted_distances, sorted_scores = zip(*sorted_lists)
    model_run_distances[os.path.dirname(model_run_iter)] = sorted_distances
    model_run_scores[os.path.dirname(model_run_iter)] = sorted_scores


In [None]:
# from rlbench.utils import get_stored_demo
# from rlbench.backend.utils import extract_obs

# batch = next(train_data_iter)

# # what to visualize
# episode_idx_to_visualize = 846#INDEXES[0] # out of 10 demos
# # ts = 70#25 # timestep out of total timesteps

# # get demo
# demo = get_stored_demo(data_path=test_data_path,
#                     index=episode_idx_to_visualize,
#                     cameras=CAMERAS,
#                     depth_scale=DEPTH_SCALE,)

# episode_length = list(range(len(demo._observations)))
# for ts in episode_length:

#     # extract obs at timestep
#     obs_dict = extract_obs(demo._observations[ts], CAMERAS, t=ts)
#     gripper_pose = demo[ts].gripper_pose
#     gripper_open = demo[ts].gripper_open
#     gripper_joint_positions = demo[ts].gripper_joint_positions

#     # obs_dict["gripper_pose"] = gripper_pose
#     obs_dict["gripper_open"] = gripper_open
#     obs_dict["gripper_joint_positions"] = gripper_joint_positions

#     # plot rgb and depth at timestep
#     fig = plt.figure(figsize=(20, 10))
#     rows, cols = 2, len(CAMERAS)

#     plot_idx = 1
#     for camera in CAMERAS:
#         # rgb
#         rgb_name = "%s_%s" % (camera, 'rgb')
#         rgb = np.transpose(obs_dict[rgb_name], (1, 2, 0))
#         fig.add_subplot(rows, cols, plot_idx)
#         plt.imshow(rgb)
#         plt.axis('off')
#         plt.title("%s_rgb | step %s" % (camera, ts))

#         # depth
#         depth_name = "%s_%s" % (camera, 'depth')
#         depth = np.transpose(obs_dict[depth_name], (1, 2, 0))
#         fig.add_subplot(rows, cols, plot_idx+len(CAMERAS))
#         plt.imshow(depth)
#         plt.axis('off')
#         plt.title("%s_depth | step %s" % (camera, ts))

#         plot_idx += 1

#     plt.show()

#     print(obs_dict)

#     (continuous_trans, continuous_quat, gripper_open), (voxel_grid, coord_indices, rot_and_grip_indices, gripper_open) = peract_agent_inference.forward(obs_dict, ts)

#     from arm.utils import visualise_voxel
#     from arm.utils import discrete_euler_to_quaternion, get_gripper_render_pose

#     # things to visualize
#     vis_voxel_grid = voxel_grid[0].detach().cpu().numpy()
#     pred_trans_coord = coord_indices[0].detach().cpu().numpy().tolist()

#     # discrete to continuous
#     continuous_trans = continuous_trans[0].detach().cpu().numpy()
#     continuous_quat = discrete_euler_to_quaternion(rot_and_grip_indices[0][:3].detach().cpu().numpy(),
#                                                 resolution=peract_agent_inference._rotation_resolution)
#     gripper_open = bool(rot_and_grip_indices[0][-1].detach().cpu().numpy())
#     ignore_collision = bool(test_update_dict['pred_action']['collision'][0][0].detach().cpu().numpy())

#     # # gripper visualization pose
#     voxel_size = 0.045
#     voxel_scale = voxel_size * 100
#     gripper_pose_mat = get_gripper_render_pose(voxel_scale,
#                                             SCENE_BOUNDS[:3],
#                                             continuous_trans,
#                                             continuous_quat)

#     # #@markdown #### Show Q-Prediction and Best Action
#     show_expert_action = True  #@param {type:"boolean"}
#     show_q_values = False  #@param {type:"boolean"}
#     render_gripper = False  #@param {type:"boolean"}
#     rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

#     rendered_img_0 = visualise_voxel(vis_voxel_grid,
#                                 None,
#                                 [pred_trans_coord],
#                                 None,
#                                 voxel_size=voxel_size,
#                                 rotation_amount=np.deg2rad(0),
#                                 render_gripper=render_gripper,
#                                 gripper_pose=gripper_pose_mat,
#                                 gripper_mesh_scale=voxel_scale,
#                                 perspective=False)

#     rendered_img_270 = visualise_voxel(vis_voxel_grid,
#                                 None,
#                                 [pred_trans_coord],
#                                 None,
#                                 voxel_size=voxel_size,
#                                 rotation_amount=np.deg2rad(45),
#                                 render_gripper=render_gripper,
#                                 gripper_pose=gripper_pose_mat,
#                                 gripper_mesh_scale=voxel_scale)


#     fig = plt.figure(figsize=(20, 15))
#     fig.add_subplot(1, 2, 1)
#     plt.imshow(rendered_img_0)
#     plt.axis('off')
#     plt.title("Front view")
#     fig.add_subplot(1, 2, 2)
#     plt.imshow(rendered_img_270)
#     plt.axis('off')
#     plt.title("Side view")


#     # #@markdown #### Show Q-Prediction and Best Action
#     show_expert_action = True  #@param {type:"boolean"}
#     show_q_values = True  #@param {type:"boolean"}
#     render_gripper = True  #@param {type:"boolean"}
#     rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

#     rendered_img_0 = visualise_voxel(vis_voxel_grid,
#                                 None,
#                                 [pred_trans_coord],
#                                 None,
#                                 voxel_size=voxel_size,
#                                 rotation_amount=np.deg2rad(0),
#                                 render_gripper=render_gripper,
#                                 gripper_pose=gripper_pose_mat,
#                                 gripper_mesh_scale=voxel_scale)

#     rendered_img_270 = visualise_voxel(vis_voxel_grid,
#                                 None,
#                                 [pred_trans_coord],
#                                 None,
#                                 voxel_size=voxel_size,
#                                 rotation_amount=np.deg2rad(45),
#                                 render_gripper=render_gripper,
#                                 gripper_pose=gripper_pose_mat,
#                                 gripper_mesh_scale=voxel_scale)


#     fig = plt.figure(figsize=(20, 15))
#     fig.add_subplot(1, 2, 1)
#     plt.imshow(rendered_img_0)
#     plt.axis('off')
#     plt.title("Front view")
#     fig.add_subplot(1, 2, 2)
#     plt.imshow(rendered_img_270)
#     plt.axis('off')
#     plt.title("Side view")

#     print(f"Lang goal: {lang_goal}")


In [None]:
# from arm.utils import discrete_euler_to_quaternion, get_gripper_render_pose

# batch = next(test_data_iter)

# lang_goal = batch['lang_goal'][0][0][0]
# print(lang_goal)

# batch = {k: v.to(device) for k, v in batch.items() if type(v) == torch.Tensor}
# test_update_dict = peract_agent_inference.update(0, batch, backprop=False) # Here backprop == False: for evaluation, hence test_loss == total_loss

# # Log test metrics
# test_metrics = {
#     "total_loss": test_update_dict['total_loss'],
#     "trans_loss": test_update_dict['trans_loss'],
#     "rot_loss": test_update_dict['rot_loss'],
#     "col_loss": test_update_dict['col_loss']
# }
# for episode_kp, value in test_metrics.items():
#     print(episode_kp, value)


# from arm.utils import visualise_voxel

# # # things to visualize
# vis_voxel_grid = test_update_dict['voxel_grid'][0].detach().cpu().numpy()
# vis_trans_q = test_update_dict['q_trans'][0].detach().cpu().numpy()
# pred_trans_coord = test_update_dict['pred_action']['trans'][0].detach().cpu().numpy().tolist()
# vis_gt_coord = test_update_dict['expert_action']['action_trans'][0].detach().cpu().numpy()

# # discrete to continuous
# continuous_trans = test_update_dict['pred_action']['continuous_trans'][0].detach().cpu().numpy()
# continuous_quat = discrete_euler_to_quaternion(test_update_dict['pred_action']['rot_and_grip'][0][:3].detach().cpu().numpy(),
#                                             resolution=peract_agent_inference._rotation_resolution)
# gripper_open = bool(test_update_dict['pred_action']['rot_and_grip'][0][-1].detach().cpu().numpy())
# ignore_collision = bool(test_update_dict['pred_action']['collision'][0][0].detach().cpu().numpy())

# # gripper visualization pose
# voxel_size = 0.045
# voxel_scale = voxel_size * 100
# # print(continuous_trans, continuous_quat)
# gripper_pose = batch['gripper_state'][:, -1][0].detach().cpu().numpy()
# continuous_trans = gripper_pose[:3]
# continuous_quat = gripper_pose[3:7]
# gripper_pose_mat = get_gripper_render_pose(voxel_scale,
#                                         SCENE_BOUNDS[:3],
#                                         continuous_trans,
#                                         continuous_quat)


# # #@markdown #### Show Q-Prediction and Best Action
# show_expert_action = True  #@param {type:"boolean"}
# show_q_values = True  #@param {type:"boolean"}
# render_gripper = True  #@param {type:"boolean"}
# rotation_amount = -90 #@param {type:"slider", min:-180, max:180, step:5}

# rendered_img_0 = visualise_voxel(vis_voxel_grid,
#                             vis_trans_q if show_q_values else None,
#                             [pred_trans_coord],
#                             vis_gt_coord if show_expert_action else None,
#                             voxel_size=voxel_size,
#                             rotation_amount=np.deg2rad(0),
#                             render_gripper=render_gripper,
#                             gripper_pose=gripper_pose_mat,
#                             gripper_mesh_scale=voxel_scale)

# rendered_img_270 = visualise_voxel(vis_voxel_grid,
#                             vis_trans_q if show_q_values else None,
#                             [pred_trans_coord],
#                             vis_gt_coord if show_expert_action else None,
#                             voxel_size=voxel_size,
#                             rotation_amount=np.deg2rad(45),
#                             render_gripper=render_gripper,
#                             gripper_pose=gripper_pose_mat,
#                             gripper_mesh_scale=voxel_scale)


# fig = plt.figure(figsize=(20, 15))
# fig.add_subplot(1, 2, 1)
# plt.imshow(rendered_img_0)
# plt.axis('off')
# plt.title("Front view")
# fig.add_subplot(1, 2, 2)
# plt.imshow(rendered_img_270)
# plt.axis('off')
# plt.title("Side view")

# # print(f"Lang goal: {lang_goal}")