In [1]:
# import os
# os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

from typing import List, Tuple, Optional, Union, Iterable, NamedTuple
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, FeaturedPoints, merge_featured_points
from diffusion_edf import preprocess

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda'
voxel_size = 0.01
task = 'pick'

def proc_fn(demo_batch: List[TargetPoseDemo]):
    scene_pcd = []
    grasp_pcd = []
    target_poses = []
    for b, demo in enumerate(demo_batch):
        demo: TargetPoseDemo = preprocess.downsample(data=demo, voxel_size=voxel_size, coord_reduction='average')
        scene_pcd.append(demo.scene_pc.to_featured_points(batch_idx=b))
        grasp_pcd.append(demo.grasp_pc.to_featured_points(batch_idx=b))
        target_poses.append(demo.target_poses.poses)

    scene_pcd = merge_featured_points(scene_pcd) # Shape: x: (b*p, 3), f: (b*p, 3), b: (b*p, )   # b: N_batch, p: N_points_scene
    grasp_pcd = merge_featured_points(grasp_pcd) # Shape: x: (b*p, 3), f: (b*p, 3), b: (b*p, )   # b: N_batch, p: N_points_grasp
    target_poses = torch.stack(target_poses, dim=0) # Shape: (b, g, 4+3)                         # g: N_poses

    return scene_pcd, grasp_pcd, target_poses

In [3]:
n_batch = 3
if task == 'pick':
    def collate_fn(data_batch: List[DemoSequence]) -> List[TargetPoseDemo]:
        return [demo_seq[0] for demo_seq in data_batch]
elif task == 'place':
    def collate_fn(data_batch: List[DemoSequence]) -> List[TargetPoseDemo]:
        return [demo_seq[1] for demo_seq in data_batch]
else:
    raise ValueError(f"Unknown task name: {task}")

trainset = DemoSeqDataset(dataset_dir="test_data", annotation_file="data.yaml", device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=n_batch)

In [4]:
data = next(iter(train_dataloader))
scene_pcd, grasp_pcd, target_poses = proc_fn(data)

In [5]:
sadasfddsffa

NameError: name 'sadasfddsffa' is not defined

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
import math

device = 'cuda:0'
eval = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('32x0e+16x1e+8x2e') #o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
fc_neurons = [128, 64, 64]
num_heads = 4
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = 2
n_scales = 4
pool_ratio = 0.5

In [None]:
edf = EDF(irreps_input = irreps_input,
          irreps_emb_init = irreps_node_embedding,
          irreps_sh = irreps_sh,
          fc_neurons_init = [32, 16, 16],
          num_heads = 4,
          n_scales = 4,
          pool_ratio = 0.5,
          dim_mult = [1, 2, 3, 4],
          n_layers = 2,
          gnn_radius = 2.0,
          cutoff_radius = 3.0,
          deterministic = True,
          compile_head=True)

edf = edf.to(device)
if eval:
    edf = edf.eval()

# Load demo

In [None]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[0]

demo: TargetPoseDemo = demo_seq[1]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

#print(scene_proc)
#go.Figure(scene_unproc_fn(scene_proc).plotly())


node_feature = scene_proc.colors
node_coord = scene_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)

query_coord = torch.randn(3,3, device=device)
query_batch = torch.zeros_like(query_coord[...,0])

In [None]:
field_val, edf_info = edf.forward(query_coord=query_coord, query_batch=query_batch, node_feature=node_feature, node_coord=node_coord, batch=batch)

# Rotate

In [None]:
from diffusion_edf.quaternion_utils import quaternion_apply, random_quaternions
rot = random_quaternions(1, device=device)
rot = rot/rot.norm(dim=-1, keepdim=True)
trans = torch.randn(3,device=device)
node_coord_rot = quaternion_apply(rot, node_coord) + trans
query_coord_rot = quaternion_apply(rot, query_coord) + trans

node_feature_rot = node_feature

In [None]:
field_val_rot, edf_info_rot = edf.forward(query_coord=query_coord_rot, query_batch=query_batch, node_feature=node_feature_rot, node_coord=node_coord_rot, batch=batch)

In [None]:
irrep_transform = e3nn_script(TransformFeatureQuaternion(irreps = o3.Irreps(edf.irreps_emb), device=device))
a = field_val_rot
b = field_val
isclose = torch.isclose(irrep_transform(b, rot)[0], a, atol=0.01, rtol=0.01)
# print(isclose)
print(f"Equivariance ratio: {(isclose.sum() / len(a.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.