In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

In [None]:
from typing import List, Tuple, Optional, Union, Iterable

import plotly
import plotly.graph_objects as go
import numpy as np

import torch
from torchvision.transforms import Compose
from diffusion_edf.transforms import quaternion_apply, random_quaternions, quaternion_multiply, quaternion_invert

from e3nn import o3

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos, OutputLog
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.score_model import ScoreModel
from diffusion_edf.pc_utils import get_plotly_fig

In [None]:
plotly.offline.init_notebook_mode()

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
output_log = OutputLog.load("runs/output_log/score_model.gzip")

In [None]:
query = (output_log.query_weight, output_log.query_feature, output_log.query_coord, output_log.query_batch)
query_ext_info = (output_log.ext_edge_src_query, output_log.ext_edge_dst_query)
query_graph_info = (output_log.node_feature_query, output_log.node_coord_query, output_log.node_batch_query, output_log.node_scale_slice_query, output_log.node_edge_src_query, output_log.node_edge_dst_query)
key_ext_info = (output_log.ext_edge_src_key, output_log.ext_edge_dst_key)
key_graph_info = (output_log.node_feature_key, output_log.node_coord_key, output_log.node_batch_key, output_log.node_scale_slice_key, output_log.node_edge_src_key, output_log.node_edge_dst_key)

In [None]:
node_coord = key_graph_info[1]
node_batch = key_graph_info[2]
node_scale_slice = key_graph_info[3]
edge_src = key_graph_info[4]
edge_dst = key_graph_info[5]
edge_dst, edge_src = edge_src, edge_dst # Because Unet is upside down

In [None]:
edges = [[] for _ in node_coord]
for src, dst in zip(edge_src, edge_dst):
    edges[dst.item()].append(src.item())

In [None]:
n_scales = 4
src_scale = 0
dst_scale = 1

In [None]:
src_idx = torch.arange(node_scale_slice[src_scale], node_scale_slice[src_scale+1])
dst_idx = torch.arange(node_scale_slice[dst_scale], node_scale_slice[dst_scale+1])

src_node_coord = node_coord[src_idx]
dst_node_coord = node_coord[dst_idx]

# node_plot = PointCloud.points_to_plotly(node_coord)
# edge_src = multiscale_edge_src[torch.logical_and((multiscale_edge_src >= node_scale_slice[src_scale]), (multiscale_edge_src < node_scale_slice[src_scale+1])).nonzero().squeeze(-1)] - node_scale_slice[src_scale]
# edge_dst = multiscale_edge_dst[torch.logical_and((multiscale_edge_dst >= node_scale_slice[dst_scale]), (multiscale_edge_dst < node_scale_slice[dst_scale+1])).nonzero().squeeze(-1)] - node_scale_slice[dst_scale]
# edges = [[] for _ in node_coord]
# for src, dst in zip(edge_src, node_edge_dst):
#     edges[dst.item()].append(src.item())

In [None]:
edges[4000]

In [None]:
dst_i = 4000
dst_z_disp = 50



xe, ye, ze = [], [], []

for src_i in edges[dst_i]:
    xe += [node_coord[src_i,0].item(), node_coord[dst_i, 0].item(), None]
    ye += [node_coord[src_i,1].item(), node_coord[dst_i, 1].item(), None]
    ze += [node_coord[src_i,2].item(), node_coord[dst_i, 2].item() + dst_z_disp, None]

In [None]:
src_node_plot = go.Scatter3d(x=src_node_coord[:,0].numpy(),
                           y=src_node_coord[:,1].numpy(),
                           z=src_node_coord[:,2].numpy(),
                           mode='markers',
                           name='src_nodes',
                           marker=dict(symbol='circle',
                                       size=1,
                                       color='rgb(0,0,0)'),
                           text=src_idx,
                           hoverinfo='text')
dst_node_plot = go.Scatter3d(x=dst_node_coord[:,0].numpy(),
                             y=dst_node_coord[:,1].numpy(),
                             z=dst_node_coord[:,2].numpy() + dst_z_disp,
                           mode='markers',
                           name='dst_nodes',
                           marker=dict(symbol='circle',
                                        size=1,
                                        color='rgb(0,0,0)'),
                           text=dst_idx,
                           hoverinfo='text')



edges_plot = go.Scatter3d(x=xe,
               y=ye,
               z=ze,
               mode='lines',
               line=dict(color='rgb(125,125,125)', width=1),
               hoverinfo='none'
               )

In [None]:
fig=go.Figure(data=[src_node_plot, dst_node_plot, edges_plot])
fig