In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from moviepy import VideoFileClip
from IPython.display import Image, display
from plotly.subplots import make_subplots
from multiprocessing.pool import ThreadPool
import copy
import json

In [None]:
def visualize_pcd_sequence(xyz_list, rgb_list, actions_list, num_cols, pc_labels_list):
    num_keysteps = len(xyz_list)
    num_rows = (num_keysteps + num_cols - 1) // num_cols

    # Create subplot figure
    fig = make_subplots(
        rows=num_rows, 
        cols=num_cols,
        specs=[[{'type': 'scene'} for _ in range(num_cols)] for _ in range(num_rows)],
        subplot_titles=[f'Step {i}' for i in range(num_keysteps)]
    )

    if len(rgb_list) == 0:
        rgb_list = []
        for id, xyz in enumerate(xyz_list):
            # 0: normal point, light grey
            # 1: gripper, blue
            # 2: object, orange
            # 3: target, green
            labels = pc_labels_list[id]
            rgb = np.zeros((xyz.shape[0], 3))
            
            # Set colors based on labels
            rgb[labels == 0] = [0.7, 0.7, 0.7]  # Light grey
            rgb[labels == 1] = [0.0, 0.0, 1.0]  # Blue
            rgb[labels == 2] = [1.0, 0.65, 0.0] # Orange
            rgb[labels == 3] = [0.0, 1.0, 0.0]  # Green
            
            rgb_list.append(rgb)

    if actions_list is None:
        actions_list = [[] for _ in range(num_keysteps)]

    # Plot PCDs and actions for each keystep
    for idx, (xyz, rgb, actions) in enumerate(zip(xyz_list, rgb_list, actions_list)):
        row = idx // num_cols + 1
        col = idx % num_cols + 1
        
        # Convert RGB values to strings
        colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b in rgb]

        fig.add_trace(
            go.Scatter3d(
                x=xyz[:, 0],
                y=xyz[:, 1],
                z=xyz[:, 2],
                mode='markers',
                marker=dict(
                    size=2,
                    color=colors,
                    opacity=0.8
                ),
                showlegend=False
            ),
            row=row, col=col
        )

        for action in actions:
            fig.add_trace(
                go.Scatter3d(
                    x=[action[0]],
                    y=[action[1]],
                    z=[action[2]],
                    mode='markers',
                    marker=dict(
                        size=5,
                        color='red',
                        opacity=1.0
                    ),
                    name="Action",
                    showlegend=False
                ),
                row=row, col=col
            )

    fig.update_layout(
        height=400 * num_rows,
        width=400 * num_cols,
        margin=dict(l=0, r=0, b=0, t=30),
        paper_bgcolor='white',
        plot_bgcolor='white'
    )

    scene_settings = dict(
        aspectmode='data',
        xaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
        yaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
        zaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True)
    )
    
    for i in range(1, num_keysteps + 1):
        fig.update_scenes(scene_settings, row=((i-1)//num_cols + 1), col=((i-1)%num_cols + 1))

    fig.show()

In [None]:
def display_front_view_sequence(obs_dict, col_nb: int = 4):
    front_view_sequence = []
    for step in range(num_steps):
        try:
            obs_rgb = obs_dict[step]['obs']['rgb']
        except KeyError:
            continue
        front_view = obs_rgb[-1]
        front_view_sequence.append(front_view)
    cursor = 0
    while cursor < len(front_view_sequence):
        plt.imshow(np.concatenate(front_view_sequence[cursor:cursor+col_nb], axis=1))
        plt.show()
        cursor += col_nb

In [None]:
def load_one_step(step, steps_root_dir):
    file = f"step_{step}.npy"
    print(f"loading step: {step}")
    step_array = np.load(f"{steps_root_dir}/{file}", allow_pickle=True)
    return step_array

In [None]:
from functools import partial

def load_steps_data(episode_dir, num_threads=None):
    steps_root_dir = f"{episode_dir}/steps"

    step_files = os.listdir(steps_root_dir)
    num_steps = len(step_files)
    print(f"list of steps: {step_files}")

    load_step_partial = partial(load_one_step, steps_root_dir=steps_root_dir)

    with ThreadPool(processes=num_threads) as pool:
        steps = range(num_steps)
        results = pool.map(load_step_partial, steps)

    obs_dict = {step: result[()] for step, result in zip(steps, results)}
    
    return obs_dict, num_steps

In [None]:
def load_video_data(episode_dir):
    video_dir = [folder for folder in os.listdir(episode_dir) if "video" in folder][0]
    success_rate = video_dir.strip("video-SR")

    video_path = os.path.join(episode_dir, video_dir, "global.avi")
    gif_output_path = "gif/robot_action.gif"
    
    gif_path = convert_video_to_gif(video_path, gif_output_path)
    
    return success_rate, gif_path

In [None]:
def prepare_keysteps_pcd_sequence(obs_dict, num_steps):
    xyz_list = []
    actions_list = []
    pc_labels_list = []
    plans_list = []
    action_name_lists = []
    cache_list = []
    gripper_pos_list = []
    mp_error_list = []

    for step in range(num_steps):
        batch = obs_dict[step]['batch']
        valid_actions = obs_dict[step]['valid_actions']
        
        if batch is None:
            # nothing changes, except for the action
            batch = obs_dict[step-1]['batch']

        xyz = batch['pc_fts'][:, :3]
        xyz_list.append(xyz)

        denormalized_valid_actions = []
        valid_action = valid_actions[0]
        
        pos = (valid_action[:3] - batch['pc_centroids']) / batch['pc_radius']
        denormalized_action = np.concatenate([pos, valid_action[3:]])
        if len(denormalized_action) == 8: # action is release, and in the code when it happens the action is len(8) instead of 9 which spoils the code
            denormalized_action = np.concatenate([denormalized_action, [404]])
        denormalized_valid_actions.append(denormalized_action)
            
        actions_list.append(denormalized_valid_actions)

        pc_labels_list.append(batch['pc_labels'])
        
        if obs_dict[step]['plan'] is not None:
            plans_list.append(obs_dict[step]['plan'])
        else:
            plans_list.append({})
            
        
        if obs_dict[step]['action_name'] is not None:
            action_name_lists.append(obs_dict[step]['action_name'])
        else:
            action_name_lists.append([])
        
        if obs_dict[step]['cache'] is not None:
            cache_list.append(obs_dict[step]['cache'])
        else:
            cache_list.append({})

        gripper_pos = copy.deepcopy(obs_dict[step]['obs']["gripper"])
        gripper_pos[:3] = (gripper_pos[:3] - batch['pc_centroids']) / batch['pc_radius']
        gripper_pos_list.append(gripper_pos)

        if step == 0:
            mp_error_list.append((0,gripper_pos[:3],[0,0,0]))
        else:
            current_gripper_pos = gripper_pos[:3]
            previous_gripper_target = (obs_dict[step-1]['valid_actions'][0][:3] - batch['pc_centroids']) / batch['pc_radius']
            previous_gripper_target_pos = previous_gripper_target[:3]
            mp_error = np.sqrt(sum( (current_gripper_pos[i] - previous_gripper_target_pos[i])**2 for i in range(3)))
            mp_error_list.append((round(mp_error,3), current_gripper_pos, previous_gripper_target_pos))
    
    return xyz_list, actions_list, pc_labels_list, plans_list, action_name_lists, cache_list, gripper_pos_list, mp_error_list

In [None]:
def display_plan(plans, actions_list):
    for step, plan in enumerate(plans):
        print(f"# Step {step}:")
        print(f"Plan: {plan}")
        print(f"Action: {actions_list[step]}")
        gripper_states = [f"{action[7]}, {'Open' if action[7] == 1 else 'Closed'}" for action in actions_list[step]]
        print(f"Gripper states: {gripper_states}")

In [None]:
def quaternion_to_rotation_matrix(qx, qy, qz, qw):
    R = np.zeros((3, 3))
    R[0, 0] = 1 - 2*qy**2 - 2*qz**2
    R[0, 1] = 2*qx*qy - 2*qz*qw
    R[0, 2] = 2*qx*qz + 2*qy*qw
    R[1, 0] = 2*qx*qy + 2*qz*qw
    R[1, 1] = 1 - 2*qx**2 - 2*qz**2
    R[1, 2] = 2*qy*qz - 2*qx*qw
    R[2, 0] = 2*qx*qz - 2*qy*qw
    R[2, 1] = 2*qy*qz + 2*qx*qw
    R[2, 2] = 1 - 2*qx**2 - 2*qy**2

    return R.T

def transform_gripper_points(xyz, labels, action, gripper_pos):
    """Transform gripper points according to predicted action."""
    # Extract gripper points
    gripper_mask = labels == 1
    gripper_points = xyz[gripper_mask].copy()
    
    if len(gripper_points) == 0:
        return np.array([])
        
    # Calculate current centroid of gripper points
    current_centroid = gripper_pos[:3]
    
    # Center points around origin
    centered_points = gripper_points - current_centroid
    
    # Extract target position and rotation from action
    target_position = action[:3]
    qx, qy, qz, qw = action[3:7]

    current_qx, current_qy, current_qz, current_qw = gripper_pos[3:7]
    R_current_rot = quaternion_to_rotation_matrix(current_qx, current_qy, current_qz, current_qw)
    
    # Rotation target
    R_target = quaternion_to_rotation_matrix(qx, qy, qz, qw)
    
    # Apply rotation if needed (combining fix rotation with quaternion rotation)
    distance_quaternions_gripper_pos_target_pose = np.sqrt(sum( (action[3:7][i] - gripper_pos[3:7][i])**2 for i in range(4)))
    print(f"distance_quaternions_gripper_pos_target_pose: {distance_quaternions_gripper_pos_target_pose}")
    if distance_quaternions_gripper_pos_target_pose < 0.01:
        # If the distance between the quaternions of the current pose and the target pose is almost zero, then no need to rotate the target
        rotated_points = centered_points
    else:
        rotated_points = centered_points @ R_current_rot @ R_target
    
    # Translate to target position
    transformed_points = rotated_points + target_position
    
    return transformed_points

def display_story(start_idx, plans_list, xyz_list, actions_list, pc_labels_list, obs_dict, cache_list, action_name_lists, gripper_pos_list, mp_error_list):
    for step, plan in enumerate(plans_list):
        step = step + start_idx
        print(f"\033[1m# Step {step}:\033[0m")
        print(f"Plan:")
        print(f"Action_name: {action_name_lists[step]}")
        for key in plan:
            print(f"\033[4m{key}\033[0m: {plan[key]}")
        print(f"MP error = {mp_error_list[step][0]}")

        fig = make_subplots(
            rows=1, cols=2,
            specs=[[{"type": "image"}, {"type": "scene"}]],
            subplot_titles=["Front View", "Point Cloud"]
        )

        try:
            front_view = obs_dict[step]['obs']['rgb'][-1]
            fig.add_trace(
                go.Image(z=front_view),
                row=1, col=1
            )
        except KeyError:
            print(f"No front view image for step {step}")

        try:
            xyz = xyz_list[step]
            labels = pc_labels_list[step]
            gripper_pos = gripper_pos_list[step]

            rgb = np.zeros((xyz.shape[0], 3))
            rgb[labels == 0] = [0.7, 0.7, 0.7]  # Light grey
            rgb[labels == 1] = [0.0, 0.0, 1.0]  # Blue
            rgb[labels == 2] = [1.0, 0.65, 0.0]  # Orange
            rgb[labels == 3] = [0.0, 1.0, 0.0]  # Green
            colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b in rgb]

            fig.add_trace(
                go.Scatter3d(
                    x=xyz[:, 0],
                    y=xyz[:, 1],
                    z=xyz[:, 2],
                    mode='markers',
                    marker=dict(
                        size=2,
                        color=colors,
                        opacity=0.8
                    ),
                    showlegend=False
                ),
                row=1, col=2
            )

            for action in actions_list[step]:
                # Transform gripper points ("imagine" the target pose)
                transformed_points = transform_gripper_points(xyz, labels, action, gripper_pos)
                gripper_state_current = 1 if gripper_pos[-1] > 0.5 else 0
                gripper_state_target = 1 if action[-2] > 0.5 else 0
                print(f"Current Gripper State: {'Open' if gripper_state_current == 1 else 'Closed'}")
                print(f"Target Gripper State: {'Open' if gripper_state_target == 1 else 'Closed'}")
                print(f'Action: {action}')
                print(f'Gripper pos: {gripper_pos}')

                # When the gripper changes its state, we want to see it visually, so we use another color
                color_target_action = 'red' if gripper_state_current == gripper_state_target else 'yellow'
                if plan['action'] == 'release':
                    color_target_action = 'yellow'

                if len(transformed_points) > 0:
                    fig.add_trace(
                        go.Scatter3d(
                            x=transformed_points[:, 0],
                            y=transformed_points[:, 1],
                            z=transformed_points[:, 2],
                            mode='markers',
                            marker=dict(
                                size=2,
                                color=color_target_action,
                                opacity=0.8
                            ),
                            name="Predicted Gripper",
                            showlegend=False
                        ),
                        row=1, col=2
                    )
                    fig.add_trace(
                        go.Scatter3d(
                            x=[action[0]],
                            y=[action[1]],
                            z=[action[2]],
                            mode='markers',
                            marker=dict(
                                size=6,
                                color='pink',
                                opacity=0.8
                            ),
                            name="Predicted Gripper Pose",
                            showlegend=False
                        ),
                        row=1, col=2
                    )

            fig.add_trace(
                go.Scatter3d(
                    x=[gripper_pos[0]],
                    y=[gripper_pos[1]],
                    z=[gripper_pos[2]],
                    mode='markers',
                    marker=dict(
                        size=4,
                        color='green',
                        opacity=0.8
                    ),
                    name="Gripper Position",
                    showlegend=False
                ),
                row=1, col=2
            )
            if step !=0:
                print(f"mp_error_list: {mp_error_list[step]}")
                
                fig.add_trace(
                    go.Scatter3d(
                        x=[mp_error_list[step][2][0]],
                        y=[mp_error_list[step][2][1]],
                        z=[mp_error_list[step][2][2]],
                        mode='markers',
                        marker=dict(
                            size=6,
                            color='purple',
                            opacity=0.8
                        ),
                        name="Previous Target Gripper Position",
                        showlegend=False
                    ),
                    row=1, col=2
                )
        except IndexError:
            print(f"No point cloud data for step {step}")

        fig.update_layout(
            height=500,
            width=1000,
            margin=dict(l=0, r=0, b=0, t=30),
            paper_bgcolor='white',
            plot_bgcolor='white'
        )

        scene_settings = dict(
            aspectmode='data',
            xaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True)
        )
        fig.update_scenes(scene_settings)

        fig.show()

