In [1]:
from typing import List, Tuple, Optional, Union

import plotly.graph_objects as go

import torch
from torchvision.transforms import Compose

from e3nn import o3
from e3nn.util.jit import compile_mode, script as e3nn_script

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.block import PoolingBlock, EquiformerBlock
from diffusion_edf.connectivity import RadiusGraph

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [3]:
device = 'cuda:0'
compile = True
eval = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
irreps_block_output = o3.Irreps('128x0e+64x1e+32x2e')
number_of_basis = 128
fc_neurons = [128, 64, 64]
irreps_head = o3.Irreps('32x0e+16x1o+8x2e')
num_heads = 4
irreps_pre_attn = None
rescale_degree = False
nonlinear_message = True
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = o3.Irreps('384x0e+192x1e+96x2e')
norm_layer = 'layer'




node_enc = NodeEmbeddingNetwork(irreps_input=irreps_input, irreps_node_emb=irreps_node_embedding)
pool_block = PoolingBlock(irreps_src = irreps_node_embedding,
                          irreps_dst = irreps_node_embedding,
                          irreps_edge_attr = irreps_sh,
                          irreps_head = irreps_head,
                          num_heads = num_heads,
                          fc_neurons = fc_neurons,
                          pool_radius = 2.0,
                          pool_ratio = 0.5,
                          pool_method = 'fps',
                          deterministic = True,
                          irreps_mlp_mid = 3,
                          attn_type='mlp',
                          alpha_drop=alpha_drop, 
                          proj_drop=proj_drop,
                          drop_path_rate=drop_path_rate)

if compile:
    node_enc = e3nn_script(node_enc).to(device)
    pool_block = e3nn_script(pool_block).to(device)
else:
    node_enc = node_enc.to(device)
    pool_block = pool_block.to(device)

if eval:
    node_enc.eval()
    pool_block.eval()



# Load demo

In [4]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[0]

demo: TargetPoseDemo = demo_seq[0]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

#print(scene_proc)
#go.Figure(scene_unproc_fn(scene_proc).plotly())


node_feature = scene_proc.colors
node_coord = scene_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)

TargetPoseDemo  (name: pick_demo)


In [5]:
node_emb = node_enc(node_feature)
output_unrot, node_coord_dst, edge_src, edge_dst, edge_length, edge_attr, degree, batch_dst = pool_block(node_feature=node_emb,
                                                                                                         node_coord=node_coord,
                                                                                                         batch=batch)

# Rotate

In [6]:
from pytorch3d.transforms import quaternion_apply, random_quaternions
rot = random_quaternions(1, device=device)
rot = rot/rot.norm(dim=-1, keepdim=True)
trans = torch.randn(3,device=device)
node_coord_rot = quaternion_apply(rot, node_coord)# + trans

In [7]:
node_feature_rot = scene_proc.colors
node_feature_rot = node_enc(node_feature_rot)
output_rot, node_coord_dst_rot, edge_src_rot, edge_dst_rot, edge_length, edge_attr, degree_rot, batch_dst_rot = pool_block(node_feature=node_feature_rot,
                                                                                                                           node_coord=node_coord_rot,
                                                                                                                           batch=batch)

In [8]:
irrep_transform = e3nn_script(TransformFeatureQuaternion(irreps = irreps_node_embedding, device=device))
a = output_rot
b = output_unrot
isclose = torch.isclose(irrep_transform(b, rot), a, atol=0.001, rtol=0.001)
isclose

  return forward_call(*input, **kwargs)


tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]], device='cuda:0')

In [9]:
f"Equivariance ratio: {(isclose.sum() / len(a.view(-1))).item()}" # Slight non-equivariance comes from FPS downsampling algorithm.

'Equivariance ratio: 0.9955335259437561'