In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from rpad.rlbench_utils.placement_dataset import RLBenchPlacementDataset, load_handle_mapping, load_state_pos_dict, TASK_DICT
import numpy as np

from rpad.visualize_3d.plots import segmentation_fig
import torch
import matplotlib.pyplot as plt

In [None]:
dset = RLBenchPlacementDataset(
    dataset_root="/data/rlbench10/",
    # task_name="stack_wine",
    # task_name="insert_onto_square_peg",
    # task_name="insert_usb_in_computer",
    # task_name="phone_on_base",
    # task_name="put_toilet_roll_on_stand",
    # task_name="place_hanger_on_rack",
    # task_name="solve_puzzle",
    task_name="take_umbrella_out_of_umbrella_stand",
    demos=range(10),
    phase="all",
)

# Show the "initial" and "final" observations for a single episode.

In [None]:
dset = RLBenchPlacementDataset(
    dataset_root="/data/rlbench10/",
    # task_name="stack_wine",
    # task_name="insert_onto_square_peg",
    # task_name="insert_usb_in_computer",
    # task_name="phone_on_base",
    # task_name="put_toilet_roll_on_stand",
    # task_name="place_hanger_on_rack",
    # task_name="solve_puzzle",
    task_name="take_umbrella_out_of_umbrella_stand",
    demos=[0],
    phase="all",
    use_first_as_init_keyframe=False,
)

# Number of phases.
N = len(dset)

# Grid of images, where the rows are the initial and final images, and the columns are the different phases.


fig, axes = plt.subplots(2, N, figsize=(5 * N, 10))
for i in range(N):
    data = dset[i]
    axes[0, i].imshow(data["init_front_rgb"])
    axes[1, i].imshow(data["key_front_rgb"])

    # Title of the column.
    axes[0, i].set_title(data["phase"])


# For each phase, show the first 10 final observations.

