In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "2"
import torch
import glob
import math
import numpy as np
from medpy.io import load
import pyvista
import matplotlib
import matplotlib.pyplot as plt
from dataset_vessel3d import build_vessel_data
from monai.data import DataLoader
from torch import nn
from skimage.morphology import skeletonize_3d
from skimage.measure import marching_cubes_lewiner
from scipy.sparse import csr_matrix
from models.detr_transformer_3D import build_detr_transformer
from losses import SetCriterion
from models import build_model
import networkx as nx
# %matplotlib widget
import yaml
import sys
sys.path.append("..")
import json
from models.position_encoding import PositionalEncoding3D
from torch.utils.tensorboard import SummaryWriter
from scipy import ndimage
import open3d as o3d
from inference import relation_infer
from scipy.sparse.csgraph import connected_components
from utils import save_input, save_output
from utils import Bresenham3D
import pyvista as pv
import pdb
%load_ext autoreload
%autoreload 2

In [4]:
def plot_test_sample(image, points, edges):
    meshes = []
    graphs = []
    image = image[:,54:-54,54:-54]
    porder_points = [
        [0, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [1, 1, 1],
    ]
    border_edges = [
        [0, 1],
        [0, 2],
        [1, 3],
        [2, 3],
        [4, 5],
        [4, 6],
        [5, 7],
        [6, 7],
        [0, 4],
        [1, 5],
        [2, 6],
        [3, 7],
    ]

    # edges = np.concatenate((np.int32(2*np.ones((edges.shape[0],1))), edges), 1)
    # gt_graph = pyvista.PolyData(points)
    # gt_graph.lines = edges.flatten()
    ref_color = [[1, 0, 0] for i in range(len(edges))]
    ref_line_set = o3d.geometry.LineSet(
                                points=o3d.utility.Vector3dVector(points + np.array([0.6, 0, 0])),
                                lines=o3d.utility.Vector2iVector(edges),
                            )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    graphs.append(ref_line_set)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points + np.array([0.6, 0, 0]))
    point_cloud.paint_uniform_color([0, 1, 0])
    graphs.append(point_cloud)
    
    ref_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    ref_line_set = o3d.geometry.LineSet(
                                points=o3d.utility.Vector3dVector(porder_points + np.array([0.6, 0, 0])),
                                lines=o3d.utility.Vector2iVector(border_edges),
                            )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    graphs.append(ref_line_set)
    
    verts, faces, norms, vals = marching_cubes_lewiner(image>0.0, level=0)
    verts = verts/np.array(image.shape) 

    mesh = np.concatenate((faces[:,:2], faces[:,1:]), axis=0)
    adjucency = np.zeros((verts.shape[0], verts.shape[0]))
    
    for e in mesh:
        adjucency[e[0], e[1]] =1.0
        adjucency[e[1], e[0]] =1.0
    
    adjucency = np.triu(adjucency)
    mesh = np.array(np.where(np.triu(adjucency)>0)).T
    # mesh = np.concatenate((np.int32(2*np.ones((mesh.shape[0],1))), mesh), 1)
    # gt_mesh = pyvista.PolyData(verts)
    # gt_mesh.lines = mesh.flatten()
    pred_color = [[0, 0, 1] for i in range(len(mesh))]
    pred_line_set = o3d.geometry.LineSet(
                               points=o3d.utility.Vector3dVector(verts- np.array([0.6, 0, 0])),
                               lines=o3d.utility.Vector2iVector(mesh),
                           )
    
    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    meshes.append(pred_line_set)
    
    pred_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    pred_line_set = o3d.geometry.LineSet(
                               points=o3d.utility.Vector3dVector((porder_points - np.array([0.6, 0, 0]))),
                               lines=o3d.utility.Vector2iVector(border_edges),
                           )
    
    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    meshes.append(pred_line_set)
    
    
    o3d.visualization.draw_geometries(meshes+graphs)
    
    # plotter = pv.Plotter(shape=(1, 2))

    # plotter.subplot(0, 0)
    # plotter.add_mesh(gt_mesh, show_edges=True, color="b")

    # plotter.subplot(0, 1)
    # plotter.add_mesh(gt_graph, show_edges=True, color="r")

    # Display the window
    # plotter.show()
    

