In [1]:
from typing import List, Tuple, Optional, Union

import plotly.graph_objects as go

import torch
from torchvision.transforms import Compose

from e3nn import o3
from e3nn.util.jit import compile_mode, script as e3nn_script

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.block import PoolingBlock

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [3]:
device = 'cuda:0'
compile = True
eval = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
irreps_block_output = o3.Irreps('128x0e+64x1e+32x2e')
number_of_basis = 128
fc_neurons = [64, 64]
irreps_head = o3.Irreps('32x0e+16x1o+8x2e')
num_heads = 4
irreps_pre_attn = None
rescale_degree = False
nonlinear_message = True
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = o3.Irreps('384x0e+192x1e+96x2e')
norm_layer = 'layer'

In [4]:
node_enc = NodeEmbeddingNetwork(irreps_input=irreps_input, irreps_node_emb=irreps_node_embedding)
pool_block = PoolingBlock(irreps_src = irreps_node_embedding,
                          irreps_dst = irreps_node_embedding,
                          irreps_edge_attr = irreps_sh,
                          irreps_head = irreps_head,
                          num_heads = num_heads,
                          fc_neurons = fc_neurons,
                          pool_radius = 2.0,
                          pool_ratio = 0.5,
                          pool_method = 'fps',
                          deterministic = True,
                          irreps_mlp_mid = 3,
                          attn_type='mlp',
                          alpha_drop=alpha_drop, 
                          proj_drop=proj_drop,
                          drop_path_rate=drop_path_rate)

if compile:
    node_enc = e3nn_script(node_enc).to(device)
    pool_block = e3nn_script(pool_block).to(device)
else:
    node_enc = node_enc.to(device)
    pool_block = pool_block.to(device)

if eval:
    node_enc.eval()
    pool_block.eval()



# Load demo

In [5]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[0]

demo: TargetPoseDemo = demo_seq[0]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

#print(scene_proc)
#go.Figure(scene_unproc_fn(scene_proc).plotly())


node_feature = scene_proc.colors
node_coord = scene_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)

TargetPoseDemo  (name: pick_demo)


In [6]:
node_feature = node_enc(node_feature)

In [7]:
pool_block(node_feature=node_feature,
           node_coord=node_coord,
           batch=batch)

(tensor([[-0.3958,  0.4951,  0.2688,  ...,  0.0545, -0.2257,  0.0591],
         [-0.5648,  0.7542,  0.4275,  ..., -0.0663, -0.2869,  0.8080],
         [ 0.2097,  0.8758,  0.1280,  ...,  0.5349, -0.2646,  1.2482],
         ...,
         [-0.5382,  1.5607,  0.8833,  ..., -0.1673, -0.2981,  0.6712],
         [-0.4200,  0.9316,  0.3788,  ...,  0.2293, -0.2905,  0.5461],
         [-0.3254,  1.4588,  0.1213,  ..., -0.8972, -0.2744,  0.2207]],
        device='cuda:0', grad_fn=<AddBackward0>),
 tensor([[-29.4233, -29.5070,   0.0592],
         [ 29.5753,  29.5039,   0.0522],
         [ 29.5753, -29.5039,   0.0521],
         ...,
         [ 23.6634,   2.5225,   0.0531],
         [-22.4204, -13.5039,   0.0480],
         [-29.4188,  -6.4802,   0.0500]], device='cuda:0'),
 tensor([ 1, 60, 61,  ..., 82, 83, 84], device='cuda:0'),
 tensor([   0,    0,    0,  ..., 2403, 2403, 2403], device='cuda:0'),
 tensor([ 3,  5,  4,  ..., 12, 10,  7], device='cuda:0'),
 tensor([0, 0, 0,  ..., 0, 0, 0], device='cu

In [None]:
safd

In [None]:
node_coord_1, node_feature_1, edge_src, edge_dst, degree, batch = fps_pool(node_coord_src = node_coord, node_feature_src = node_feature, batch_src = batch)
print(len(edge_dst), len(node_coord_1))

pooled_pcd = PointCloud(points=node_coord_1.cpu(), colors=node_feature_1.cpu())

print(pooled_pcd)
go.Figure(scene_unproc_fn(pooled_pcd).plotly())

In [None]:
node_coord_2, node_feature_2, edge_src, edge_dst, degree, batch = fps_pool(node_coord_src = node_coord_1, node_feature_src = node_feature_1, batch_src = batch)
print(len(edge_dst), len(node_coord_2))

pooled_pcd = PointCloud(points=node_coord_2.cpu(), colors=node_feature_2.cpu())

print(pooled_pcd)
go.Figure(scene_unproc_fn(pooled_pcd).plotly())

In [None]:
node_coord_3, node_feature_3, edge_src, edge_dst, degree, batch = fps_pool(node_coord_src = node_coord_2, node_feature_src = node_feature_2, batch_src = batch)
print(len(edge_dst), len(node_coord_3))

pooled_pcd = PointCloud(points=node_coord_3.cpu(), colors=node_feature_3.cpu())

print(pooled_pcd)
go.Figure(scene_unproc_fn(pooled_pcd).plotly())

In [None]:
# edge_vec = node_coord.index_select(0, edge_src) - node_coord.index_select(0, edge_dst)
# edge_sh = o3.spherical_harmonics(l=irreps_sh, x=edge_vec, normalize=True, normalization='component')
# edge_length = edge_vec.norm(dim=1)
# edge_length_emb = rbf(edge_length)

# node_emb = node_enc(node_feature)
# edge_degree_emb = edge_deg_enc(node_input = node_emb, 
#                                edge_attr = edge_sh, 
#                                edge_scalars = edge_length_emb, 
#                                edge_src = edge_src, 
#                                edge_dst = edge_dst)

# node_input = node_emb + edge_degree_emb
# edge_scalars = edge_length_emb


# output_features =  block1(node_input = node_input, 
#                           node_attr = node_attr, 
#                           edge_src = edge_src, 
#                           edge_dst = edge_dst,
#                           edge_attr = edge_sh,
#                           edge_scalars = edge_scalars,
#                           batch = batch)