In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from rpad.pyg.dataset import CachedByKeyDataset
from rpad.visualize_3d.plots import segmentation_fig
import torch

import taxpose.datasets.pm_placement as place

# Define some common parameters.

In [None]:
scene_ids = [("11299", "ell", "0", "in")]
root = os.path.expanduser("~/datasets/partnet-mobility")
randomize_camera = True
mode = "obs"
snap_to_surface = True
full_obj = True
even_downsample = True
rotate_anchor = True

# Create a raw dataset, from which we can sample many different variations.

In [None]:
dset = place.PlaceDataset(
    root=root,
    scene_ids=scene_ids,
    randomize_camera=randomize_camera,
    mode=mode,
    snap_to_surface=snap_to_surface,
    full_obj=full_obj,
    even_downsample=even_downsample,
    rotate_anchor=rotate_anchor,
)

In [None]:
# Sample 3 different configurations, to show how the dataset is constructed.
datas = [dset[0] for _ in range(3)]
for data in datas:
    pos = torch.cat([data.action_pos, data.anchor_pos], axis=0)
    labels = torch.cat([torch.ones(len(data.action_pos)), torch.zeros(len(data.anchor_pos))], axis=0).int()
    segmentation_fig(pos, labels).show()

# Use a cached dataset instead!

In [None]:
# We sample from the same distribution, but instead of making each worker sample
# every time, we can pre-sample.
cached_dset = CachedByKeyDataset(
    dset_cls=place.PlaceDataset,
    dset_kwargs={
        "root": root,
        "randomize_camera": randomize_camera,
        "snap_to_surface": snap_to_surface,
        "full_obj": full_obj,
        "even_downsample": even_downsample,
        "rotate_anchor": rotate_anchor,
        "scene_ids": scene_ids,
        "mode": mode,
    },
    data_keys=scene_ids,
    root=root,
    processed_dirname=place.PlaceDataset.processed_dir_name(
        "obs",
        randomize_camera,
        snap_to_surface,
        full_obj,
        even_downsample,
    ),
    n_repeat=50,
    n_workers=0,
    n_proc_per_worker=2,
    seed=123456,
)

In [None]:
# Sample 3 different configurations, to show how the dataset is constructed.
datas = [cached_dset[i] for i in range(3)]
for data in datas:
    pos = torch.cat([data.action_pos, data.anchor_pos], axis=0)
    labels = torch.cat([torch.ones(len(data.action_pos)), torch.zeros(len(data.anchor_pos))], axis=0).int()
    segmentation_fig(pos, labels).show()