In [None]:
def convert_video_to_gif(video_path, output_path):
    print(f"Converting video to GIF...")
    video_clip = VideoFileClip(video_path)

    video_clip.write_gif(output_path, fps=10)

    video_clip.close()

    gif_size_mb = os.path.getsize(output_path) / (1024 * 1024)
    print(f"GIF size: {gif_size_mb:.2f} MB")

    return output_path

In [None]:
# Parameters
benchmark="peract"
seed = 200
model = "3dlotusplus/v2_mix"
taskvar = "put_item_in_drawer_peract+2"

taskvar_records_dir = f"../data/experiments/{benchmark}/{model}/records/seed{seed}/{taskvar}"

# <span style="color:blue"> ------------ Run ------------  </span>

In [None]:
# list all episodes along with their success rate
list_dir = os.listdir(taskvar_records_dir)
sorted_list_dir = sorted(list_dir, key=lambda x: int(x.split('_')[1]) if 'ep_' in x else float('inf'))

for folder in sorted_list_dir:
    if "ep" in folder:
        episode_id = folder.split('_')[1]
        episode_dir = f"{taskvar_records_dir}/{folder}"
        video_dir = [folder for folder in os.listdir(episode_dir) if "video" in folder][0]
        success_rate = video_dir.strip("video-SR")
        nb_steps = len(os.listdir(f"{episode_dir}/steps"))
        print(f"Ep {episode_id}: SR={success_rate}, steps={nb_steps}")

