In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

In [None]:
from typing import List, Tuple, Optional, Union, Iterable

import plotly.graph_objects as go

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3
from torch.utils.tensorboard import SummaryWriter

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.score_model import ScoreModel
from diffusion_edf.transforms import quaternion_apply, random_quaternions, quaternion_multiply, quaternion_invert, axis_angle_to_quaternion, se3_exp_map, matrix_to_quaternion, se3_log_map, standardize_quaternion
from diffusion_edf.loss import SE3DenoisingDiffusion
from diffusion_edf.utils import sample_reference_points


torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
import math

device = 'cuda:0'
compile = False

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('32x0e+16x1e+8x2e') #o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
fc_neurons = [128, 64, 64]
num_heads = 4
alpha_drop = 0.2
proj_drop = 0.0
drop_path_rate = 0.0
irreps_mlp_mid = 2
n_scales = 4
pool_ratio = 0.5

In [None]:
score_model = ScoreModel(irreps_input = irreps_input,
                         irreps_emb_init = irreps_node_embedding,
                         irreps_sh = irreps_sh,
                         fc_neurons_init = [32, 16, 16],
                         num_heads = 4,
                         n_scales = 4,
                         pool_ratio = 0.3,
                         dim_mult = [1, 1, 2, 2],
                         n_layers = 2,
                         gnn_radius = 2.0,
                         cutoff_radius = 4.0,
                         weight_feature_dim = 20,
                         query_downsample_ratio = 0.7,
                         device=device,
                         deterministic = False,
                         compile_head = compile)

score_model = score_model.to(device)
optimizer = torch.optim.Adam(list(score_model.parameters()), lr=1e-4, betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-4, amsgrad=True)

# Load demo

In [None]:
loss_fn = torch.nn.MSELoss(reduction='mean')
diffusion = SE3DenoisingDiffusion()

In [None]:
trainset = DemoSeqDataset(dataset_dir="demo/test_demo", annotation_file="data.yaml", device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=lambda x:x)
writer = SummaryWriter()

In [None]:
max_epochs = 1000
N_samples = 10

iter_ = 0
loss_list = []

In [None]:
demo_seq: DemoSequence = next(iter(train_dataloader))[0]
demo: TargetPoseDemo = demo_seq[1]

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: SE3 = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)
target_poses = rescale_fn(target_poses).to(device)

In [None]:
x_ref, n_neighbors = sample_reference_points(PointCloud.transform_pcd(scene_proc, target_poses.inv())[0].points, grasp_proc.points, r=3)
T_ref = SE3.from_orn_and_pos(orns=torch.cat([torch.ones_like(x_ref[:,:1]), torch.zeros_like(x_ref)], dim=-1), positions=x_ref)


T_target = target_poses.poses
T_target_ref = (SE3(T_target) * SE3(T_target)).poses
target_score_ref, T_diffused_ref, time_in = diffusion.diffuse(T_target_ref, angular_first=False)


target_score_lin = target_score_ref[..., :3]
target_score_ang = target_score_ref[..., 3:] + torch.cross(x_ref, target_score_ref[..., :3])  # Ad_{T_{ref}^{-1}}^{T}
target_score = torch.cat([target_score_lin, target_score_ang], dim=-1)
target_score = target_score * torch.sqrt(time_in) # Is it necessary?
T = (SE3(T_diffused_ref) * T_ref.inv()).poses


key_feature = scene_proc.colors
key_coord = scene_proc.points
key_batch = torch.zeros(len(key_coord), device=device, dtype=torch.long)
query_feature = grasp_proc.colors
query_coord = grasp_proc.points
query_batch = torch.zeros(len(query_coord), device=device, dtype=torch.long)

