In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr, detach_featured_points
from diffusion_edf import train_utils
from diffusion_edf import transforms
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched
from diffusion_edf.point_attentive_score_model import PointAttentiveScoreModel
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
configs_root_dir = 'configs/pick_point_attn'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'
trainer = DiffusionEdfTrainer(configs_root_dir=configs_root_dir,
                              train_configs_file=train_configs_file,
                              task_configs_file=task_configs_file)
device = trainer.device

trainer._init_dataloaders()
score_model = trainer.get_model(checkpoint_dir='runs/2023_05_30_17-05-29_MultiphaseTraining/checkpoint/200.pt',
                                deterministic=False, 
                                device = trainer.device,).eval()

# Main Loop

In [None]:
dataset = list(trainer.trainloader)
# dataset = list(trainer.testloader)
demo_batch = dataset[1]
B = len(demo_batch)
assert B == 1, "Batch training is not supported yet."


# scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
# T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

# time = torch.tensor([trainer.t_max], dtype=T_target.dtype, device=T_target.device)
# print(f"Diffusion time: {time.item()}")
# x_ref, n_neighbors = train_utils.transform_and_sample_reference_points(T_target=T_target,
#                                                                        scene_points=scene_input,
#                                                                        grasp_points=grasp_input,
#                                                                        contact_radius=trainer.contact_radius,
#                                                                        n_samples_x_ref=1)
# T0, delta_T, time_in, gt_score, gt_score_ref = train_utils.diffuse_T_target(T_target=T_target, 
#                                                                             x_ref=x_ref, 
#                                                                             time=time, 
#                                                                             lin_mult=score_model.lin_mult)
# (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = gt_score, gt_score_ref

scene_input, grasp_input, _ = train_utils.flatten_batch(demo_batch=demo_batch)
T0 = torch.cat([
    transforms.random_quaternions(1, device=device),
    torch.distributions.Uniform(scene_input.x[:].min(dim=0).values, scene_input.x[:].max(dim=0).values).sample(sample_shape=(1,))
], dim=-1)


scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)
# diffused_pose_pcd = PointCloud.merge(
#     scene_pcd,
#     grasp_pcd.transformed(SE3(T0))[0],
# )
# diffused_pose_pcd.show(point_size=2., width=600, height=600)

In [None]:
with torch.no_grad():
    scene_out_multiscale: List[FeaturedPoints] = [score_model.key_model(scene_input)]
    grasp_out: FeaturedPoints = score_model.query_model(grasp_input)
    
scene_attn_pcd = PointCloud(points=scene_out_multiscale[0].x.detach().cpu(), 
                            colors=scene_out_multiscale[0].w.detach().cpu(),
                            cmap='magma')
scene_attn_pcd.show(point_size=6., width=600, height=600)

In [None]:
T = T0.clone().detach()
T_shape = T.shape

N_steps_per_phase = [500, 500, 500, 500]
steps_per_record = 10
temp = 0.5

Ts = [T.clone().detach()]
steps = 0
for n_steps, schedule in zip(N_steps_per_phase, trainer.diffusion_schedules):
    t_schedule = torch.linspace(schedule[0], schedule[1], n_steps, device=device)
    # step_schedule = t_schedule/t_schedule.max() * 0.1
    step_schedule = torch.ones_like(t_schedule) * 0.001 * t_schedule.max()
    t_schedule = t_schedule.unsqueeze(-1)

    for i in tqdm(range(len(t_schedule))):
        t = t_schedule[i]
        dt = t_schedule[i]
        with torch.no_grad():
            (ang_score, lin_score) = score_model.score_head(Ts=T.view(-1,7), 
                                                            key_pcd_multiscale=scene_out_multiscale,
                                                            query_pcd=grasp_out,
                                                            time = t.repeat(len(T)))
        ang_score = ang_score / torch.sqrt(t)
        lin_score = lin_score / torch.sqrt(t) / score_model.lin_mult

        ang_disp = ang_score * dt / (2*temp) + (torch.randn_like(ang_score) * torch.sqrt(dt))
        lin_disp = lin_score * dt / (2*temp) + (torch.randn_like(lin_score) * torch.sqrt(dt))

        L = T.detach()[...,score_model.q_indices] * score_model.q_factor
        q, x = T[...,:4], T[...,4:]
        dq = torch.einsum('...ij,...j->...i', L, ang_disp)
        dx = transforms.quaternion_apply(q, lin_disp)
        q = transforms.normalize_quaternion(q + dq)
        T = torch.cat([q, x+dx], dim=-1)

        # dT = transforms.se3_exp_map(torch.cat([lin_disp, ang_disp], dim=-1))
        # dT = torch.cat([transforms.matrix_to_quaternion(dT[..., :3, :3]), dT[..., :3, 3]], dim=-1)
        # T = transforms.multiply_se3(T, dT)
        steps += 1
        if steps % steps_per_record == 0:
            Ts.append(T.clone().detach())

Ts.append(T.clone().detach())

In [None]:
# import importlib
# import diffusion_edf.visualize
# importlib.reload(diffusion_edf.visualize)
# visualize_pose = diffusion_edf.visualize.visualize_pose
fig_grasp, fig_sample = visualize_pose(scene_pcd, grasp_pcd, poses=SE3(torch.cat(Ts, dim=0).detach()), point_size=3.0, width=1000, height=1000)
fig_sample.show()