In [None]:
from typing import List, Tuple, Optional, Union

import plotly.graph_objects as go

import torch
from torchvision.transforms import Compose

from e3nn import o3
from e3nn.util.jit import compile_mode as e3nn_compile_mode, script as e3nn_script

from diffusion_edf.equiformer.graph_attention_transformer import TransBlock, EdgeDegreeEmbeddingNetwork, FeedForwardNetwork, SeparableFCTP
from diffusion_edf.equiformer.gaussian_rbf import GaussianRadialBasisLayer
from diffusion_edf.equiformer.tensor_product_rescale import LinearRS, irreps2gate, FullyConnectedTensorProductRescaleSwishGate
from diffusion_edf.equiformer.fast_activation import Activation, Gate

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.pooling import RadiusGraph, FpsPool

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
device = 'cuda:0'
compile = True
eval = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('128x0e+64x1e+32x2e')
irreps_node_attr = o3.Irreps('1x0e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
irreps_block_output = o3.Irreps('128x0e+64x1e+32x2e')
number_of_basis = 128
fc_neurons = [64, 64]
irreps_head = o3.Irreps('32x0e+16x1o+8x2e')
num_heads = 4
irreps_pre_attn = None
rescale_degree = False
nonlinear_message = True
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = o3.Irreps('384x0e+192x1e+96x2e')
norm_layer = 'layer'

r_scene = scene_voxel_size * 2
r_grasp = grasp_voxel_size * 2



rbf = GaussianRadialBasisLayer(number_of_basis, cutoff=r_scene * 0.99)
node_enc = NodeEmbeddingNetwork(irreps_input=irreps_input, irreps_node_emb=irreps_node_embedding)
edge_deg_enc = EdgeDegreeEmbeddingNetwork(irreps_node_embedding=irreps_node_embedding,
                                          irreps_edge_attr=irreps_sh, 
                                          fc_neurons = [number_of_basis] + fc_neurons, 
                                          avg_aggregate_num = 4)
block1 = TransBlock(irreps_node_input=irreps_node_embedding, 
                    irreps_node_attr=irreps_node_attr, 
                    irreps_edge_attr=irreps_sh, 
                    irreps_node_output=irreps_block_output,
                    fc_neurons= [number_of_basis] + fc_neurons, 
                    irreps_head=irreps_head, 
                    num_heads=num_heads, 
                    irreps_pre_attn=irreps_pre_attn, 
                    rescale_degree=rescale_degree,
                    attn_type='mlp',
                    alpha_drop=alpha_drop, 
                    proj_drop=proj_drop,
                    drop_path_rate=drop_path_rate,
                    irreps_mlp_mid=irreps_mlp_mid,
                    norm_layer=norm_layer)

no_pool = RadiusGraph(r=r_scene, self_connect=True, max_num_neighbors=1000)
fps_pool = FpsPool(ratio=0.5, random_start=True, r=r_scene, max_num_neighbors=1000)

if compile:
    rbf = torch.jit.script(rbf).to(device)
    node_enc = e3nn_script(node_enc).to(device)
    edge_deg_enc = e3nn_script(edge_deg_enc).to(device)
    block1 = e3nn_script(block1).to(device)
    no_pool = torch.jit.script(no_pool).to(device)
    fps_pool = torch.jit.script(fps_pool).to(device)
else:
    rbf = rbf.to(device)
    node_enc = node_enc.to(device)
    edge_deg_enc = edge_deg_enc.to(device)
    block1 = block1.to(device)
    no_pool = no_pool.to(device)
    fps_pool = fps_pool.to(device)

if eval:
    rbf.eval()
    node_enc.eval()
    edge_deg_enc.eval()
    block1.eval()
    no_pool.eval()
    fps_pool.eval()

# Load demo

In [None]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[0]

demo: TargetPoseDemo = demo_seq[0]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)

print(scene_proc)
go.Figure(scene_unproc_fn(scene_proc).plotly())

In [None]:
node_feature = scene_proc.colors
node_coord = scene_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)
node_attr = torch.ones(len(node_coord), 1, device=device) # dummy

In [None]:
node_coord_1, node_feature_1, edge_src, edge_dst, degree, batch = fps_pool(node_coord_src = node_coord, node_feature_src = node_feature, batch_src = batch)
print(len(edge_dst), len(node_coord_1))

pooled_pcd = PointCloud(points=node_coord_1.cpu(), colors=node_feature_1.cpu())

print(pooled_pcd)
go.Figure(scene_unproc_fn(pooled_pcd).plotly())

In [None]:
node_coord_2, node_feature_2, edge_src, edge_dst, degree, batch = fps_pool(node_coord_src = node_coord_1, node_feature_src = node_feature_1, batch_src = batch)
print(len(edge_dst), len(node_coord_2))

pooled_pcd = PointCloud(points=node_coord_2.cpu(), colors=node_feature_2.cpu())

print(pooled_pcd)
go.Figure(scene_unproc_fn(pooled_pcd).plotly())

In [None]:
node_coord_3, node_feature_3, edge_src, edge_dst, degree, batch = fps_pool(node_coord_src = node_coord_2, node_feature_src = node_feature_2, batch_src = batch)
print(len(edge_dst), len(node_coord_3))

pooled_pcd = PointCloud(points=node_coord_3.cpu(), colors=node_feature_3.cpu())

print(pooled_pcd)
go.Figure(scene_unproc_fn(pooled_pcd).plotly())

In [None]:
# edge_vec = node_coord.index_select(0, edge_src) - node_coord.index_select(0, edge_dst)
# edge_sh = o3.spherical_harmonics(l=irreps_sh, x=edge_vec, normalize=True, normalization='component')
# edge_length = edge_vec.norm(dim=1)
# edge_length_emb = rbf(edge_length)

# node_emb = node_enc(node_feature)
# edge_degree_emb = edge_deg_enc(node_input = node_emb, 
#                                edge_attr = edge_sh, 
#                                edge_scalars = edge_length_emb, 
#                                edge_src = edge_src, 
#                                edge_dst = edge_dst)

# node_input = node_emb + edge_degree_emb
# edge_scalars = edge_length_emb


# output_features =  block1(node_input = node_input, 
#                           node_attr = node_attr, 
#                           edge_src = edge_src, 
#                           edge_dst = edge_dst,
#                           edge_attr = edge_sh,
#                           edge_scalars = edge_scalars,
#                           batch = batch)