In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

In [None]:
from typing import List, Tuple, Optional, Union, Iterable

import plotly.graph_objects as go

import torch
from torchvision.transforms import Compose

from e3nn import o3

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.query_model import QueryModel

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
import math

device = 'cuda:0'
eval = True
compile = True

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('32x0e+16x1e+8x2e') #o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
fc_neurons = [128, 64, 64]
num_heads = 4
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = 2
n_scales = 4
pool_ratio = 0.5

In [None]:
query_model = QueryModel(irreps_input = irreps_input,
                         irreps_emb_init = irreps_node_embedding,
                         irreps_sh = irreps_sh,
                         fc_neurons_init = [32, 16, 16],
                         num_heads = 4,
                         n_scales = 4,
                         pool_ratio = 0.5,
                         dim_mult = [1, 2, 3, 4],
                         n_layers = 2,
                         gnn_radius = 2.0,
                         cutoff_radius = 3.0,
                         weight_feature_dim = 20,
                         query_downsample_ratio = 0.3,
                         deterministic = True,
                         compile_head = False)

if compile:
    query_model = torch.jit.script(query_model)
query_model = query_model.to(device)
if eval:
    query_model = query_model.eval()

# Load demo

In [None]:
demo_list: List[DemoSequence] = load_demos(dir='demo/test_demo')
demo_seq: DemoSequence = demo_list[1]

demo: TargetPoseDemo = demo_seq[1]
print(demo)

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: TargetPoseDemo = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)


node_feature = grasp_proc.colors
node_coord = grasp_proc.points
batch = torch.zeros(len(node_coord), device=device, dtype=torch.long)


# go.Figure(scene_unproc_fn(grasp_proc).plotly())

In [None]:
query, query_info = query_model(node_feature=node_feature, node_coord=node_coord, batch=batch)
query_weight, query_feature, query_coord, query_batch = query

In [None]:
query_pcd = PointCloud(points=query_coord.detach().cpu(), colors=torch.zeros(len(query_coord), 3))
query_pcd = grasp_unproc_fn(query_pcd)


query_opacity = torch.ones_like(query_weight).cpu() # (query_weight ** 1)
query_trace = PointCloud.points_to_plotly(pcd=query_pcd, point_size=7.0, opacity= query_opacity / query_opacity.max())
go.Figure([scene_unproc_fn(grasp_proc).plotly(), query_trace])

# Rotate

In [None]:
from diffusion_edf.quaternion_utils import quaternion_apply, random_quaternions
rot = random_quaternions(1, device=device)
rot = rot/rot.norm(dim=-1, keepdim=True)
trans = torch.randn(3,device=device)
node_coord_rot = quaternion_apply(rot, node_coord) + trans
node_feature_rot = node_feature

In [None]:
query_rot, query_info_rot = query_model(node_feature=node_feature_rot, node_coord=node_coord_rot, batch=batch)
query_weight_rot, query_feature_rot, query_coord_rot, query_batch_rot = query_rot

In [None]:
irrep_transform = TransformFeatureQuaternion(irreps = o3.Irreps(query_model.irreps_emb), device=device)
a = query_feature_rot
b = query_feature
isclose = torch.isclose(irrep_transform(b, rot)[0], a, atol=0.001, rtol=0.001)
# print(isclose)
print(f"Feature Equivariance ratio: {(isclose.sum() / len(a.view(-1))).item()}") # Slight non-equivariance comes from FPS downsampling algorithm.

In [None]:
torch.isclose(query_weight_rot, query_weight, atol=0.001, rtol=0.001)

In [None]:
torch.isclose(query_coord_rot, quaternion_apply(rot, query_coord) + trans, atol=0.001, rtol=0.001)

In [None]:
query_pcd_rot = PointCloud(points=query_coord_rot.detach().cpu(), colors=torch.zeros(len(query_coord_rot), 3))
query_pcd_rot = grasp_unproc_fn(query_pcd_rot)


query_opacity_rot = torch.ones_like(query_weight_rot).cpu() # (query_weight ** 1)
query_trace_rot = PointCloud.points_to_plotly(pcd=query_pcd_rot, point_size=7.0, opacity= query_opacity_rot / query_opacity_rot.max())
go.Figure([scene_unproc_fn(
    grasp_proc.transformed(torch.cat([rot, trans.unsqueeze(0)], dim=-1))[0]
    ).plotly(), query_trace_rot])