# Arteries planar visualization: demonstration

This is an example code of the article:

*Untangling Vascular Trees for Surgery and Interventional Radiology*, G. Houry, T. Boeken, S. Allassonnière and J. Feydy, MICCAI 2025.

## Imports and loading utility

In [18]:
import skshapes as sks

In [19]:
from backend.viz import plot_graph, plot_branches
from backend.graph_processing import *

In [20]:
from backend.tree_creation import compute_edge_lenghts, compute_downstream, reorder_branches, compute_reference_angles

from backend.recursive_layout.recursive_layout import recursive_layout
from backend.tree_creation import initialize_vascular_tree
from backend.force_directed.force_directed import compute_force

In [21]:
import torch
import scipy

import nibabel as nib
import numpy as np

In [22]:
import pyvista as pv

pv.set_jupyter_backend(["trame", "static"][1])
pv.global_theme.colorbar_orientation = "vertical"

In [23]:
def load_nii(path: str, normalize: bool = True) -> np.ndarray:
    image = nib.load(path).get_fdata().astype(np.float32)

    if normalize:
        image = image / image.max()
        image = image[:8 * (image.shape[0] // 8), :8 * (image.shape[1] // 8), : 8 * (image.shape[2] // 8)]

    return image

## Preprocessing

In this step, we transform a raw CT Scan in a tree structure exploitable by our algorithm.

In [24]:
data_path = 'normalized_001.nii.gz'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [40]:
def cube_offsets(radius, device=None):
    window = torch.arange(-radius, radius + 1, device=device)
    return torch.stack(torch.meshgrid(window, window, window, indexing='xy'), dim=-1).reshape(-1, 3)

def preprocessing(data_path):
    # Load the original image
    raw_image = sks.Image(load_nii(data_path), dtype=torch.float32, device=device)

    # Segment and skeletonize the vessels of the image using simple filters
    hysteresis_mask = sks.images.filters.hysteresis_threshold(raw_image, low=17 / 255, high=30 / 255)
    signed_radii = sks.images.filters.signed_distance_transform(hysteresis_mask, dilation=3)
    vessel_mask = sks.images.filters.frangi_filter(signed_radii, alpha=0.1, beta=0.5, gamma=0.9, smoothing=1)
    skeleton = sks.images.filters.skeletonize(vessel_mask)

    # Extract the adjacency matrix of the skeleton and the 3D position of the vertices
    pos, adjmatrix = sks.images.filters.skeleton_to_graph(skeleton)

    # Smooth the positions to avoid voxelization artifacts
    pos = laplacian_smoothing(pos.astype(np.float32), adjmatrix.astype(int), iters=10)

    # Estimate the radius at each position robustly by applying a max_pool around the new positions 
    intpos = pos.round().astype(np.int64)
    mask = sks.Mask(indices=intpos, shape=signed_radii.shape, device=signed_radii.device)
    radius = signed_radii.masked_convolution(mask=mask, offsets=cube_offsets(radius=3, device=skeleton.device), kernel='max')
    radius = radius.values_at(query_indices=torch.tensor(intpos, device=radius.device)).clip(min=0.1)

    # Get the voxel intensity of the initial image at the new positions
    intensity = raw_image.values_at(query_indices=intpos)

    # Scale the radii of subpixel vessels by the image intensity, in order to refine the radius estimation for small vessels
    small_radii = radius < 1
    radius[small_radii] = radius[small_radii] * intensity[small_radii] / intensity[small_radii].max()

    # Remove the nodes whose intensity is too small (to reduce sources of noise)
    filter = intensity.cpu().numpy() > 0.05
    adjmatrix = adjmatrix[filter, :][:, filter]
    features = {'pos': pos[filter], 'radius': radius.cpu().numpy()[filter], 'intensity': intensity.cpu().numpy()[filter]}

    # Choose the root of the tree
    root = graph_root(adjmatrix, features)

    # Merge the relevant disconnected components to the root
    adjmatrix = merge_components_to_root(adjmatrix, features, root)

    # Smooth the radii along the vessels
    features["radius"] = smooth_val(features["radius"], adjmatrix, iters=5)

    # Remove the topological artifacts of the graph
    adjmatrix, tree_features, root = keep_largest_component(adjmatrix, features, root)
    tree_adjmatrix = remove_cycles(adjmatrix, tree_features, root)

    tree_features = {'pos': tree_features['pos'], 'radius': tree_features['radius']}    
    return tree_adjmatrix, tree_features, root

In [41]:
tree_adjmatrix, tree_features, root = preprocessing(data_path)

In [None]:
plot_graph(pos=tree_features['pos'], adjmatrix=tree_adjmatrix, radius=tree_features['radius'], colors=tree_features['radius'],cmap="Reds").show()

## Embedding algorithm

The embedding algorithm takes as input:

- ``tree_adjmatrix``, a ``scipy.sparse.csr_matrix`` of bools of shape ``(n_vertices, n_vertices)`` encoding the input directed tree;
- ``features``, a dict with keys ``pos`` and ``radius``:
    - ``features['pos']`` is a ``np.ndarray`` of shape ``(n_vertices, 3)`` containing the 3D positions of the tree vertices;
    - ``features['radius']`` is a ``np.ndarray`` of shape ``(n_vertices,)`` containing the radius of the vessels at each vertex.
- ``root``, an integer between ``0`` and ``n_vertices`` specifying the index of the tree root.

In [43]:
print(tree_adjmatrix)

<Compressed Sparse Row sparse matrix of dtype 'bool'
	with 21162 stored elements and shape (21163, 21163)>
  Coords	Values
  (1, 11)	True
  (2, 13)	True
  (3, 2)	True
  (4, 3)	True
  (5, 4)	True
  (6, 5)	True
  (7, 6)	True
  (8, 14)	True
  (9, 0)	True
  (10, 19)	True
  (11, 10)	True
  (12, 1)	True
  (13, 12)	True
  (14, 7)	True
  (15, 8)	True
  (16, 9)	True
  (18, 29)	True
  (19, 18)	True
  (20, 15)	True
  (21, 20)	True
  (22, 21)	True
  (23, 22)	True
  (24, 17)	True
  (25, 16)	True
  (27, 43)	True
  :	:
  (21136, 21141)	True
  (21137, 21143)	True
  (21138, 21144)	True
  (21139, 21133)	True
  (21140, 21131)	True
  (21141, 21146)	True
  (21141, 21147)	True
  (21143, 21148)	True
  (21144, 21149)	True
  (21145, 21140)	True
  (21146, 21145)	True
  (21146, 21151)	True
  (21147, 21142)	True
  (21148, 21152)	True
  (21149, 21150)	True
  (21150, 21154)	True
  (21151, 21156)	True
  (21152, 21153)	True
  (21153, 21157)	True
  (21154, 21155)	True
  (21157, 21158)	True
  (21158, 21159)	True
  (211

In [47]:
print(tree_features.keys())
print(tree_features['pos'].shape, tree_features['radius'].shape)

dict_keys(['pos', 'radius'])
(21163, 3) (21163,)


In [48]:
print(root)

12957


In [49]:
def embed_vascular_tree(tree_adjmatrix, features, root):
    features["radius"] = features["radius"][:,None]

    # Initialize the tree data structure
    tree = initialize_vascular_tree(tree_adjmatrix, features, root, pruning=20, internal_pruning=3)

    # Compute different useful features 
    tree = compute_edge_lenghts(tree)
    tree = compute_downstream(tree)

    # Compute the target angles parameters (corresponding to the true angular curvatures). 
    tree = reorder_branches(tree, scores=- tree.features["downstream_barycenter"][:, 0])
    tree = compute_reference_angles(tree, priority=tree.features["downstream_volume"], smoothing=15)

    # The 'importance' feature determines the space given to each subtesselation at each recursive step
    tree.coarse_features["importance"] = tree.features["downstream_volume"][tree.bifurcations().flatten()]

    # Initialize features for the recursive layout
    tree.features["bounds"] = np.zeros(shape=(tree.coarse_size, 2))
    tree.features["angle"] = np.zeros(shape=(tree.size, 1))
    tree.features["emb"] = np.zeros(shape=(tree.size, 2))
    
    tree, anchors = recursive_layout(tree, smoothing=10)
    tree = compute_force(tree, iters=300, alpha=25, beta=5e-5, gamma=0.4, sigma=200, mu=1, momentum=0., clip=0.01)
    
    return tree

In [32]:
tree = embed_vascular_tree(tree_adjmatrix, tree_features, root)

The final embeddings are contained in ``tree.features['emb']``, the initial positions in ``tree.features['pos']``:

In [None]:
pl = pv.Plotter(border=None, window_size = (1700, 2000))

plot_branches(tree.branches(), tree.features["pos"], pl=pl, radius=1.2*tree.features["radius"]**0.9, 
                  colors=tree.features["radius"], cmap="Reds", clim=(0,5), show_scalar_bar=False)
pl.show()

In [None]:
pl = pv.Plotter(border=None, window_size = (4300, 3600))

plot_branches(tree.branches(), tree.features["emb"], pl=pl, radius=1.2*tree.features["radius"]**0.9, 
                  colors=tree.features["radius"], cmap="Reds", clim=(0,5), show_scalar_bar=False)

pl.camera.zoom(1.05)
pl.show()