In [None]:
episode_id = 22

In [None]:
episode_dir=f"{taskvar_records_dir}/ep_{episode_id}"

In [None]:
obs_dict, num_steps = load_steps_data(episode_dir)

In [None]:
success_rate, gif_path = load_video_data(episode_dir)

In [None]:
xyz_list, actions_list, pc_labels_list, plans_list, action_name_lists, cache_list, gripper_pos_list, mp_error_list = prepare_keysteps_pcd_sequence(obs_dict, num_steps)

In [None]:
display(Image(filename=gif_path))

In [None]:
highlevel_plans = json.load(open(f"{episode_dir}/highlevel_plans.json", "r"))

print(f"### {taskvar}, SR: {success_rate}")
print(f"# Task Planning - LLM")
print(f"\033[4m Input \033[0m: {highlevel_plans['instruction']}")
print(f"\033[4m Output \033[0m:")
for i, plan in enumerate(highlevel_plans['plans']):
    print(f" Step {i}: {plan}")
print(f"\033[4m Output parsed \033[0m:")
for i, plan in enumerate(highlevel_plans['parsed_plans']):
    print(f"Step {i}: {plan}")

In [None]:
subplan1 = [p for i, p in enumerate(plans_list) if i <= len(plans_list)/3]
subplan2 = [p for i, p in enumerate(plans_list) if (i <= 2*len(plans_list)/3 and i > len(plans_list)/3)]
subplan3 = [p for i, p in enumerate(plans_list) if (i <= len(plans_list) and i > 2*len(plans_list)/3)]