In [None]:
for epoch in range(1, max_epochs+1):
    for train_batch in train_dataloader:
        iter_ += 1
        assert len(train_batch) == 1, "Batch training is not supported yet."

        optimizer.zero_grad(set_to_none=True)

        demo_seq: DemoSequence = train_batch[0]
        demo: TargetPoseDemo = demo_seq[1]

        scene_raw: PointCloud = demo.scene_pc
        grasp_raw: PointCloud = demo.grasp_pc
        target_poses: SE3 = demo.target_poses

        scene_proc = scene_proc_fn(scene_raw).to(device)
        grasp_proc = grasp_proc_fn(grasp_raw).to(device)
        target_poses = rescale_fn(target_poses).to(device)

        T_target = target_poses.poses
        target_score, T, time_in = diffusion.diffuse(T_target, angular_first=False, manual_time=0.01)
        target_score = target_score * torch.sqrt(time_in) # Is it necessary?

        key_feature = scene_proc.colors
        key_coord = scene_proc.points
        key_batch = torch.zeros(len(key_coord), device=device, dtype=torch.long)
        query_feature = grasp_proc.colors
        query_coord = grasp_proc.points
        query_batch = torch.zeros(len(query_coord), device=device, dtype=torch.long)

        score, query, query_info, key_info = score_model(T=T,
                                                         key_feature=key_feature, key_coord=key_coord, key_batch=key_batch,
                                                         query_feature=query_feature, query_coord=query_coord, query_batch=query_batch,
                                                         info_mode='NONE', angular_first= False, time=time_in)
        
        loss = loss_fn(target_score, score)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        writer.add_scalar("Loss/train", loss.item(), iter_)
        writer.add_scalar("score_norm", torch.norm(score.detach(),dim=-1).item(), iter_)
    print(torch.tensor(loss_list[-len(trainset):]).mean())

# Eval

In [None]:
from diffusion_edf.pc_utils import get_plotly_fig


def get_raw_pointcloud(**kwargs) -> Tuple[PointCloud, PointCloud]:

    ################### Write your custom codes here ###################
    dir, idx, pick_or_place = kwargs['dir'], kwargs['idx'], kwargs['pick_or_place']

    demos = DemoSeqDataset(dataset_dir="demo/test_demo", annotation_file="data.yaml")
    demo: DemoSequence = demos[idx]
    if pick_or_place == 'pick':
        demo: TargetPoseDemo = demo[0]
    elif pick_or_place == 'place':
        demo: TargetPoseDemo = demo[1]
    else:
        raise ValueError(f"Wrong value for pick_or_place argument: {pick_or_place}")

    scene_pcd: PointCloud = demo.scene_pc
    grasp_pcd: PointCloud = demo.grasp_pc
    target_pose: SE3 = demo.target_poses
    ####################################################################

    return scene_pcd, grasp_pcd


def visualize(scene_pcd: PointCloud, grasp_pcd: PointCloud, poses: SE3, query: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, show_sample_points: bool = False):
    
    grasp_pl = grasp_pcd.plotly(point_size=1.0, name="grasp")
    grasp_geometry = [grasp_pl]
    if query is not None:
        query_points, query_attention = query
        query_opacity = query_attention ** 1
        query_pl = PointCloud.points_to_plotly(pcd=query_points, point_size=15.0, opacity=query_opacity / query_opacity.max())#, custom_data={'attention': query_attention.cpu()})
        grasp_geometry.append(query_pl)
    fig_grasp = get_plotly_fig("Grasp")
    fig_grasp = fig_grasp.add_traces(grasp_geometry)



    
    scene_pl = scene_pcd.plotly(point_size=1.0, name='scene')
    placement_geometry = [scene_pl]
    transformed_grasp_pcd = grasp_pcd.transformed(poses)
    for i in range(len(poses)):
        pose_pl = transformed_grasp_pcd[i].plotly(point_size=1.0, name=f'pose_{i}')
        placement_geometry.append(pose_pl)
    if show_sample_points:
        sample_pl = PointCloud.points_to_plotly(pcd=poses.points, point_size=7.0, colors=[0.2, 0.5, 0.8], name=f'sample_points')
        placement_geometry.append(sample_pl)
    fig_sample = get_plotly_fig("Sampled Placement")
    fig_sample = fig_sample.add_traces(placement_geometry)
    
    trace_dict = {}
    visiblility_list = []
    for i, trace in enumerate(fig_sample.data):
        trace_dict[trace.name] = i
        if trace.name[:4] == 'pose':
            trace.visible = False
            visiblility_list.append(False)
        else:
            visiblility_list.append(trace.visible)

    # Define sliders
    steps = []
    for i in range(len(poses)):
        step = dict(
            method="update",
            args=[{"visible": visiblility_list.copy()},
                {"title": "Visualizing pose_" + str(i)}],  # layout attribute
        )
        step["args"][0]["visible"][trace_dict[f'pose_{i}']] = True  # Toggle i'th trace to "visible"
        steps.append(step)

    sliders = [dict(
        active=0,
        currentvalue={"prefix": "Pose: "},
        pad={"t": 50},
        steps=steps
    )]

    fig_sample.update_layout(
        sliders=sliders
    )

    fig_sample.data[trace_dict[f'pose_0']].visible = True



    return fig_grasp, fig_sample

