In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from open3d.visualization.tensorboard_plugin import summary
from torch.utils.tensorboard import SummaryWriter

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf import transforms
from diffusion_edf.feature_extractor import UnetFeatureExtractor
from diffusion_edf.tensor_field import TensorField
from diffusion_edf.radial_func import SinusoidalPositionEmbeddings
from diffusion_edf.equivariant_score_model import ScoreModel
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
device = 'cuda:0'
eval = True
compile = False

model_configs_dir = 'configs/test/model_configs.yaml'
eval_configs_dir = 'configs/test/eval_configs.yaml'
task_configs_dir = 'configs/test/task_configs.yaml'

with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(eval_configs_dir) as file:
    eval_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(task_configs_dir) as file:
    task_configs = yaml.load(file, Loader=yaml.FullLoader)

eval_configs['preprocess_config'].append({
    'name': 'Rescale',
    'kwargs': {'rescale_factor': 1/task_configs['unit_length']}
})

# Load demo

In [None]:
proc_fn = []
for proc in eval_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['kwargs'])
    )
proc_fn = Compose(proc_fn)


collate_fn = train_utils.get_collate_fn(task=eval_configs['task_type'], proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=eval_configs['dataset_dir'], annotation_file=eval_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=eval_configs['n_batches'])

# Load Model

In [None]:
score_model = ScoreModel(**model_configs, deterministic=False).to(device=device)
if compile:
    raise NotImplementedError
if eval:
    score_model = score_model.eval()

# optimizer = torch.optim.Adam(list(score_model.parameters()), **train_configs['optimizer_kwargs'])

In [None]:
checkpoint_file: Optional[str] = eval_configs['checkpoint_file']
task_type = eval_configs['task_type']

