In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

In [None]:
from typing import List, Tuple, Optional, Union, Iterable
import datetime

import plotly.graph_objects as go
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3
from open3d.visualization.tensorboard_plugin import summary
from torch.utils.tensorboard import SummaryWriter

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.score_model import ScoreModel
from diffusion_edf import transforms
from diffusion_edf.loss import SE3DenoisingDiffusion
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score



torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
import math

device = 'cuda:0'
compile = False

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('32x0e+16x1e+8x2e') #o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
fc_neurons = [128, 64, 64]
num_heads = 4
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = 2
n_scales = 4
pool_ratio = 0.5

In [None]:
score_model = ScoreModel(irreps_input = irreps_input,
                         irreps_emb_init = irreps_node_embedding,
                         irreps_sh = irreps_sh,
                         fc_neurons_init = [32, 16, 16],
                         num_heads = 4,
                         n_scales = 4,
                         pool_ratio = 0.3,
                         dim_mult = [1, 1, 2, 2],
                         n_layers = 2,
                         gnn_radius = 2.0,
                         cutoff_radius = 4.0,
                         weight_feature_dim = 20,
                         query_downsample_ratio = 0.7,
                         device=device,
                         deterministic = False,
                         compile_head = compile)

score_model = score_model.to(device).eval()
# optimizer = torch.optim.Adam(list(score_model.parameters()), lr=1e-4, betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-4, amsgrad=True)

# Load demo

In [None]:
loss_fn = torch.nn.MSELoss(reduction='mean')
trainset = DemoSeqDataset(dataset_dir="demo/test_demo", annotation_file="data.yaml", device=device)
train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda x:x)
eval_data = []
for data in train_dataloader:
    eval_data += data

In [None]:
checkpoint_file: Optional[str] = "runs/2023_04_20_21-28-29/checkpoint/1911.pt"

if checkpoint_file is not None:
    checkpoint = torch.load(checkpoint_file)
    score_model.load_state_dict(checkpoint['score_model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    steps = checkpoint['steps']
    print(f"Successfully Loaded checkpoint @ epoch: {epoch} (steps: {steps})")
else:
    print(f"Initialize without loading from checkpoint.")
    epoch = 0
    steps = 0

In [None]:
demo_seq: DemoSequence = eval_data[0]
demo: TargetPoseDemo = demo_seq[1]

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses_raw: SE3 = demo.target_poses

scene_proc: PointCloud = scene_proc_fn(scene_raw).to(device)
grasp_proc: PointCloud = grasp_proc_fn(grasp_raw).to(device)
target_poses: SE3 = rescale_fn(target_poses_raw).to(device)
T_target: torch.Tensor = target_poses.poses

x_ref, n_neighbors = sample_reference_points(PointCloud.transform_pcd(scene_proc, target_poses.inv())[0].points, grasp_proc.points, r=3)

min_time = 0.01 #1e-3
max_time = 0.03  
time_in = (min_time/max_time + torch.rand(1, dtype=T_target.dtype, device=T_target.device) * (1-min_time/max_time))*max_time
lin_mult = 10.
std = torch.sqrt(time_in) * lin_mult
eps = time_in / 2

T, delta_T, gt_score, gt_score_ref = diffuse_isotropic_se3(T0 = T_target, eps=eps, std=std, N=1, angular_first=True, double_precision=True)
T, delta_T, gt_score = T.squeeze(0), delta_T.squeeze(0), gt_score.squeeze(0)
target_score = gt_score_ref * torch.tensor([2*torch.sqrt(eps), 2*torch.sqrt(eps), 2*torch.sqrt(eps), std, std, std], device=eps.device, dtype=eps.dtype)

key_feature = scene_proc.colors
key_coord = scene_proc.points
key_batch = torch.zeros(len(key_coord), device=device, dtype=torch.long)
query_feature = grasp_proc.colors
query_coord = grasp_proc.points
query_batch = torch.zeros(len(query_coord), device=device, dtype=torch.long)

In [None]:
# score, query, query_info, key_info = score_model(T=T,
#                                                     key_feature=key_feature, key_coord=key_coord, key_batch=key_batch,
#                                                     query_feature=query_feature, query_coord=query_coord, query_batch=query_batch,
#                                                     info_mode='NONE', angular_first= True, time=time_in)
# score_ref = adjoint_inv_tr_isotropic_se3_score(x_ref=-x_ref, score=score, angular_first=True)

In [None]:
with torch.no_grad():
    query, query_info = score_model._get_query(node_feature=query_feature,
                                            node_coord=query_coord,
                                            batch=query_batch,
                                            info_mode='NONE')
    key_gnn_outputs = score_model.key_model.get_gnn_outputs(node_feature=key_feature, node_coord=key_coord, batch=key_batch)

In [None]:
t_max = time_in
anneal_mult = 3
N_anneal = 100
dt = torch.tensor([t_max * anneal_mult / N_anneal], dtype=T.dtype, device=T.device)
print(t_max.item(), dt.item())

T_next = T
for n in tqdm(range(N_anneal)):
    t = t_max - dt/anneal_mult * n

    with torch.no_grad():
        time_emb = score_model.get_time_emb(t)
        score, key_extractor_info = score_model.get_score(T=T_next, query=query, key_gnn_outputs=key_gnn_outputs, time_emb = time_emb, angular_first=True)

    std = torch.sqrt(time_in) * lin_mult
    eps = time_in / 2
    disp = score / torch.tensor([2*torch.sqrt(eps), 2*torch.sqrt(eps), 2*torch.sqrt(eps), std, std, std], device=eps.device, dtype=eps.dtype) * (0.5 * dt)
    disp = disp + (torch.randn_like(score) * torch.sqrt(dt))

    L = T_next.detach()[...,score_model.q_indices] * score_model.q_factor
    q, x = T_next[...,:4], T_next[...,4:]
    dq = torch.einsum('...ij,...j->...i', L, disp[...,:3])
    dx = transforms.quaternion_apply(q, disp[...,3:])
    q_next = transforms.normalize_quaternion(q + dq)
    T_next = torch.cat([q_next, x+dx], dim=-1)

    # dT = transforms.se3_exp_map(torch.cat([disp[..., 3:], disp[..., :3]], dim=-1))
    # dT = torch.cat([transforms.matrix_to_quaternion(dT[..., :3, :3]), dT[..., :3, 3]], dim=-1)
    # T_next = transforms.multiply_se3(T_next, dT)

In [None]:
target_pose_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed(target_poses_raw)[0])
diffused_pose_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed( recover_scale_fn(SE3(T.detach())) )[0])
denoised_pose_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed( recover_scale_fn(SE3(T_next.detach())) )[0])

In [None]:
denoised_pose_pcd.show(width=800, height=800)

In [None]:
diffused_pose_pcd.show(width=800, height=800)