def plot_val_rel_sample(image, points1, edges1, points2, edges2, attn_map=None, relative_coords=True):
    ref_line_sets = []
    pred_line_sets = []
    porder_points = [
        [0, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [0, 1, 1],
        [1, 1, 1],
    ]
    border_edges = [
        [0, 1],
        [0, 2],
        [1, 3],
        [2, 3],
        [4, 5],
        [4, 6],
        [5, 7],
        [6, 7],
        [0, 4],
        [1, 5],
        [2, 6],
        [3, 7],
    ]

    # edges = np.concatenate((np.int32(2*np.ones((edges.shape[0],1))), edges), 1)
    # gt_graph = pyvista.PolyData(points)
    # gt_graph.lines = edges.flatten()
    ref_color = [[1, 0, 0] for i in range(len(edges1))]
    ref_line_set = o3d.geometry.LineSet(
                                points=o3d.utility.Vector3dVector(points1),
                                lines=o3d.utility.Vector2iVector(edges1),
                            )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points1)
    point_cloud.paint_uniform_color([0, 1, 0])
    ref_line_sets.append(point_cloud)

    ref_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    ref_line_set = o3d.geometry.LineSet(
                                points=o3d.utility.Vector3dVector(porder_points),
                                lines=o3d.utility.Vector2iVector(border_edges),
                            )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)
    
    
    ref_color = [[1, 0, 0] for i in range(len(edges2))]
    ref_line_set = o3d.geometry.LineSet(
                                points=o3d.utility.Vector3dVector(points2 + np.array([1.1, 0, 0])),
                                lines=o3d.utility.Vector2iVector(edges2),
                            )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(points2 + np.array([1.1, 0, 0]))
    point_cloud.paint_uniform_color([0, 1, 0])
    ref_line_sets.append(point_cloud)

    ref_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    ref_line_set = o3d.geometry.LineSet(
                                points=o3d.utility.Vector3dVector(porder_points + np.array([1.1, 0, 0])),
                                lines=o3d.utility.Vector2iVector(border_edges),
                            )
    ref_line_set.colors = o3d.utility.Vector3dVector(ref_color)
    ref_line_sets.append(ref_line_set)

    verts, faces, norms, vals = marching_cubes_lewiner(image>0.0, level=0)
    verts = verts/np.array(image.shape) 

    mesh = np.concatenate((faces[:,:2], faces[:,1:]), axis=0)
    adjucency = np.zeros((verts.shape[0], verts.shape[0]))

    for e in mesh:
        adjucency[e[0], e[1]] =1.0
        adjucency[e[1], e[0]] =1.0

    adjucency = np.triu(adjucency)
    mesh = np.array(np.where(np.triu(adjucency)>0)).T
    # mesh = np.concatenate((np.int32(2*np.ones((mesh.shape[0],1))), mesh), 1)
    # gt_mesh = pyvista.PolyData(verts)
    # gt_mesh.lines = mesh.flatten()
    pred_color = [[0, 0, 1] for i in range(len(mesh))]
    pred_line_set = o3d.geometry.LineSet(
                               points=o3d.utility.Vector3dVector(verts- np.array([1.1, 0, 0])),
                               lines=o3d.utility.Vector2iVector(mesh),
                           )

    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    pred_line_sets.append(pred_line_set)

    pred_color = [[0.2, 0.2, 0.2] for i in range(len(border_edges))]
    pred_line_set = o3d.geometry.LineSet(
                               points=o3d.utility.Vector3dVector((porder_points - np.array([1.1, 0, 0]))),
                               lines=o3d.utility.Vector2iVector(border_edges),
                           )

    pred_line_set.colors = o3d.utility.Vector3dVector(pred_color)
    pred_line_sets.append(pred_line_set)

    o3d.visualization.draw_geometries(ref_line_sets+pred_line_sets)


In [5]:
class obj:
    def __init__(self, dict1):
        self.__dict__.update(dict1)
        
def dict2obj(dict1):
    return json.loads(json.dumps(dict1), object_hook=obj)

In [7]:
config_file = "./configs/synth_3D.yaml"

with open(config_file) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

config = dict2obj(config)

In [8]:
data_dir = './data/vessel_data/test_data/'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [9]:
dataset = build_vessel_data(config)

In [10]:
def image_graph_collate(batch):
    images = torch.cat([item_ for item in batch for item_ in item[0]], 0).contiguous()
    segs = torch.cat([item_ for item in batch for item_ in item[1]], 0).contiguous()
    points = [item_ for item in batch for item_ in item[2]]
    edges = [item_ for item in batch for item_ in item[3]]
    return [images, segs, points, edges]

In [11]:
mn_dataset_loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple([y_.to(device) for y_ in x_] for x_ in image_graph_collate(x)))

In [18]:
for images, seg, points, edges in mn_dataset_loader:
    images = torch.stack(images)
    seg = torch.stack(seg)
    print(len(points), images.shape, points[0].shape, len(edges))
    plot_test_sample(seg[0].squeeze().cpu().numpy(), points[0].cpu().numpy(), edges[0].cpu().numpy())

2 torch.Size([2, 1, 64, 64, 64]) torch.Size([27, 3]) 2


ValueError: Input array must be at least 2x2x2.

## Debug Model

In [16]:
model = build_model(config)
model = model.to(device)

In [17]:
ckpt_path = './trained_weights/3D_vessel_checkpoint_epoch=100.pt'
checkpoint = torch.load(ckpt_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(checkpoint['net'], strict=False)
unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
if len(missing_keys) > 0:
    print('Missing Keys: {}'.format(missing_keys))
if len(unexpected_keys) > 0:
    print('Unexpected Keys: {}'.format(unexpected_keys))

In [19]:
model.eval()
iteration = 0
for images, seg, points, edges in mn_dataset_loader:
    images = torch.stack(images)-0.5
    seg = torch.stack(seg)

    h, out = model(images)
    out = relation_infer(h.detach(), out, model, config.MODEL.DECODER.OBJ_TOKEN, config.MODEL.DECODER.RLN_TOKEN)

    plot_val_rel_sample(seg[0].squeeze().cpu().numpy(), points[0].cpu().numpy(), edges[0].cpu().numpy(), out['pred_nodes'][0], out['pred_rels'][0])
    
    iteration = iteration+1
    print('Iteration:',iteration)
    if iteration>10:
        break

Iteration: 1
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 11