In [None]:
display_story(0, subplan1, xyz_list, actions_list, pc_labels_list, obs_dict, cache_list, action_name_lists, gripper_pos_list, mp_error_list)

In [None]:
display_story(int(len(plans_list)/3) + 1, subplan2, xyz_list, actions_list, pc_labels_list, obs_dict, cache_list, action_name_lists, gripper_pos_list, mp_error_list)

In [None]:
display_story(int(2*len(plans_list)/3) + 1, subplan3, xyz_list, actions_list, pc_labels_list, obs_dict, cache_list, action_name_lists, gripper_pos_list, mp_error_list)

In [None]:
from genrobo3d.utils.point_cloud import voxelize_pcd, get_pc_foreground_mask
from genrobo3d.utils.robot_box import RobotBox
import open3d as o3d
from scipy.spatial.transform import Rotation as R
import json
import copy
import numpy as np
import torch
import random

class GroundtruthVision(object):
    def __init__(
        self,
        gt_label_file,
        num_points=4096,
        voxel_size=0.01,
        same_npoints_per_example=False,
        rm_robot="box_keep_gripper",
        xyz_shift="center",
        xyz_norm=False,
        use_height=True,
        pc_label_type="coarse",
        use_color=False,
    ):
        self.taskvar_gt_target_labels = json.load(open(gt_label_file))
        self.workspace = {
            'TABLE_HEIGHT': 0.7505,
            'X_BBOX': (-0.5, 1.5),
            'Y_BBOX': (-1, 1),
            'Z_BBOX': (0.2, 2)
        }
        self.TABLE_HEIGHT = self.workspace["TABLE_HEIGHT"]

        self.num_points = num_points
        self.voxel_size = voxel_size
        self.pc_label_type = pc_label_type
        self.same_npoints_per_example = same_npoints_per_example
        self.rm_robot = rm_robot
        self.xyz_shift = xyz_shift
        self.xyz_norm = xyz_norm
        self.use_height = use_height
        self.use_color = use_color

    def get_target_labels(self, taskvar, step_id, episode):
        """
        Flexibly access target labels handling both direct step_id indexing and episode[step_id] cases.

        Args:
            taskvar: Task variable key
            step_id: Step ID to access
            episode: Optional episode number

        Returns:
            Target labels dictionary for the specified step
        """
        try:
            # First try direct step_id indexing
            return self.taskvar_gt_target_labels[taskvar][step_id]
        except KeyError:
            return self.taskvar_gt_target_labels[taskvar][episode][step_id]

    def __call__(
        self,
        taskvar,
        step_id,
        pcd_images,
        sem_images,
        gripper_pose,
        arm_links_info,
        rgb_images=None,
        episode_id=None,
    ):
        episode = f"episode{episode_id}"
        task, variation = taskvar.split("+")
        pcd_xyz = pcd_images.reshape(-1, 3)
        pcd_sem = sem_images.reshape(-1)
        if self.use_color:
            assert rgb_images is not None
            pcd_rgb = rgb_images.reshape(-1, 3)

        # remove background and table points
        fg_mask = get_pc_foreground_mask(pcd_xyz, self.workspace)
        pcd_xyz = pcd_xyz[fg_mask]
        pcd_sem = pcd_sem[fg_mask]
        if self.use_color:
            pcd_rgb = pcd_rgb[fg_mask]

        pcd_xyz, idxs = voxelize_pcd(pcd_xyz, voxel_size=self.voxel_size)
        pcd_sem = pcd_sem[idxs]
        if self.use_color:
            pcd_rgb = pcd_rgb[idxs]

        if self.rm_robot != "none":
            if self.rm_robot == "box":
                robot_box = RobotBox(arm_links_info, keep_gripper=False)
            elif self.rm_robot == "box_keep_gripper":
                robot_box = RobotBox(arm_links_info, keep_gripper=True)
            robot_point_idxs = robot_box.get_pc_overlap_ratio(
                xyz=pcd_xyz, return_indices=True
            )[1]
            robot_point_idxs = np.array(list(robot_point_idxs))
            if len(robot_point_idxs) > 0:
                mask = np.ones((pcd_xyz.shape[0],), dtype=bool)
                mask[robot_point_idxs] = False
                pcd_xyz = pcd_xyz[mask]
                pcd_sem = pcd_sem[mask]
                if self.use_color:
                    pcd_rgb = pcd_rgb[mask]

        # sample points
        if len(pcd_xyz) > self.num_points:
            point_idxs = np.random.permutation(len(pcd_xyz))[: self.num_points]
        else:
            if self.same_npoints_per_example:
                point_idxs = np.random.choice(
                    pcd_xyz.shape[0], self.num_points, replace=True
                )
            else:
                point_idxs = np.arange(pcd_xyz.shape[0])
        pcd_xyz = pcd_xyz[point_idxs]
        pcd_sem = pcd_sem[point_idxs]
        height = pcd_xyz[..., 2] - self.TABLE_HEIGHT
        if self.use_color:
            pcd_rgb = pcd_rgb[point_idxs]

        # robot pcd_label
        pcd_label = np.zeros_like(pcd_sem)
        robot_box = RobotBox(arm_links_info, keep_gripper=False)
        robot_point_idxs = robot_box.get_pc_overlap_ratio(
            xyz=pcd_xyz, return_indices=True
        )[1]
        robot_point_idxs = np.array(list(robot_point_idxs))
        if len(robot_point_idxs) > 0:
            pcd_label[robot_point_idxs] = 1
        for query_key, query_label_id in zip(["object", "target"], [2, 3]):
            target_labels = self.get_target_labels(taskvar, step_id, episode)
            if target_labels is None or query_key not in target_labels:
                continue

            gt_target_labels = target_labels[query_key]

            if self.pc_label_type != "mix":
                pc_label_type = self.pc_label_type
            else:
                pc_label_type = random.choice(["coarse", "fine"])

            labels = (
                gt_target_labels[pc_label_type]
                if pc_label_type in gt_target_labels
                else gt_target_labels["fine"]
            )
            gt_query_mask = [pcd_sem == x for x in labels]
            gt_query_mask = np.sum(gt_query_mask, 0) > 0
            if "zrange" in gt_target_labels:
                gt_query_mask = (
                    gt_query_mask
                    & (pcd_xyz[..., 2] > gt_target_labels["zrange"][0])
                    & (pcd_xyz[..., 2] < gt_target_labels["zrange"][1])
                )
            if "xy_bbox" in gt_target_labels and self.pc_label_type != "coarse":
                bbox_offset = gt_target_labels["xy_bbox"]["bbox_offset"]
                bbox_size = gt_target_labels["xy_bbox"]["bbox_size"]

                obj_pose = gt_target_labels["xy_bbox"]["obj_pose"]
                bbox_pos = obj_pose[:3]
                gripper_quat = obj_pose[3:]

                bbox_rot = R.from_quat(gripper_quat).as_matrix()

                # Rotate the offset by the rotation matrix
                rotated_offset = bbox_rot @ bbox_offset

                obbox = o3d.geometry.OrientedBoundingBox(bbox_pos, bbox_rot, bbox_size)
                obbox.translate(rotated_offset, relative=True)

                pcd = o3d.geometry.PointCloud()
                pcd.points = o3d.utility.Vector3dVector(pcd_xyz)
                box_idx = obbox.get_point_indices_within_bounding_box(pcd.points)

                bbox_mask = np.zeros((pcd_xyz.shape[0],), dtype=bool)
                if box_idx:  # If not empty
                    indices = np.array(box_idx, dtype=np.int64)
                    bbox_mask[indices] = True

                gt_query_mask = gt_query_mask & bbox_mask
            pcd_label[gt_query_mask] = query_label_id

        # normalize point cloud
        if self.xyz_shift == "none":
            pc_centroid = np.zeros((3,))
        elif self.xyz_shift == "center":
            pc_centroid = np.mean(pcd_xyz, 0)
        elif self.xyz_shift == "gripper":
            pc_centroid = copy.deepcopy(gripper_pose[:3])
        if self.xyz_norm:
            pc_radius = np.max(np.sqrt(np.sum((pcd_xyz - pc_centroid) ** 2, axis=1)))
        else:
            pc_radius = 1
        pcd_xyz = (pcd_xyz - pc_centroid) / pc_radius
        gripper_pose[:3] = (gripper_pose[:3] - pc_centroid) / pc_radius

        pcd_ft = pcd_xyz
        if self.use_height:
            pcd_ft = np.concatenate([pcd_ft, height[:, None]], -1)
        if self.use_color:
            pcd_rgb = (pcd_rgb / 255.0) * 2 - 1
            pcd_ft = np.concatenate([pcd_ft, pcd_rgb], -1)

        outs = {
            "pc_fts": torch.from_numpy(pcd_ft).float(),
            "pc_labels": torch.from_numpy(pcd_label).long(),
            "offset": torch.LongTensor([pcd_xyz.shape[0]]),
            "npoints_in_batch": [pcd_xyz.shape[0]],
            "pc_centroids": pc_centroid,
            "pc_radius": pc_radius,
            "ee_poses": torch.from_numpy(gripper_pose).float().unsqueeze(0),
        }

        return outs

