In [2]:
import os
import open3d as o3d
import lmdb
import msgpack
import msgpack_numpy
import copy
msgpack_numpy.patch()

import plotly.graph_objects as go
import json
import numpy as np

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [1]:
def visualize_point_clouds(xyz, rgb, gripper_pos, noisy_point=None):
    """
    Visualize 3D point cloud data using Plotly
    
    Args:
        xyz (numpy.ndarray): Point coordinates with shape (N, 3)
        rgb (numpy.ndarray): RGB colors with shape (N, 3), values in range [0, 1]
    """
    # Convert RGB values to strings for Plotly
    colors = [f'rgb({r},{g},{b})' 
              for r, g, b in rgb]
    
    # Create the 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=xyz[:, 0],
        y=xyz[:, 1],
        z=xyz[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=colors,
            opacity=0.8
        )
    )])
    
    fig.add_trace(go.Scatter3d(
        x=[gripper_pos[0]],
        y=[gripper_pos[1]],
        z=[gripper_pos[2]],
        mode='markers',
        marker=dict(
            size=10,
            color="blue",
            opacity=1.0
        ),
        name="Gripper"
    ))
    
    if noisy_point is not None:
        fig.add_trace(go.Scatter3d(
            x=[noisy_point[0]],
            y=[noisy_point[1]],
            z=[noisy_point[2]],
            mode='markers',
            marker=dict(
                size=10,
                color="red",
                opacity=1.0
            ),
            name="Noisy Point"
        ))
    
    fig.update_layout(
        scene=dict(
            aspectmode='data',  # Preserve the point cloud's true shape
            xaxis=dict(
                backgroundcolor='white',
                gridcolor='lightgrey',
                showbackground=True
            ),
            yaxis=dict(
                backgroundcolor='white',
                gridcolor='lightgrey',
                showbackground=True
            ),
            zaxis=dict(
                backgroundcolor='white',
                gridcolor='lightgrey',
                showbackground=True
            )
        ),
        width=800,
        height=800,
        margin=dict(l=0, r=0, b=0, t=0),
        paper_bgcolor='white',  # Set paper background to white
        plot_bgcolor='white'    # Set plot background to white
    )
    
    # Show the plot
    fig.show()

In [4]:
def load_data_point(data_dir):
    """
    Load LMDB data points from directory
    
    Args:
        data_dir (str): Path to data directory containing LMDB files
        
    Returns:
        tuple: (lmdb_txns, data_ids) where data_ids is list of (taskvar, episode_id) tuples
    """
    taskvars = os.listdir(data_dir)
    lmdb_envs, lmdb_txns = {}, {}
    data_ids = []
    
    for taskvar in taskvars:
        lmdb_envs[taskvar] = lmdb.open(os.path.join(data_dir, taskvar), readonly=True)
        lmdb_txns[taskvar] = lmdb_envs[taskvar].begin()
        
        data_ids.extend(
            [(taskvar, key) for key in lmdb_txns[taskvar].cursor().iternext(values=False)]
        )
    
    return lmdb_txns, data_ids

def load_episode(taskvar, episode_id, lmdb_txns):
    """
    Load specific episode data from LMDB
    
    Args:
        taskvar (str): Task variant name
        episode_id (bytes): Episode ID
        lmdb_txns (dict): Dictionary of LMDB transactions
        
    Returns:
        tuple: (episode_data, num_steps)
    """
    data = msgpack.unpackb(lmdb_txns[taskvar].get(episode_id))
    num_steps = len(data['xyz'])
    
    return data, num_steps

In [5]:
data_dir = '../data/peract/train_dataset/motion_keysteps_bbox_pcd/seed0/voxel1cm'
data_dict = {}

print(f"Loading coarse data")
lmdb_txns, data_ids = load_data_point(data_dir)

Loading coarse data


In [6]:
taskvar = "close_jar_peract+19"

episode_id = b'episode20'

data, num_steps = load_episode(taskvar, episode_id, lmdb_txns)
data_dict['coarse'] = {'data': data, 'num_steps': num_steps}

In [7]:
data.keys()

dict_keys(['xyz', 'rgb', 'sem', 'ee_pose', 'bbox_info', 'pose_info', 'trajs', 'end_open_actions', 'is_new_keystep'])

In [10]:
num_steps

7

In [9]:
data['ee_pose']

