In [1]:
from typing import List, Tuple, Optional, Union

import plotly.graph_objects as go

import torch
from torchvision.transforms import Compose

from e3nn import o3
from e3nn.util.jit import compile_mode, script as e3nn_script

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.block import PoolingBlock, EquiformerBlock, DownBlock, EdfExtractor
from diffusion_edf.connectivity import RadiusGraph

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [3]:
device = 'cuda:0'
compile = True
eval = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('64x0e+32x1e+16x2e') #o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
irreps_block_output = o3.Irreps('64x0e+32x1e+16x2e') #o3.Irreps('128x0e+64x1e+32x2e')
number_of_basis = 128
fc_neurons = [128, 64, 64]
irreps_head = o3.Irreps('16x0e+8x1e+4x2e') #o3.Irreps('32x0e+16x1o+8x2e')
num_heads = 4
irreps_pre_attn = None
rescale_degree = False
nonlinear_message = True
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = 2
norm_layer = 'layer'
n_scales = 4



node_enc = NodeEmbeddingNetwork(irreps_input=irreps_input, irreps_node_emb=irreps_node_embedding)
down_block = DownBlock(irreps = irreps_node_embedding,
                       irreps_edge_attr = irreps_sh,
                       irreps_head = irreps_head,
                       num_heads = num_heads,
                       fc_neurons = fc_neurons,
                       init_radius = 2.0,
                       pool_ratio = 0.5,
                       n_scales = n_scales,
                       n_layers_per_scale = 2,
                       pool_method = 'fps',
                       deterministic = True,
                       irreps_mlp_mid = irreps_mlp_mid,
                       attn_type='mlp',
                       alpha_drop=alpha_drop, 
                       proj_drop=proj_drop,
                       drop_path_rate=drop_path_rate)

if compile:
    node_enc = e3nn_script(node_enc).to(device)
    down_block = e3nn_script(down_block).to(device)
else:
    node_enc = node_enc.to(device)
    down_block = down_block.to(device)

if eval:
    node_enc.eval()
    down_block.eval()



In [4]:
extractor = EdfExtractor(
    irreps_inputs = [o3.Irreps('64x0e+32x1e+16x2e') for _ in range(n_scales)],
    fc_neurons_inputs = [[64, 32, 32] for _ in range(n_scales)],
    irreps_emb = irreps_node_embedding,
    irreps_edge_attr = irreps_sh,
    irreps_head = irreps_head,
    num_heads = num_heads,
    fc_neurons = fc_neurons,
    n_layers = 1,
    cutoffs = [2., 3.5, 5., 7.5],
    offsets = [0.01, 0.01, 0.01, 0.01],
    query_radius = 8.,
    irreps_mlp_mid = irreps_mlp_mid,
    attn_type='mlp',
    alpha_drop=alpha_drop, 
    proj_drop=proj_drop,
    drop_path_rate=drop_path_rate
)

In [5]:
extractor = extractor.to(device)
extractor.eval()
pass

# Load demo

In [6]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[0]

demo: TargetPoseDemo = demo_seq[0]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

#print(scene_proc)
#go.Figure(scene_unproc_fn(scene_proc).plotly())


node_feature = scene_proc.colors
node_coord = scene_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)

query_points = torch.randn(3,3, device=device)

TargetPoseDemo  (name: pick_demo)


In [7]:
node_emb = node_enc(node_feature)
outputs = down_block(node_feature=node_emb,
                     node_coord=node_coord,
                     batch=batch)

field_val = extractor(query_coord = query_points, 
                      query_batch = torch.zeros_like(query_points[...,0]),
                      node_features = [output[0] for output in outputs[::2]],
                      node_coords = [output[1] for output in outputs[::2]],
                      node_batches = [output[7] for output in outputs[::2]])

# Rotate

In [8]:
from pytorch3d.transforms import quaternion_apply, random_quaternions
rot = random_quaternions(1, device=device)
rot = rot/rot.norm(dim=-1, keepdim=True)
trans = torch.randn(3,device=device)
node_coord_rot = quaternion_apply(rot, node_coord)# + trans
query_points_rot = quaternion_apply(rot, query_points)# + trans

In [9]:
node_feature_rot = scene_proc.colors
node_emb_rot = node_enc(node_feature_rot)
outputs_rot = down_block(node_feature=node_emb_rot,
                         node_coord=node_coord_rot,
                         batch=batch)

field_val_rot = extractor(query_coord = query_points_rot, 
                      query_batch = torch.zeros_like(query_points[...,0]),
                      node_features = [output[0] for output in outputs_rot[::2]],
                      node_coords = [output[1] for output in outputs_rot[::2]],
                      node_batches = [output[7] for output in outputs_rot[::2]])

In [10]:
irrep_transform = e3nn_script(TransformFeatureQuaternion(irreps = irreps_node_embedding, device=device))
a = field_val_rot
b = field_val
isclose = torch.isclose(irrep_transform(b, rot)[0], a, atol=0.001, rtol=0.001)
# print(isclose)
print(f"Equivariance ratio: {(isclose.sum() / len(a.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

Equivariance ratio: 1.0


  return forward_call(*input, **kwargs)