In [None]:
score_model = score_model.eval()
score_model = score_model.requires_grad_(False)

In [None]:
train_batch = next(iter(train_dataloader))

demo_seq: DemoSequence = train_batch[0]
demo: TargetPoseDemo = demo_seq[1]

scene_raw: PointCloud = demo.scene_pc
grasp_raw: PointCloud = demo.grasp_pc
target_poses: SE3 = demo.target_poses

scene_proc = scene_proc_fn(scene_raw).to(device)
grasp_proc = grasp_proc_fn(grasp_raw).to(device)
target_poses = rescale_fn(target_poses).to(device)

T_target = target_poses.poses
target_score, T, time_in = diffusion.diffuse(T_target, angular_first=False, manual_time=0.03)

key_feature = scene_proc.colors
key_coord = scene_proc.points
key_batch = torch.zeros(len(key_coord), device=device, dtype=torch.long)
query_feature = grasp_proc.colors
query_coord = grasp_proc.points
query_batch = torch.zeros(len(query_coord), device=device, dtype=torch.long)

In [None]:
def reverse_diffusion(T, score, timestep, angular_first = True):
    if angular_first:
        score = torch.cat([score[..., 3:], score[..., :3]], dim=-1)

    score = score * timestep # Flow
    #score = (score * timestep / 2) + (torch.randn_like(score) * torch.sqrt(timestep)) # Langevin
    reverse_T = se3_exp_map(score)
    reverse_T = torch.cat([standardize_quaternion(matrix_to_quaternion(reverse_T[..., :3, :3])), reverse_T[..., :3, -1]], dim=-1)

    T = torch.cat([quaternion_multiply(T[..., :4], reverse_T[..., :4]), 
                   quaternion_apply(T[..., :4], reverse_T[..., 4:]) + T[..., 4:]], dim=-1)
    # T = torch.cat([quaternion_multiply(reverse_T[..., :4], T[..., :4]), 
    #                quaternion_apply(reverse_T[..., :4], T[..., 4:]) + reverse_T[..., 4:]], dim=-1)
    
    return T, reverse_T

In [None]:
with torch.no_grad():
    score, query, query_info, key_info = score_model(T=T,
                                                key_feature=key_feature, key_coord=key_coord, key_batch=key_batch,
                                                query_feature=query_feature, query_coord=query_coord, query_batch=query_batch,
                                                info_mode='NONE')

    loss = loss_fn(target_score, score)

In [None]:
score = score/math.sqrt(time_in)

In [None]:
#T_rev, dT = reverse_diffusion(T, target_score, timestep=time_in, angular_first=False)
T_rev, dT = reverse_diffusion(T, score, timestep=time_in, angular_first=False)

In [None]:
vis_pose = recover_scale_fn(SE3(T_target))[:1]
#vis_pose = recover_scale_fn(SE3(T))[:1]
#vis_pose = recover_scale_fn(SE3(T_rev))[:1]
fig_grasp, fig_sample = visualize(scene_pcd=scene_raw, grasp_pcd=grasp_raw, poses=vis_pose) #, query=(edf_outputs['query_points'] * unit_len, edf_outputs['query_attention']))
fig_sample

In [None]:
#vis_pose = recover_scale_fn(SE3(T_target))[:1]
vis_pose = recover_scale_fn(SE3(T))[:1]
#vis_pose = recover_scale_fn(SE3(T_rev))[:1]
fig_grasp, fig_sample = visualize(scene_pcd=scene_raw, grasp_pcd=grasp_raw, poses=vis_pose) #, query=(edf_outputs['query_points'] * unit_len, edf_outputs['query_attention']))
fig_sample

In [None]:
target_score * time_in

In [None]:
#vis_pose = recover_scale_fn(SE3(T_target))[:1]
#vis_pose = recover_scale_fn(SE3(T))[:1]
vis_pose = recover_scale_fn(SE3(T_rev))[:1]
fig_grasp, fig_sample = visualize(scene_pcd=scene_raw, grasp_pcd=grasp_raw, poses=vis_pose) #, query=(edf_outputs['query_points'] * unit_len, edf_outputs['query_attention']))
# fig_sample

In [None]:
math.sin(0.32)