In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr, detach_featured_points
from diffusion_edf import train_utils
from diffusion_edf import transforms
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched
from diffusion_edf.point_attentive_score_model import PointAttentiveScoreModel
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose

torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
configs_root_dir = 'configs/pick_lowres'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'
lowres_trainer = DiffusionEdfTrainer(configs_root_dir=configs_root_dir,
                              train_configs_file=train_configs_file,
                              task_configs_file=task_configs_file)
configs_root_dir = 'configs/pick_highres'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'
highres_trainer = DiffusionEdfTrainer(configs_root_dir=configs_root_dir,
                              train_configs_file=train_configs_file,
                              task_configs_file=task_configs_file)
assert lowres_trainer.device == highres_trainer.device
device = lowres_trainer.device


lowres_trainer._init_dataloaders()
lowres_model = lowres_trainer.get_model(checkpoint_dir='runs/2023_06_01_13-10-19_Lowres/checkpoint/200.pt',
                                        deterministic=False, 
                                        device = device,).eval()
highres_model = highres_trainer.get_model(checkpoint_dir='runs/2023_06_01_13-23-18_Highres/checkpoint/200.pt',
                                         deterministic=False, 
                                         device = device,).eval()

# Main Loop

In [None]:
# dataset = list(lowres_trainer.trainloader)
dataset = list(lowres_trainer.testloader)
demo_batch = dataset[0]
B = len(demo_batch)
assert B == 1, "Batch training is not supported yet."


# scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
# T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

# time = torch.tensor([trainer.t_max], dtype=T_target.dtype, device=T_target.device)
# print(f"Diffusion time: {time.item()}")
# x_ref, n_neighbors = train_utils.transform_and_sample_reference_points(T_target=T_target,
#                                                                        scene_points=scene_input,
#                                                                        grasp_points=grasp_input,
#                                                                        contact_radius=trainer.contact_radius,
#                                                                        n_samples_x_ref=1)
# T0, delta_T, time_in, gt_score, gt_score_ref = train_utils.diffuse_T_target(T_target=T_target, 
#                                                                             x_ref=x_ref, 
#                                                                             time=time, 
#                                                                             lin_mult=score_model.lin_mult)
# (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = gt_score, gt_score_ref

scene_input, grasp_input, _ = train_utils.flatten_batch(demo_batch=demo_batch)
# T0 = torch.cat([
#     transforms.random_quaternions(1, device=device),
#     torch.distributions.Uniform(scene_input.x[:].min(dim=0).values, scene_input.x[:].max(dim=0).values).sample(sample_shape=(1,))
# ], dim=-1)
T0 = torch.cat([
    #transforms.random_quaternions(1, device=device),
    #torch.tensor([[math.sqrt(0.5), -math.sqrt(0.5), 0.0, 0.]], device=device),
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[-30., -30., 30.]], device=device)
], dim=-1)
scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)


# diffused_pose_pcd = PointCloud.merge(
#     scene_pcd,
#     grasp_pcd.transformed(SE3(T0))[0],
# )
# diffused_pose_pcd.show(point_size=2., width=600, height=600)

In [None]:
with torch.no_grad():
    scene_out_multiscale: List[FeaturedPoints] = [lowres_model.key_model(scene_input)]
    grasp_out: FeaturedPoints = lowres_model.query_model(grasp_input)

if scene_out_multiscale[0].w is not None:
    scene_attn_pcd = PointCloud(points=scene_out_multiscale[0].x.detach().cpu(), 
                                colors=scene_out_multiscale[0].w.detach().cpu(),
                                cmap='magma')
    scene_attn_pcd.show(point_size=6., width=600, height=600)
else:
    scene_attn_pcd = None

In [None]:
Ts_lowres = lowres_model.sample(T_seed=T0.clone().detach(),
                                scene_pcd_multiscale=scene_out_multiscale,
                                grasp_pcd=grasp_out,
                                diffusion_schedules=lowres_trainer.diffusion_schedules,
                                N_steps=[500, 500],
                                timesteps=[0.01, 0.01],
                                add_noise=True
                                temperature=1.)
T_lowres = Ts_lowres[-1:,:]

In [None]:
with torch.no_grad():
    scene_out_multiscale: List[FeaturedPoints] = highres_model.key_model(scene_input)
    grasp_out: FeaturedPoints = highres_model.query_model(grasp_input)

if scene_out_multiscale[0].w is not None:
    scene_attn_pcd = PointCloud(points=scene_out_multiscale[0].x.detach().cpu(), 
                                colors=scene_out_multiscale[0].w.detach().cpu(),
                                cmap='magma')
    scene_attn_pcd.show(point_size=6., width=600, height=600)
else:
    scene_attn_pcd = None

In [None]:
Ts_highres = highres_model.sample(T_seed=T_lowres.clone().detach(),
                                  scene_pcd_multiscale=scene_out_multiscale,
                                  grasp_pcd=grasp_out,
                                  diffusion_schedules=highres_trainer.diffusion_schedules,
                                  N_steps=[500, 500],
                                  timesteps=[0.01, 0.01],
                                  add_noise=True,
                                  temperature=1.)
T_highres = Ts_highres[-1:,:]

In [None]:
# import importlib
# import diffusion_edf.visualize
# importlib.reload(diffusion_edf.visualize)
# visualize_pose = diffusion_edf.visualize.visualize_pose

Ts = torch.cat([Ts_lowres, Ts_highres], dim=0)
fig_grasp, fig_sample = visualize_pose(scene_pcd, grasp_pcd, poses=SE3(Ts[::10].float()), 
                                       point_size=3.0, width=800, height=800,
                                       ranges=torch.tensor([[-40., 40.], [-40., 40.], [0., 40.]]))
fig_sample.show()