In [None]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

In [None]:
from typing import List, Tuple, Optional, Union, Iterable

import plotly
import plotly.graph_objects as go
import numpy as np

import torch
from torchvision.transforms import Compose
from diffusion_edf.transforms import quaternion_apply, random_quaternions, quaternion_multiply, quaternion_invert

from e3nn import o3

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos, OutputLog
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.score_model import ScoreModel
from diffusion_edf.pc_utils import get_plotly_fig

In [None]:
plotly.offline.init_notebook_mode()

In [None]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)
normalize_color_fn = NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))
recover_color_fn = NormalizeColor(color_mean = -normalize_color_fn.color_mean / normalize_color_fn.color_std, color_std = 1 / normalize_color_fn.color_std)


scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
scene_unproc_fn = Compose([recover_color_fn, recover_scale_fn])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average"),
                         normalize_color_fn])
grasp_unproc_fn = Compose([recover_color_fn, recover_scale_fn])

In [None]:
output_log = OutputLog.load("runs/output_log/score_model.gzip")

In [None]:
query = (output_log.query_weight, output_log.query_feature, output_log.query_coord, output_log.query_batch)
query_ext_info = (output_log.ext_edge_src_query, output_log.ext_edge_dst_query)
query_graph_info = (output_log.node_feature_query, output_log.node_coord_query, output_log.node_batch_query, output_log.node_scale_slice_query, output_log.node_edge_src_query, output_log.node_edge_dst_query)
key_ext_info = (output_log.ext_edge_src_key, output_log.ext_edge_dst_key)
key_graph_info = (output_log.node_feature_key, output_log.node_coord_key, output_log.node_batch_key, output_log.node_scale_slice_key, output_log.node_edge_src_key, output_log.node_edge_dst_key)

In [None]:
node_coord = key_graph_info[1]
node_batch = key_graph_info[2]
node_scale_slice = key_graph_info[3]
edge_src = key_graph_info[4]
edge_dst = key_graph_info[5]
edge_dst, edge_src = edge_src, edge_dst # Because Unet is upside down

In [None]:
dst_to_src_dict = [[] for _ in node_coord]
for src, dst in zip(edge_src, edge_dst):
    dst_to_src_dict[dst.item()].append(src.item())

In [None]:
def draw_recursive_graph(node_coord, dst_to_src_dict, scale_slice,
               target_dst_idx, point_size=None, scale_z_disp = None):
    
    assert target_dst_idx >= scale_slice[-2]
    
    target_dst_idx_set = set()
    target_dst_idx_set.add(target_dst_idx)
    src_node_plots = []
    dst_node_plots = []
    edge_plots = []
    top_scale = len(scale_slice) - 3

    if scale_z_disp is None:
        scale_z_disp = (node_coord[...,-1].max() - node_coord[...,-1].min()) * 1.5
    if point_size is None:
        point_size = [1] * (top_scale+3)

    for s in range(top_scale, -1, -1):
        src_idx = torch.arange(node_scale_slice[s], node_scale_slice[s+1])
        dst_idx = torch.arange(node_scale_slice[s+1], node_scale_slice[s+2])

        src_node_coord = node_coord[src_idx] + np.array([0., 0., scale_z_disp * (s)])
        dst_node_coord = node_coord[dst_idx] + np.array([0., 0., scale_z_disp * (s+1)])

        xe, ye, ze = [], [], []

        edges = []
        for dst in target_dst_idx_set:
            srcs = dst_to_src_dict[dst]
            edges.append((srcs, dst))

        target_dst_idx_set = set()
        for srcs, dst in edges:
            for src in srcs:
                xe += [node_coord[src,0].item(), node_coord[dst, 0].item(), None]
                ye += [node_coord[src,1].item(), node_coord[dst, 1].item(), None]
                ze += [node_coord[src,2].item() + scale_z_disp * (s), node_coord[dst, 2].item() + scale_z_disp * (s+1), None]
                target_dst_idx_set.add(src)
        

        if s == 0:
            src_node_plot = go.Scatter3d(x=src_node_coord[:,0].numpy(),
                                    y=src_node_coord[:,1].numpy(),
                                    z=src_node_coord[:,2].numpy(),
                                    mode='markers',
                                    name='src_nodes',
                                    marker=dict(symbol='circle',
                                                size=point_size[s],
                                                color='rgb(0,0,0)'),
                                    text=src_idx,
                                    hoverinfo='text')
            src_node_plots.append(src_node_plot)
        dst_node_plot = go.Scatter3d(x=dst_node_coord[:,0].numpy(),
                                    y=dst_node_coord[:,1].numpy(),
                                    z=dst_node_coord[:,2].numpy(),
                                mode='markers',
                                name='dst_nodes',
                                marker=dict(symbol='circle',
                                                size=point_size[s+1],
                                                color='rgb(0,0,0)'),
                                text=dst_idx,
                                hoverinfo='text')



        edges_plot = go.Scatter3d(x=xe,
                    y=ye,
                    z=ze,
                    mode='lines',
                    line=dict(color=f'rgba(255,0,0, {0.2 ** (top_scale - s)})', width=1),
                    hoverinfo='none'
                    )
        
        dst_node_plots.append(dst_node_plot)
        edge_plots.append(edges_plot)
    
    return src_node_plots + dst_node_plots + edge_plots

In [None]:
node_scale_slice

In [None]:
traces = draw_recursive_graph(node_coord=node_coord, dst_to_src_dict=dst_to_src_dict, scale_slice=node_scale_slice,
                              target_dst_idx=5050, scale_z_disp=40., point_size=[1,1.5,2,2.5,3.])

In [None]:
fig=go.Figure(data=traces, layout=dict(width=1600, height=1200))
fig