In [None]:
# For each phase.
for phase in TASK_DICT["take_umbrella_out_of_umbrella_stand"]["phase"].keys():
    N_DEMOS = 10
    # Create a dataset for that phase.
    dset = RLBenchPlacementDataset(
        dataset_root="/data/rlbench10/",
        task_name="take_umbrella_out_of_umbrella_stand",
        demos=range(N_DEMOS),
        phase=phase,
    )

    # Create 2 x 5 grid of images.
    fig, axes = plt.subplots(2, N_DEMOS // 2, figsize=(16, 8))
    for i in range(N_DEMOS):
        data = dset[i]
        axes[i // 5, i % 5].imshow(data["key_front_rgb"])
        axes[i // 5, i % 5].set_title(f"Demo {i}")

# Plot the keyframes for each task.

In [None]:
# For each keyframe in the data, make a grid of images which show each keyframe.

import matplotlib.pyplot as plt

RLBENCH10_TASKS = [
    # "insert_onto_square_peg",
    # "pick_and_lift",
    # "put_knife_on_chopping_board",
    # "take_money_out_safe",
    # "pick_up_cup",
    # "put_money_in_safe",
    # "slide_block_to_target",
    # "take_umbrella_out_of_umbrella_stand",
    # "push_button",
    # "reach_target",
    "stack_wine",
    
]

for task in RLBENCH10_TASKS:
    dset = RLBenchPlacementDataset(
        # dataset_root="/data/rlbench10/",
        dataset_root="/home/beisner/datasets/rlbench/",
        task_name=task,
        n_demos=1,
    )

    data = dset[0]

    # Create the figure, which has 1 row for each episode, and 1 column for each keyframe in the demo.
    fig, axes = plt.subplots(nrows=1, ncols=len(data["keyframes"]))
    fig.set_size_inches(5 * len(data["keyframes"]), 5)

    # One row for each episode
    for ep in [data]:
        if len(ep["keyframes"]) == 1:
            axes = [axes]
        # One column for each keyframe
        for i, ix in enumerate(ep["keyframes"]):
            keyframe = ep["demo"][ix]
            # Plot the image

            axes[i].imshow(keyframe.front_rgb)
            axes[i].set_title(f"Keyframe {i}")
            axes[i].axis("off")

    plt.suptitle(f"Task: {task}")
    plt.show()

# Visualize the first keyframes, in order to see which objects matter (including an interactive plot)

In [None]:
# task_name = "pick_and_lift"
# task_name = "put_knife_on_chopping_board"
# task_name = "take_money_out_safe"
# task_name = "put_money_in_safe"
# task_name = "slide_block_to_target"
# task_name = "take_umbrella_out_of_umbrella_stand"
# task_name = "push_button"
task_name = "reach_target"

# Getting individual frames.
dset = RLBenchPlacementDataset(
    dataset_root="/data/rlbench10/",
    # task_name="pick_and_lift",
    task_name=task_name,
    n_demos=1,
    phase="grasp",
    debugging=True,
)

data = dset[0]

# Plot RGB image of the initial rgb and final rgb, as well as initial mask and final mask.
import matplotlib.pyplot as plt

%matplotlib widget

plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.imshow(data["init_front_rgb"])
plt.subplot(2, 2, 2)
plt.imshow(data["key_front_rgb"])
plt.subplot(2, 2, 3)
plt.imshow(data["init_front_mask"])
plt.subplot(2, 2, 4)
plt.imshow(data["key_front_mask"])
plt.show()

# Extract various properties we can use to annotate the keyframes.

In [None]:
# Get a mapping from handle id to handle name.
handle_mapping = load_handle_mapping("/data/rlbench10/", task_name, 0)
rev_handle_mapping = {v: k for k, v in handle_mapping.items()}

q_id = 100
rev_handle_mapping[q_id]

In [None]:
rev_handle_mapping[137]

In [None]:
state_pos_dict = load_state_pos_dict("/data/rlbench10/", task_name, 0, 0)
state_pos_dict

In [None]:
import numpy as np
import plotly.graph_objs as go
import numpy.typing as npt
from typing import Dict, Optional, Sequence

from rpad.visualize_3d.plots import _3d_scene, _segmentation_traces

def segmentation_fig_rc(
    data: npt.ArrayLike,
    labels: npt.ArrayLike,
    labelmap: Optional[Dict[int, str]] = None,
    sizes: Optional[Sequence[int]] = None,
    fig: Optional[go.Figure] = None,
    row: int = 1,
    column: int = 1,
):
    """Creates a segmentation figure."""
    # Create a figure.
    if fig is None:
        fig = go.Figure()

    scene_num = (row-1) * 5 + column

    fig.add_traces(_segmentation_traces(data, labels, labelmap, f"scene{scene_num}", sizes), rows=row, cols=column)

    fig.update_layout(

        showlegend=True,
        margin=dict(l=0, r=0, b=0, t=40),
        legend=dict(x=1.0, y=0.75),
        **{f"scene{scene_num}":_3d_scene(data),}
    )

    return fig

In [None]:
data.keys()

In [None]:
# For each phase, plot the segmentation.

from plotly.subplots import make_subplots

phase = list(TASK_DICT["take_umbrella_out_of_umbrella_stand"]["phase"].keys())[0]

N_DEMOS = 10
# Create a dataset for that phase.
dset = RLBenchPlacementDataset(
    dataset_root="/data/rlbench10/",
    task_name="take_umbrella_out_of_umbrella_stand",
    demos=range(N_DEMOS),
    phase=phase,
)

fig = make_subplots(rows=2, cols=5, specs=[[{"type": "scene"}] * 5] * 2)

for i in range(N_DEMOS):
    data = dset[i]
    init_action_pc = data["init_action_pc"]
    key_action_pc = data["key_action_pc"]
    key_anchor_pc = data["key_anchor_pc"]
    pcd = np.concatenate([init_action_pc, key_action_pc, key_anchor_pc], axis=0)
    labels = np.concatenate(
        [
            np.zeros(init_action_pc.shape[0]),
            np.ones(key_action_pc.shape[0]),
            2 * np.ones(key_anchor_pc.shape[0]),
        ]
    ).astype(int)
    labelmap = {0: "init_action_pc", 1: "key_action_pc", 2: "key_anchor_pc"}
    fig = segmentation_fig_rc(pcd, labels, labelmap, fig=fig, row=(i // 5) + 1, column=(i % 5) + 1)
fig.show()

# Visualize all the phases in plotly for a single task.

In [None]:
# task_name = "stack_wine"
# task_name = "reach_target"
# task_name = "put_money_in_safe"
# task_name = "take_money_out_safe"
task_name = "put_knife_on_chopping_board"

for phase in TASK_DICT[task_name]["phase_order"]:
    print(f"Phase: {phase}")
    dset = RLBenchPlacementDataset(
        dataset_root="/data/rlbench10/",
        task_name=task_name,
        n_demos=1,
        phase=phase,
        debugging=False,
        use_first_as_init_keyframe=False,
    )

    data = dset[0]

    # Plot segmentation with segmentation_fig

    print(list(data.keys()))

    points = torch.cat(
        [
            data["init_action_pc"],
            data["init_anchor_pc"],
            data["key_action_pc"],
        ]
    )
    print(points.shape)
    seg = torch.cat(
        [
            torch.zeros(data["init_action_pc"].shape[0]),
            torch.ones(data["init_anchor_pc"].shape[0]),
            2 * torch.ones(data["key_action_pc"].shape[0]),
        ]
    )
    fig = segmentation_fig(
        points,
        seg.int(),
        labelmap={0: "init_action", 1: "init_anchor", 2: "key_action"},
    )
    fig.show()

    

In [None]:
!ls /data/rlbench10/put_money_in_safe/variation0/episodes/episode0