In [None]:
step = 3

In [None]:
obs_step = obs_dict[step]['obs']
pcd_images = obs_step['pc']
rgb_images = obs_step["rgb"]
sem_images = obs_step["gt_mask"]
arm_links_info = obs_step["arm_links_info"]
gripper_pose = copy.deepcopy(obs_step["gripper"])
gt_label_file = f"../assets/{benchmark}/taskvars_target_label_zrange_{benchmark}.json"

vlm_pipeline = GroundtruthVision(
        gt_label_file,
        num_points=4096,
        voxel_size=0.01,
        same_npoints_per_example=False,
        rm_robot="box_keep_gripper",
        xyz_shift="none",
        xyz_norm=False,
        use_height=True,
        pc_label_type="mix",
        use_color=False,
    )

batch = vlm_pipeline(
        taskvar,
        cache_list[step]["highlevel_step_id_norelease"],
        pcd_images,
        sem_images,
        gripper_pose,
        arm_links_info,
        rgb_images=rgb_images,
        episode_id=episode_id,
    )

In [None]:
xyz = batch['pc_fts'][:, :3]
labels = batch['pc_labels']

In [None]:
fig = make_subplots(
    rows=1, cols=2,
    specs=[[{"type": "image"}, {"type": "scene"}]],
    subplot_titles=["Front View", "Point Cloud"]
)

front_view = obs_dict[step]['obs']['rgb'][-1]
fig.add_trace(
    go.Image(z=front_view),
    row=1, col=1
)

rgb = np.zeros((xyz.shape[0], 3))
rgb[labels == 0] = [0.7, 0.7, 0.7]  # Light grey
rgb[labels == 1] = [0.0, 0.0, 1.0]  # Blue
rgb[labels == 2] = [1.0, 0.65, 0.0]  # Orange
rgb[labels == 3] = [0.0, 1.0, 0.0]  # Green


colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b in rgb]

# Add point cloud
fig.add_trace(
    go.Scatter3d(
        x=xyz[:, 0],
        y=xyz[:, 1],
        z=xyz[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=colors,
            opacity=0.8
        ),
        showlegend=False
    ),
    row=1, col=2
)

# Update layout
fig.update_layout(
    height=500,
    width=1000,
    margin=dict(l=0, r=0, b=0, t=30),
    paper_bgcolor='white',
    plot_bgcolor='white'
)

# Update 3D scene settings
scene_settings = dict(
    aspectmode='data',
    xaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
    yaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
    zaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True)
)
fig.update_scenes(scene_settings)

fig.show()