In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import transforms
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched
from diffusion_edf.point_attentive_score_model import PointAttentiveScoreModel
from diffusion_edf.trainer import DiffusionEdfTrainer

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
configs_root_dir = 'configs/pick_point_attn'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'
trainer = DiffusionEdfTrainer(configs_root_dir=configs_root_dir,
                              train_configs_file=train_configs_file,
                              task_configs_file=task_configs_file)
device = trainer.device

trainer._init_dataloaders()
score_model = trainer.get_model(checkpoint_dir='runs/2023_05_29_18-25-54_StaticQueryPAscore/checkpoint/200.pt',
                                deterministic=False, 
                                device = trainer.device,).eval()

# Main Loop

In [None]:
# dataset = list(trainer.trainloader)
dataset = list(trainer.testloader)
demo_batch = dataset[1]
B = len(demo_batch)
assert B == 1, "Batch training is not supported yet."


scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

# time = train_utils.random_time(min_time=1e-3, max_time=1.0, device=T_target.device)
time = torch.tensor([1.0], dtype=T_target.dtype, device=T_target.device)
x_ref, n_neighbors = train_utils.transform_and_sample_reference_points(T_target=T_target,
                                                                       scene_points=scene_input,
                                                                       grasp_points=grasp_input,
                                                                       contact_radius=trainer.contact_radius,
                                                                       n_samples_x_ref=1)
T0, delta_T, time_in, gt_score, gt_score_ref = train_utils.diffuse_T_target(T_target=T_target, 
                                                                            x_ref=x_ref, 
                                                                            time=time, 
                                                                            lin_mult=score_model.lin_mult)
(gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = gt_score, gt_score_ref

In [None]:
with torch.no_grad():
    scene_out_multiscale: List[FeaturedPoints] = [score_model.key_model(scene_input)]
    grasp_out: FeaturedPoints = score_model.query_model(grasp_input)

In [None]:
T_shape = T0.shape
T = T0

In [None]:
N_steps = 1000
temp = 1.

# t_schedule = torch.exp(torch.linspace(math.log(1.), math.log(1e-2), N_steps+1, device=device)).unsqueeze(-1)
# t_schedule = torch.cat([t_schedule, t_schedule[-1].expand(100,1)], dim=-2)

t_schedule = torch.linspace(1., 1e-3, N_steps+1, device=device).unsqueeze(-1)
noise_schedule_ang = torch.ones_like(t_schedule)
noise_schedule_lin = torch.ones_like(t_schedule) * score_model.lin_mult
dt = torch.tensor([0.001], dtype=T.dtype, device=T.device)


for i in tqdm(range(len(t_schedule)-1)):
    t = t_schedule[i]
    noise_level_ang = noise_schedule_ang[i]
    noise_level_lin = noise_schedule_lin[i]
    with torch.no_grad():
        (ang_score, lin_score) = score_model.score_head(Ts=T.view(-1,7), 
                                                        key_pcd_multiscale=scene_out_multiscale,
                                                        query_pcd=grasp_out,
                                                        time = t.repeat(len(T)))
    ang_score = ang_score / torch.sqrt(dt)
    lin_score = lin_score / torch.sqrt(dt)

    ang_disp = ang_score * (torch.square(noise_level_ang) / temp * dt) + (torch.randn_like(ang_score) * (noise_level_ang * torch.sqrt(dt)))
    lin_disp = lin_score * (torch.square(noise_level_lin) / temp * dt) + (torch.randn_like(lin_score) * (noise_level_lin * torch.sqrt(dt)))

    L = T.detach()[...,score_model.q_indices] * score_model.q_factor
    q, x = T[...,:4], T[...,4:]
    dq = torch.einsum('...ij,...j->...i', L, ang_disp)
    dx = transforms.quaternion_apply(q, lin_disp)
    q = transforms.normalize_quaternion(q + dq)
    T = torch.cat([q, x+dx], dim=-1)

    # dT = transforms.se3_exp_map(torch.cat([lin_disp, ang_disp], dim=-1))
    # dT = torch.cat([transforms.matrix_to_quaternion(dT[..., :3, :3]), dT[..., :3, 3]], dim=-1)
    # T_next = transforms.multiply_se3(T_next, dT)

In [None]:
with torch.no_grad():
    scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
    grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)
    target_pose_pcd = PointCloud.merge(
        scene_pcd,
        grasp_pcd.transformed(SE3(T_target), squeeze=True),
    )
    diffused_pose_pcd = PointCloud.merge(
        scene_pcd,
        grasp_pcd.transformed(SE3(T0))[0],
    )
    denoised_pose_pcd = PointCloud.merge(
        scene_pcd,
        grasp_pcd.transformed(SE3(T.detach()))[0],
    )
    
    ##### Query Summary #####
    query_weight, query_points, query_point_batch = grasp_out.w.detach(), grasp_out.x.detach(), grasp_out.b.detach(), 
    batch_vis_idx = (query_point_batch == 0).nonzero().squeeze(-1)
    query_weight, query_points = query_weight[batch_vis_idx], query_points[batch_vis_idx]

    N_repeat = 500
    query_points_colors = torch.tensor([0.01, 1., 1.], device=query_weight.device, dtype=query_weight.dtype).expand(N_repeat, 1, 3) * query_weight[None, :, None]
    r_query_ball = 0.5

    ball = torch.randn(N_repeat,1,3, device=query_points.device, dtype=query_points.dtype)
    ball = ball/ball.norm(dim=-1, keepdim=True) * r_query_ball
    query_points = (query_points + ball).reshape(-1,3)
    query_points_colors = query_points_colors.reshape(-1,3)

    #########################
    scene_attn_pcd = PointCloud(points=scene_out_multiscale[0].x.detach().cpu(), 
                                colors=scene_out_multiscale[0].w.detach().cpu(),
                                cmap='magma')

In [None]:
# target_pose_pcd.show(point_size=2., width=800, height=800)
# diffused_pose_pcd.show(point_size=2., width=800, height=800)
denoised_pose_pcd.show(point_size=2., width=600, height=600)
# scene_attn_pcd.show(point_size=6., width=800, height=800)

In [None]:
scene_attn_pcd.show(point_size=6., width=600, height=600)

In [None]:
# target_pose_pcd.show(point_size=2., width=800, height=800)
diffused_pose_pcd.show(point_size=2., width=600, height=600)
# denoised_pose_pcd.show(point_size=2., width=800, height=800)
# scene_attn_pcd.show(point_size=6., width=800, height=800)