array([[ 2.78457761e-01, -8.14718567e-03,  1.47197890e+00,
        -2.26825080e-07,  9.92663085e-01,  1.09368307e-06,
         1.20913170e-01,  1.00000000e+00],
       [ 2.17031807e-01,  1.51475072e-01,  9.29268777e-01,
        -1.20047271e-01,  9.92768168e-01, -7.27345177e-05,
        -3.81539518e-04,  1.00000000e+00],
       [ 2.17071697e-01,  1.51747972e-01,  7.58757830e-01,
        -1.20201498e-01,  9.92749453e-01, -7.52486594e-05,
        -5.24296425e-04,  0.00000000e+00],
       [ 2.16754004e-01,  1.50478184e-01,  9.26252961e-01,
        -1.20212406e-01,  9.92747426e-01, -1.87031779e-04,
        -1.27542298e-03,  0.00000000e+00],
       [ 2.69367099e-01, -2.78138995e-01,  9.27740932e-01,
        -1.22757763e-01,  9.92436349e-01,  5.41877700e-04,
        -6.27792790e-04,  0.00000000e+00],
       [ 2.69397646e-01, -2.78723121e-01,  8.67343545e-01,
        -1.22859091e-01,  9.92423534e-01,  5.91070624e-04,
        -9.21464758e-04,  0.00000000e+00],
       [ 2.69598693e-01, -2.788286

In [None]:
for t in range(2):
    xyz, rgb = data['xyz'][t], data['rgb'][t]
    for idx, action in enumerate(data['trajs'][t]):
        print(f"displaying action step {idx}: {action}")
        gripper_pos = data['ee_pose'][t][:3]
        print(f"x: {xyz.shape}, rgb: {rgb.shape}")
        visualize_point_clouds(xyz, rgb, gripper_pos)
        if input(f"Step {t}/{num_steps-1}. Press Enter to continue, 'q' to quit: ") == 'q':
            break

In [16]:
max_traj_len = 1

In [31]:
gt_act_obj_label_file = '../assets/taskvars_target_label_zrange_peract.json'

gt_act_obj_labels = json.load(open(gt_act_obj_label_file))

In [33]:
gt_act_obj_labels = gt_act_obj_labels[taskvar][episode_id.decode()]
print(gt_act_obj_labels)

[{'action': 'grasp', 'object': {'name': 'gray lid', 'fine': [87], 'coarse': [87]}}, {'action': 'move grasped object', 'target': {'name': 'white jar', 'fine': [85], 'coarse': [85]}}, {'action': 'rotate grasped object'}]


In [23]:
def get_mask_with_label_ids(sem, label_ids):
    mask = sem == label_ids[0]
    for label_id in label_ids[1:]:
        mask = mask | (sem == label_id)
    return mask

In [42]:
gt_act_obj_labels[3]['action']

'rotate grasped object'

In [43]:
data = data_dict['coarse']['data']

keystep = 0
num_steps = len(data['xyz'])
for t in range(num_steps):
    print(f"keystep: {keystep}")
    if data['is_new_keystep'][t]:
        keystep += 1

    xyz, rgb, gt_sem = data['xyz'][t], data['rgb'][t], data['sem'][t]
    arm_links_info = (
        {k: v[t] for k, v in data['bbox_info'].items()},
        {k: v[t] for k, v in data['pose_info'].items()}
    )

    if t < num_steps - 1:
        gt_traj_len = len(data['trajs'][t])
        gt_trajs = copy.deepcopy(data['trajs'][t])[:max_traj_len]
    else:
        gt_traj_len = 1
        gt_trajs = copy.deepcopy(data['trajs'][-2][-1:])
    ee_pose = copy.deepcopy(data['ee_pose'][t])

    action_name = gt_act_obj_labels[keystep]['action']

    pc_label = np.zeros((gt_sem.shape[0], ), dtype=np.int32)
    for oname in ['object', 'target']:
        if oname in gt_act_obj_labels[keystep]:
            v = gt_act_obj_labels[keystep][oname]
            obj_label_ids = v['fine']
            obj_mask = get_mask_with_label_ids(gt_sem, obj_label_ids)
    if True:
        if 'object' in gt_act_obj_labels[keystep]:
            action_name = f"{action_name} {gt_act_obj_labels[keystep]['object']['name']}"
        if 'target' in gt_act_obj_labels[keystep]:
            action_name = f"{action_name} to {gt_act_obj_labels[keystep]['target']['name']}"


keystep: 0
keystep: 1
keystep: 1
keystep: 2
keystep: 2
keystep: 2


IndexError: list index out of range