In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# make a grid of 2x5 images of pngs, read from a folder

import matplotlib.pyplot as plt

import numpy as np
import os

from PIL import Image

# path to folder with pngs
path = os.path.expanduser("~/datasets/rlbench")

# task = "put_toilet_roll_on_stand"
task = "stack_wine"
# task = "phone_on_base"
# task = "insert_onto_square_peg"
# task = "place_hanger_on_rack"
# task = "solve_puzzle"

dirs = [os.path.join(path, task, "variation0", "episodes", f"episode{i}", "overhead_rgb") for i in range(10)]
start_pngs = []
end_pngs = []
for d in dirs:
    # The files have the format 1.png, 2.png, 3.png, etc.
    # We want the last one, so we need a sort which is not lexicographic
    files = os.listdir(d)

    files = sorted(files, key=lambda x: int(x.split(".")[0]))
    start_pngs.append(os.path.join(d, files[0]))
    end_pngs.append(os.path.join(d, files[-1]))

    # sort the files by 
    print(files[-1])

# create a grid of 2x5 images
fig, axs = plt.subplots(2, 5, figsize=(20, 10))
axs = axs.flatten()

for i, png in enumerate(start_pngs):
    img = Image.open(png)
    axs[i].imshow(img)
    axs[i].axis("off")

    # Add a title to the image
    axs[i].set_title(f"Episode {i}")

# create a grid of 2x5 images
fig, axs = plt.subplots(2, 5, figsize=(20, 10))
axs = axs.flatten()

for i, png in enumerate(end_pngs):
    img = Image.open(png)
    axs[i].imshow(img)
    axs[i].axis("off")

    # Add a title to the image
    axs[i].set_title(f"Episode {i}")




In [None]:
from taxpose.datasets.rlbench import RLBenchPointCloudDataset, RLBenchPointCloudDatasetConfig


dset = RLBenchPointCloudDataset(cfg=RLBenchPointCloudDatasetConfig(
    dataset_root=os.path.expanduser("~/datasets/rlbench"),
    # task_name="insert_onto_square_peg",
    # task_name="stack_wine",
    # task_name="put_toilet_roll_on_stand",
    # task_name="phone_on_base",
    # task_name="place_hanger_on_rack",
    task_name="solve_puzzle",
    episodes=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    phase="place",
)) 

In [None]:
for i in range(10):
    print(dset[i]['points_action'].shape)

In [None]:
import open3d as o3d
import open3d.web_visualizer as w3d

# for i in range(len(dset)):
data = dset[0]

pcd = o3d.geometry.PointCloud()



# Yellow points
pcd.points = o3d.utility.Vector3dVector(data["points_action"][0])
pcd.colors = o3d.utility.Vector3dVector(data["action_symmetry_rgb"][0] / 255.0)

pcd1 = o3d.geometry.PointCloud()

# Red points
pcd1.points = o3d.utility.Vector3dVector(data["points_anchor"][0])
pcd1.colors = o3d.utility.Vector3dVector(data["anchor_symmetry_rgb"][0] / 255.0)

# Draw the point clouds
# o3d.visualization.draw_geometries([
#     # pcd,
#     pcd1,
#     ])
w3d.draw(pcd)

In [None]:
# Iterate through the dataset, and plot 2D renders of 3D point clouds in matplotlib.
# We want to plot on 3D matplotlib axes.
# Make a grid of 2 x 5

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(20, 10))
axs = fig.subplots(2, 5, subplot_kw={"projection": "3d"})
axs = axs.flatten()

for i, data in enumerate(dset):
    axs[i].scatter(data["points_anchor"][0][:, 0], data["points_anchor"][0][:, 1], data["points_anchor"][0][:, 2], c=data["anchor_symmetry_rgb"][0] / 255.0)
    axs[i].scatter(data["points_action"][0][:, 0], data["points_action"][0][:, 1], data["points_action"][0][:, 2], c=data["action_symmetry_rgb"][0] / 255.0)

    axs[i].set_title(f"Episode {i}")

    # Get the combined points
    points = np.concatenate([data["points_action"][0], data["points_anchor"][0]], axis=0)

    # Set the axes limits
    axs[i].set_xlim3d(points[:, 0].min(), points[:, 0].max())
    axs[i].set_ylim3d(points[:, 1].min(), points[:, 1].max())
    axs[i].set_zlim3d(points[:, 2].min(), points[:, 2].max())

    # Remove the ticks
    axs[i].set_xticks([])
    axs[i].set_yticks([])
    axs[i].set_zticks([])

    # Remove the axes
    axs[i].set_axis_off()

In [None]:
data["points_action"][0][:, 0].min()

In [None]:
data["action_symmetry_features"]