if checkpoint_file is not None:
    checkpoint = torch.load(checkpoint_file)
    score_model.load_state_dict(checkpoint['score_model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    steps = checkpoint['steps']
    print(f"Successfully Loaded checkpoint @ epoch: {epoch} (steps: {steps})")
else:
    print(f"Initialize without loading from checkpoint.")
    epoch = 0
    steps = 0

# Main Loop

In [None]:
contact_radius = eval_configs['contact_radius']

In [None]:
for demo_batch in train_dataloader:
    B = len(demo_batch)
    assert B == 1, "Batch training is not supported yet."

    scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
    T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

    # min_time = torch.tensor([1e-3], dtype=T_target.dtype, device=T_target.device)
    # max_time = torch.tensor([1.0], dtype=T_target.dtype, device=T_target.device)
    # time_in = (min_time/max_time + torch.rand(1, dtype=T_target.dtype, device=T_target.device) * (1-min_time/max_time))*max_time   # Shape: (1,)
    # #time_in = torch.exp(torch.rand_like(max_time) * (torch.log(max_time)-torch.log(min_time)) + torch.log(min_time))              # Shape: (1,)
    time_in = torch.tensor([1.0], dtype=T_target.dtype, device=T_target.device)

    eps = time_in / 2   # Shape: (1,)
    std = torch.sqrt(time_in) * score_model.lin_mult   # Shape: (1,)
    x_ref, n_neighbors = sample_reference_points(
        src_points = PointCloud(points=scene_input.x, colors=scene_input.f).transformed(
                                SE3(T_target).inv(), squeeze=True
                                ).points, 
        dst_points = grasp_input.x, 
        r=contact_radius, 
        n_samples=1
    )
    # T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
    T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3_batched(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
    T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = T.squeeze(-2), delta_T.squeeze(-2), (gt_ang_score.squeeze(-2), gt_lin_score.squeeze(-2)), (gt_ang_score_ref.squeeze(-2), gt_lin_score_ref.squeeze(-2))
    # T: (nT, 7) || delta_T: (nT, 7) || gt_*_score_*: (nT, 3) ||
    # Note that nT = n_samples_x_ref * nT_target  ||   nT_target = 1

    break

In [None]:
with torch.no_grad():
    scene_out: FeaturedPoints = score_model.key_feature_extractor(scene_input)
    grasp_out: FeaturedPoints = score_model.query_feature_extractor(grasp_input)

In [None]:
N_steps = 500
temp = 1.

# t_schedule = torch.exp(torch.linspace(math.log(1.), math.log(1e-2), N_steps+1, device=device)).unsqueeze(-1)
# t_schedule = torch.cat([t_schedule, t_schedule[-1].expand(100,1)], dim=-2)

t_schedule = torch.linspace(1., 1e-3, N_steps+1, device=device).unsqueeze(-1)
noise_schedule_ang = torch.ones_like(t_schedule)
noise_schedule_lin = torch.ones_like(t_schedule) * score_model.lin_mult
dt = torch.tensor([0.001], dtype=T.dtype, device=T.device)


T_shape = T.shape
T_next = T 
for i in tqdm(range(len(t_schedule)-1)):
    t = t_schedule[i]
    noise_level_ang = noise_schedule_ang[i]
    noise_level_lin = noise_schedule_lin[i]
    with torch.no_grad():
        (ang_score, lin_score) = score_model.score_head(Ts=T_next.view(-1,7), 
                                                        key_pcd=scene_out,
                                                        query_pcd=grasp_out,
                                                        time = t.repeat(len(T_next)))
    ang_score = ang_score / torch.sqrt(dt)
    lin_score = lin_score / torch.sqrt(dt)

    ang_disp = ang_score * (torch.square(noise_level_ang) / temp * dt) + (torch.randn_like(ang_score) * (noise_level_ang * torch.sqrt(dt)))
    lin_disp = lin_score * (torch.square(noise_level_lin) / temp * dt) + (torch.randn_like(lin_score) * (noise_level_lin * torch.sqrt(dt)))

    L = T_next.detach()[...,score_model.q_indices] * score_model.q_factor
    q, x = T_next[...,:4], T_next[...,4:]
    dq = torch.einsum('...ij,...j->...i', L, ang_disp)
    dx = transforms.quaternion_apply(q, lin_disp)
    q_next = transforms.normalize_quaternion(q + dq)
    T_next = torch.cat([q_next, x+dx], dim=-1)

    # dT = transforms.se3_exp_map(torch.cat([lin_disp, ang_disp], dim=-1))
    # dT = torch.cat([transforms.matrix_to_quaternion(dT[..., :3, :3]), dT[..., :3, 3]], dim=-1)
    # T_next = transforms.multiply_se3(T_next, dT)

In [None]:
with torch.no_grad():
    scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
    grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)
    target_pose_pcd = PointCloud.merge(
        scene_pcd,
        grasp_pcd.transformed(SE3(T_target), squeeze=True),
    )
    diffused_pose_pcd = PointCloud.merge(
        scene_pcd,
        grasp_pcd.transformed(SE3(T))[0],
    )
    denoised_pose_pcd = PointCloud.merge(
        scene_pcd,
        grasp_pcd.transformed(SE3(T_next.detach()))[0],
    )
    
    ##### Query Summary #####
    query_weight, query_points, query_point_batch = grasp_out.w.detach(), grasp_out.x.detach(), grasp_out.b.detach(), 
    batch_vis_idx = (query_point_batch == 0).nonzero().squeeze(-1)
    query_weight, query_points = query_weight[batch_vis_idx], query_points[batch_vis_idx]

    N_repeat = 500
    query_points_colors = torch.tensor([0.01, 1., 1.], device=query_weight.device, dtype=query_weight.dtype).expand(N_repeat, 1, 3) * query_weight[None, :, None]
    r_query_ball = 0.5

    ball = torch.randn(N_repeat,1,3, device=query_points.device, dtype=query_points.dtype)
    ball = ball/ball.norm(dim=-1, keepdim=True) * r_query_ball
    query_points = (query_points + ball).reshape(-1,3)
    query_points_colors = query_points_colors.reshape(-1,3)

    #########################
    scene_attn_pcd = PointCloud(points=scene_out.x.detach().cpu(), 
                                colors=scene_out.w.detach().cpu(),
                                cmap='magma')

In [None]:
# target_pose_pcd.show(point_size=2., width=800, height=800)
# diffused_pose_pcd.show(point_size=2., width=800, height=800)
# denoised_pose_pcd.show(point_size=2., width=800, height=800)
scene_attn_pcd.show(point_size=3., width=800, height=800)