In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from edf_interface.data import PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints
from diffusion_edf import train_utils
from diffusion_edf.trainer import DiffusionEdfTrainer
from diffusion_edf.visualize import visualize_pose

torch.set_printoptions(precision=4, sci_mode=False)
device = 'cuda:0'

# Load Low-resolution Model

In [None]:
configs_root_dir = 'configs/pick_lowres'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'

lowres_trainer = DiffusionEdfTrainer(
    configs_root_dir=configs_root_dir,
    train_configs_file=train_configs_file,
    task_configs_file=task_configs_file,
    device=device
)

lowres_trainer._init_dataloaders()
lowres_model = lowres_trainer.get_model(
    checkpoint_dir='example_runs/2023_06_01_21-36-31_Pick_LowRes/checkpoint/200.pt',
    deterministic=False, 
    device = device
).eval()

lowres_trainer.warmup_score_model(
    score_model=lowres_model, 
    n_warmups=10
)

# Load Super-resolution Model

In [None]:
configs_root_dir = 'configs/pick_highres'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'

highres_trainer = DiffusionEdfTrainer(
    configs_root_dir=configs_root_dir,
    train_configs_file=train_configs_file,
    task_configs_file=task_configs_file,
    device=device
)

highres_trainer._init_dataloaders()
highres_model = highres_trainer.get_model(
    checkpoint_dir='example_runs/2023_06_01_21-47-51_Pick_HiRes/checkpoint/200.pt',
    deterministic=False, 
    device = device,
).eval()

highres_trainer.warmup_score_model(
    score_model=highres_model, 
    n_warmups=10
)

# Initialize Input Data and Initial Pose

In [None]:
################## Load test data ######################

# dataset = list(lowres_trainer.trainloader)
dataset = list(lowres_trainer.testloader)
demo_batch = dataset[3]
B = len(demo_batch)
assert B == 1, "Batch training is not supported yet."

scene_input, grasp_input, _ = train_utils.flatten_batch(demo_batch=demo_batch)
scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)

##################### Initial pose #####################

#### Random pose ####
# T0 = torch.cat([
#     transforms.random_quaternions(1, device=device),
#     torch.distributions.Uniform(scene_input.x[:].min(dim=0).values, scene_input.x[:].max(dim=0).values).sample(sample_shape=(1,))
# ], dim=-1)

#### Pose 1 ####
# T0 = torch.cat([
#     transforms.random_quaternions(1, device=device),
#     torch.tensor([[math.sqrt(0.5), -math.sqrt(0.5), 0.0, 0.]], device=device),
# ], dim=-1)

#### Pose 2 ####
T0 = torch.cat([
    torch.tensor([[1., 0., 0.0, 0.]], device=device),
    torch.tensor([[-30., -30., 30.]], device=device)
], dim=-1)

################## Visualize Diffused Pose #################

# diffused_pose_pcd = PointCloud.merge(
#     scene_pcd,
#     grasp_pcd.transformed(SE3(T0))[0],
# )
# diffused_pose_pcd.show(point_size=2., width=600, height=600)

# Run Low-resolution Diffusion Model

### Feature Extraction

In [None]:
#################### Feature extraction #####################
with torch.no_grad():
    scene_out_multiscale: List[FeaturedPoints] = lowres_model.get_key_pcd_multiscale(scene_input)
    grasp_out: FeaturedPoints = lowres_model.get_query_pcd(grasp_input)


################ Visualize Scene Attention Map #####################

# scene_attn_pcd = PointCloud(points=scene_out_multiscale[0].x.detach().cpu(), 
#                             colors=scene_out_multiscale[0].w.detach().cpu(),
#                             cmap='magma')
# scene_attn_pcd.show(point_size=6., width=600, height=600)


### Sampling

In [None]:
#################### Sample #####################
with torch.no_grad():
    Ts_lowres = lowres_model.sample(T_seed=T0.clone().detach(),
                                    scene_pcd_multiscale=scene_out_multiscale,
                                    grasp_pcd=grasp_out,
                                    diffusion_schedules=lowres_trainer.diffusion_schedules,
                                    N_steps=[500, 500],
                                    timesteps=[0.02, 0.02],
                                    temperature=1.)
    T_lowres = Ts_lowres[-1:,:]

# Super-resolution Diffusion Model

### Feature Extraction

In [None]:
with torch.no_grad():
    scene_out_multiscale: List[FeaturedPoints] = highres_model.get_key_pcd_multiscale(scene_input)
    grasp_out: FeaturedPoints = highres_model.get_query_pcd(grasp_input)

### Sampling

In [None]:
with torch.no_grad():
    Ts_highres = highres_model.sample(T_seed=T_lowres.clone().detach(),
                                      scene_pcd_multiscale=scene_out_multiscale,
                                      grasp_pcd=grasp_out,
                                      diffusion_schedules=highres_trainer.diffusion_schedules,
                                      N_steps=[500, 1000],
                                      timesteps=[0.02, 0.05],
                                      temperature=1.)
    T_highres = Ts_highres[-1:,:]

# Optimize

In [None]:
# with torch.no_grad():
#     scene_out_multiscale: List[FeaturedPoints] = highres_model.get_key_pcd_multiscale(scene_input)
#     grasp_out: FeaturedPoints = highres_model.get_query_pcd(grasp_input)
    
#     Ts_optim = highres_model.sample(T_seed=T_highres.clone().detach(),
#                                     scene_pcd_multiscale=scene_out_multiscale,
#                                     grasp_pcd=grasp_out,
#                                     diffusion_schedules=[[0.005, 0.003]],
#                                     N_steps=[1000],
#                                     timesteps=[0.05],
#                                     temperature=1.,
#                                     ang_noise_mult=0.,
#                                     lin_noise_mult=0.,)

# Visualize Result

### Sampling

In [None]:
# import importlib
# import diffusion_edf.visualize
# importlib.reload(diffusion_edf.visualize)
# visualize_pose = diffusion_edf.visualize.visualize_pose

Ts = torch.cat([Ts_lowres, Ts_highres], dim=0).float()
# Ts = torch.cat([Ts_lowres, Ts_highres, Ts_optim], dim=0).float()
Ts_visualize = torch.cat([Ts[::10], Ts[-1:]], dim=0)
fig_grasp, fig_sample = visualize_pose(scene_pcd, grasp_pcd, poses=SE3(Ts_visualize), 
                                       point_size=3.0, width=800, height=800,
                                       ranges=torch.tensor([[-40., 40.], [-40., 40.], [0., 40.]]))
fig_sample.show()