In [1]:
import plotly.graph_objects as go

def visualize_point_clouds(xyz, rgb, ee_pose, gripper_pos, gripper_state):
    """
    Visualize 3D point cloud data using Plotly, with gripper position overlay.
    
    Args:
        xyz (numpy.ndarray): Point coordinates with shape (N, 3)
        rgb (numpy.ndarray): RGB colors with shape (N, 3), values in range [0, 1]
        gripper_pos (numpy.ndarray): Gripper position in scene coordinates
        gripper_state (float): Gripper open/close state, 0 for open (red), 1 for closed (green)
    """
    # Convert RGB values to strings for Plotly
    colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b in rgb]
    
    gripper_color = 'rgb(0,255,0)' if gripper_state > 0.5 else 'rgb(255,0,0)'

    # Create the 3D scatter plot for point cloud
    fig = go.Figure(data=[go.Scatter3d(
        x=xyz[:, 0],
        y=xyz[:, 1],
        z=xyz[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=colors,
            opacity=0.8
        )
    )])

    # Add the gripper position as a larger marker
    fig.add_trace(go.Scatter3d(
        x=[gripper_pos[0]],
        y=[gripper_pos[1]],
        z=[gripper_pos[2]],
        mode='markers',
        marker=dict(
            size=10,
            color=gripper_color,
            opacity=1.0
        ),
        name="Gripper"
    ))
    
    # Add the end-effector position as a blue marker
    fig.add_trace(go.Scatter3d(
        x=[ee_pose[0]],
        y=[ee_pose[1]],
        z=[ee_pose[2]],
        mode='markers',
        marker=dict(
            size=10,
            color='blue',
            opacity=1.0
        ),
        name="End-Effector"
    ))
    
    fig.update_layout(
        scene=dict(
            aspectmode='data',  # Preserve the point cloud's true shape
            xaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            yaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True),
            zaxis=dict(backgroundcolor='white', gridcolor='lightgrey', showbackground=True)
        ),
        width=800,
        height=800,
        margin=dict(l=0, r=0, b=0, t=0),
        paper_bgcolor='white',  # Set paper background to white
        plot_bgcolor='white'    # Set plot background to white
    )
    
    # Show the plot
    fig.show()


In [2]:
import numpy as np
import os

task = 'push_button+0'
episode_id = 5
root_dir = f"../data/experiments/gembench/3dlotusplus/v1/preds-llm_gt-og_gt_fine/seed200/obs_outs/{task}/{episode_id}"

obs_dict = {}
num_steps = len(os.listdir(root_dir))

for step, file in enumerate(os.listdir(root_dir)):
    step_array = np.load(f"{root_dir}/{file}", allow_pickle=True)
    obs_dict[step] = step_array[()] 

FileNotFoundError: [Errno 2] No such file or directory: '../data/experiments/gembench/3dlotusplus/v1/preds-llm_gt-og_gt_fine/seed200/obs_outs/push_button+0/5'

In [52]:
os.listdir(root_dir)

['0.npy', '1.npy']

In [None]:
# >>> obs = obs[()]
# >>> obs.keys()
# dict_keys(['batch', 'obs', 'valid_actions'])
# >>> obs['batch'].keys()
# dict_keys(['pc_fts', 'pc_labels', 'offset', 'npoints_in_batch', 'pc_centroids', 'pc_radius', 'ee_poses'])
# >>> obs['obs'].keys()
# dict_keys(['rgb', 'depth', 'pc', 'arm_links_info', 'gt_mask', 'gripper'])
# Valid_actions:
# [gripper_pos[0:3], gripper_rot[3:7], gripper_state[7], stop_prob[8]]

In [48]:
num_steps

2

In [51]:
for step in range(num_steps):
    valid_action = obs_dict[step]['valid_actions'][0]
    stop_prob = valid_action[8]
    pc_fts = obs_dict[step]['batch']['pc_fts']
    ee_pose = obs_dict[step]['batch']['ee_poses'][0]
    pc_centroid = obs_dict[step]['batch']['pc_centroids']  # Retrieve centroid
    pc_radius = obs_dict[step]['batch']['pc_radius']       # Retrieve radius
    
    xyz = pc_fts[:, :3]
    rgb = (pc_fts[:, 3:] + 1) / 2  # Normalize from [-1, 1] to [0, 1]
    rgb = np.repeat(rgb, 3, axis=1)

    # De-normalize gripper position to match point cloud scale
    gripper_pos = (valid_action[:3] - pc_centroid)/pc_radius
    ee_pose = ee_pose[:3]

    # Visualize the point cloud
    print(f"Visualizing step {step}, stop_prob next action: {round(stop_prob,1)}")
    print(f"Valid action: {valid_action}")
    visualize_point_clouds(xyz, rgb, ee_pose, gripper_pos, valid_action[7])


Visualizing step 0, stop_prob: 0.0


Visualizing step 1, stop_prob: 1.0


In [None]:
# for step in range(10,num_steps):
#     valid_action = obs_dict[step]['valid_actions'][0]
#     pc_fts = obs_dict[step]['batch']['pc_fts']
#     ee_pose = obs_dict[step]['batch']['ee_poses'][0]
#     pc_centroid = obs_dict[step]['batch']['pc_centroids']  # Retrieve centroid
#     pc_radius = obs_dict[step]['batch']['pc_radius']       # Retrieve radius
#     
#     xyz = pc_fts[:, :3]
#     rgb = (pc_fts[:, 3:] + 1) / 2  # Normalize from [-1, 1] to [0, 1]
#     rgb = np.repeat(rgb, 3, axis=1)
# 
#     # De-normalize gripper position to match point cloud scale
#     gripper_pos = (valid_action[:3] - pc_centroid)/pc_radius
#     ee_pose = ee_pose[:3]
# 
#     # Visualize the point cloud
#     print(f"Visualizing step {step}")
#     visualize_point_clouds(xyz, rgb, ee_pose, gripper_pos, valid_action[7])
