In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from skimage.morphology import skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy import signal
from skimage.filters import frangi, sato
from skimage.draw import line_nd
from PIL import Image
import pickle
from queue import PriorityQueue

In [2]:
TREE_NAME = 'P01'

## Loading skeleton and skeleton_thickness

In [3]:
source_dir = './data/'
skeleton = np.load(source_dir + TREE_NAME + '/skeleton.npy')
skeleton_thickness = np.load(source_dir + TREE_NAME + '/skeleton-thickness.npy')
reconstruction = np.load(source_dir + TREE_NAME + '/reconstruction.npy')

FileNotFoundError: [Errno 2] No such file or directory: './data/P01/skeleton.npy'

## Utility visualisation functions

In [4]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

## Resolving nodes mask

### Resolving leaves mask

In [5]:
def trim_skeleton(skeleton):   
    new_skeleton = np.zeros(skeleton.shape)
    skeleton_voxels = np.argwhere(skeleton)
    
    for voxel in skeleton_voxels:
        x, y, z = tuple(voxel)
        neighbours_count = 0
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if skeleton[neighbour_x, neighbour_y, neighbour_z] > 0:
                        neighbours_count += 1
                        
        if neighbours_count > 1:
            new_skeleton[x, y, z] = 1
                        
    return new_skeleton.astype(np.uint8)


def mark_leaves(skeleton):
    trimmed = trim_skeleton(skeleton)
    leaves = skeleton - trimmed
    return leaves

In [6]:
%%time
leaves_mask = mark_leaves(skeleton)

Wall time: 23.3 s


### resolving bifurcations mask

In [7]:
def mark_bifurcation_regions(skeleton):
    padded_skeleton = np.pad(skeleton, 1)
    bifurcations_map = np.zeros(padded_skeleton.shape)
    
    for skeleton_voxel in np.argwhere(padded_skeleton > 0):
        x, y, z = tuple(skeleton_voxel)
        kernel_radius = 1
        kernel = np.ones((3, 3, 3))
        kernel[1, 1, 1] = 0
        
        skeleton_slice = padded_skeleton[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        intersections = (skeleton_slice > 0) * kernel
        bifurcations_map[x, y, z] = np.sum(intersections)
        
    return (bifurcations_map[1:-1, 1:-1, 1:-1] > 2).astype(np.uint8)


def mark_nodes(skeleton):
    bifurcation_map = mark_bifurcation_regions(skeleton)
    leaves_map = mark_leaves(skeleton)
    return bifurcation_map + leaves_map

In [8]:
%%time
bifurcations_mask = mark_bifurcation_regions(skeleton)
nodes_mask = ((bifurcations_mask + leaves_mask) > 0).astype(np.uint8)

Wall time: 15.3 s


## Constructing graph

In [9]:
class Node:
    def __init__(self, coords):
        self.coords = coords
        self.edges = []
        self.data = {}
            
    def add_edge(self, edge):
        self.edges.append(edge)
        
    def get_neighbours(self):
        return [e.node_a if e.node_a.coords != self.coords else e.node_b for e in self.edges]
    
    def copy_without_edges(self):
        copied_node = Node(self.coords)
        copied_node.data = self.data
        return copied_node
    
    def __setitem__(self, key, value):
        self.data[key] = value
    
    def __getitem__(self, key):
        return self.data[key]
    
    def __hash__(self):
        return hash(self.coords)
    
    def __repr__(self):
        return f'Node {str(self.coords)}'
        
        
class Edge:
    def __init__(self, node_a, node_b):
        self.node_a = node_a
        self.node_b = node_b
        self.data = {}
        
    def __setitem__(self, key, value):
        self.data[key] = value
    
    def __getitem__(self, key):
        return self.data[key]
    
    def __repr__(self):
        return f'Edge {self.node_a.coords} -> {self.node_b.coords}'

In [10]:
def construct_graph(skeleton, nodes_mask, skeleton_thickness):
    nodes_labels = measure.label(nodes_mask)
    nodes_props = measure.regionprops(nodes_labels)
    print('nodes found (regions on nodes mask):', nodes_labels.max())
    voxel_to_node = dict()
    
    for props in nodes_props:
        if props.label < 1:
            continue
            
        node = Node(tuple(props.coords[0]))
        node['voxels'] = props.coords
        node['thickness'] = skeleton_thickness[tuple(props.coords[0])]
        
        for c in props.coords:
            voxel_to_node[tuple(c)] = node
            
    edges_mask = skeleton - nodes_mask
    edges_labels = measure.label(edges_mask > 0)
    print('edges found:', edges_labels.max())
    
    visited = np.zeros(skeleton.shape, dtype=np.bool)
    
    def find_touching_nodes(source_voxel):
        touching_nodes = set()
        queue = [source_voxel]
        while len(queue) > 0:
            x, y, z = queue.pop(0)
            
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    for dz in [-1, 0, 1]:
                        if dx == dy == dz == 0:
                            continue

                        neighbour_x = x + dx
                        neighbour_y = y + dy
                        neighbour_z = z + dz
                        if visited[neighbour_x, neighbour_y, neighbour_z]:
                            continue
                            
                        potential_node = voxel_to_node.get((neighbour_x, neighbour_y, neighbour_z))
                        if potential_node is not None:
                            touching_nodes.add(potential_node)

                        if edges_mask[neighbour_x, neighbour_y, neighbour_z] == 1:
                            queue.append((neighbour_x, neighbour_y, neighbour_z))
                            visited[neighbour_x, neighbour_y, neighbour_z] = True
        return list(touching_nodes)
       
        
    edges_props = measure.regionprops(edges_labels)
    
    edges = []
    bad_edges = []
    for props in edges_props:
        edge_voxel = props.coords[0]
        touching_nodes = find_touching_nodes(edge_voxel)
        if len(touching_nodes) != 2:
            print(f'bad edge found! touching nodes count: {len(touching_nodes)}')
            bad_edges.append(props.coords)
            continue
            
        edge = Edge(touching_nodes[0], touching_nodes[1])
        edge['voxels'] = props.coords
        edges.append(edge)
    return edges, bad_edges

In [15]:
%%time
edges, bad_edges = construct_graph(skeleton, nodes_mask, skeleton_thickness)
print("Number of bad edged found:", len(bad_edges))

nodes found (regions on nodes mask): 7717
edges found: 7931
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touching nodes count: 1
bad edge found! touc

## Cleaning DAG

### finding root node

In [17]:
def convert_to_nodes_list(edges):
    nodes = {}
    for e in edges:
        nodes[e.node_a] = e.node_a.copy_without_edges()
        nodes[e.node_b] = e.node_b.copy_without_edges()
        
    for e in edges:
        new_edge = Edge(nodes[e.node_a], nodes[e.node_b])
        new_edge.data = e.data
        nodes[e.node_a].add_edge(new_edge)
        nodes[e.node_b].add_edge(new_edge)
        
    return list(nodes.values())

def find_tree_root_candidates(nodes, root_degree, thichness_tolerance):
    proper_degree_nodes = [node for node in nodes if len(node.edges) == root_degree]
    root_thickness = max(map(lambda node: node['thickness'], proper_degree_nodes))
    root_candidates = [node for node in proper_degree_nodes if 
                       node['thickness'] >= root_thickness - thichness_tolerance]
    return root_candidates

def visualize_root(root, skeleton, mark_radius=2):
    visualisation = skeleton.copy().astype(np.uint8)
    for v in root['voxels']:
            x, y, z = tuple(v)
            visualisation[x - mark_radius: x + mark_radius, 
                          y - mark_radius: y + mark_radius, 
                          z - mark_radius: z + mark_radius] = 4
    visualize_lsd(visualisation)

In [19]:
%%time

roots_degrees = {
    'P01': 1,
    'P05': 1,
    'P12': 1,
}

root_thickness_tolerance = {
    'P01': 0,
    'P05': 9,
    'P12': 0,
}

nodes = convert_to_nodes_list(edges)
root_candidates = find_tree_root_candidates(nodes, roots_degrees.get(TREE_NAME, 1), 
                                            root_thickness_tolerance.get(TREE_NAME, 0))
print(f'found {len(root_candidates)} root candidate(s)')

candidates_indices = {
    'P01': 0,
    'P05': 6,
    'P12': 0,
}

root = root_candidates[candidates_indices.get(TREE_NAME, 0)]
visualize_root(root, skeleton, 5) # verify whether the proper node was selected

found 1 root candidate(s)
Wall time: 1min 50s


### removing cycles (obtaining DAG)

In [20]:
def remove_dag_cycles(root): # TODO pomyśl o - przed node['thickness']
    counter = 0
    
    new_root = root.copy_without_edges()
    coords_to_old_parents = {}
    coords_to_new_node = { new_root.coords: new_root }
    
    queue = PriorityQueue()
    for node in root.get_neighbours():
        coords_to_old_parents[node.coords] = [root]
        queue.put(((-node['thickness'], counter), node))
        counter += 1
        
    while not queue.empty():
        _, node = queue.get()
        
        if coords_to_new_node.get(node.coords) is not None:
            continue
        
        parent_candidates = coords_to_old_parents[node.coords]
        proper_parent_thickness = min([p['thickness'] for p in parent_candidates])
        proper_parent = [p for p in parent_candidates if p['thickness'] == proper_parent_thickness][0]
        edge_from_parent = [e for e in proper_parent.edges if e.node_a == node or e.node_b == node][0]
        
        new_node = node.copy_without_edges()
        new_parent = coords_to_new_node[proper_parent.coords]
        new_edge = Edge(new_parent, new_node)
        new_edge.data = edge_from_parent.data
        new_parent.add_edge(new_edge)
        
        coords_to_new_node[new_node.coords] = new_node
        
        for neighbour in node.get_neighbours():
            parents = coords_to_old_parents.get(neighbour.coords, [])
            coords_to_old_parents[neighbour.coords] = parents + [node]
            queue.put(((-neighbour['thickness'], counter), neighbour))
            counter += 1
            
    return new_root

In [21]:
%%time

clean_root = remove_dag_cycles(root)

Wall time: 331 ms


### removing redundant nodes and edges

In [22]:
def merge_edges(a, b, node_a, node_b):
        new_edge = Edge(node_a, node_b)
        new_edge.data = a.data
        new_edge['voxels'] = np.concatenate([a['voxels'], b.node_a['voxels'], b['voxels']])
        return new_edge


def remove_dag_redundant_nodes(root):
    new_root = root.copy_without_edges()
    for edge in root.edges:
        new_neighbour = remove_dag_redundant_nodes(edge.node_b)
        
        if len(new_neighbour.edges) == 1:
            merged_edge = merge_edges(edge, new_neighbour.edges[0], new_root, new_neighbour.edges[0].node_b)
            new_root.add_edge(merged_edge)
            
        else:
            new_edge = Edge(new_root, new_neighbour)
            new_edge.data = edge.data
            new_root.add_edge(new_edge)
            
    return new_root

In [23]:
%%time

clean_root = remove_dag_redundant_nodes(clean_root)

Wall time: 52.9 ms


### Obtaining clean nodes and edges

In [24]:
def get_nodes_with_dfs(root):
    nodes = [root]
    for e in root.edges:
        if e.node_a != root:
            print(e)
        
        nodes += get_nodes_with_dfs(e.node_b)
        
    return nodes


def get_edges_with_dfs(root):
    edges = []
    for e in root.edges:
        edges += [e]
        edges += get_edges_with_dfs(e.node_b)
        
    return edges

In [25]:
clean_nodes = get_nodes_with_dfs(clean_root)
clean_edges = get_edges_with_dfs(clean_root)

print(f'# of nodes: {len(clean_nodes)}, # of edges: {len(clean_edges)}')

# of nodes: 7337, # of edges: 7336


## Populating graph with basic metadata

### reordering edges voxels

In [26]:
def reorder_edges_voxels(edge):
    node_voxels = [tuple(voxel) for voxel in edge.node_a['voxels']]
    edge_voxels = [tuple(voxel) for voxel in edge['voxels']]
    all_voxels = node_voxels + edge_voxels
    
    queue = [node_voxels[0]]
    sorted_voxels = [node_voxels[0]]
    
    while len(queue) != 0:
        x, y, z = queue.pop(0)
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    
                    neighbour = (x + dx, y + dy, z + dz)
                    
                    if (neighbour in sorted_voxels) or (neighbour not in all_voxels):
                        continue
                    
                    sorted_voxels.append(neighbour)
                    queue.append(neighbour)
                    
    sorted_edge_voxels = [voxel for voxel in sorted_voxels if voxel not in node_voxels]
    edge['voxels'] = sorted_edge_voxels
    
def fix_edges_voxels(root):
    edges = get_edges_with_dfs(root)
    for edge in edges:
        reorder_edges_voxels(edge)

In [27]:
%%time
fix_edges_voxels(clean_root)
print(clean_root['voxels'])
print(clean_root.edges[0]['voxels'])

[[1057  153  665]]
[(1056, 153, 666), (1055, 153, 666), (1054, 153, 666), (1053, 153, 666), (1052, 153, 666), (1051, 153, 667), (1050, 153, 667), (1049, 153, 667), (1048, 153, 667), (1047, 153, 667), (1046, 153, 668), (1045, 153, 668), (1044, 153, 668), (1043, 153, 668), (1042, 153, 668), (1041, 154, 668), (1040, 154, 668), (1039, 154, 669), (1038, 155, 668), (1037, 155, 669), (1036, 155, 669), (1035, 156, 669), (1034, 156, 669), (1033, 155, 670), (1032, 156, 670), (1031, 156, 670), (1030, 156, 670), (1029, 156, 670), (1028, 156, 671), (1027, 156, 671), (1026, 156, 672), (1025, 156, 672), (1024, 156, 672), (1023, 156, 673), (1022, 156, 673), (1021, 156, 674), (1020, 156, 674), (1019, 156, 675), (1018, 155, 675), (1017, 156, 675), (1016, 156, 676), (1015, 155, 676), (1014, 156, 676), (1013, 156, 676), (1012, 157, 676), (1011, 157, 677), (1010, 156, 678), (1009, 157, 677), (1008, 157, 678), (1007, 157, 678), (1006, 158, 678), (1005, 157, 679), (1004, 158, 679), (1003, 158, 680), (1002, 1

### edges and nodes thickness

In [30]:
def fix_nodes_thickness(root, skeleton_thickness):
    nodes = get_nodes_with_dfs(root)
    for node in nodes:
        thickness_list = [skeleton_thickness[tuple(coords)] for coords in node['voxels']]
        node['thickness'] = np.mean(thickness_list)
    

def add_edges_thickness(root, skeleton_thickness):
    edges = get_edges_with_dfs(root)
    for edge in edges:
        thickness_list = [skeleton_thickness[tuple(coords)] for coords in edge['voxels']]
        edge['thickness_list'] = np.array(thickness_list)
        edge['mean_thickness'] = np.mean(thickness_list)

In [31]:
%%time
fix_nodes_thickness(clean_root, skeleton_thickness)
add_edges_thickness(clean_root, skeleton_thickness)
clean_root.edges[0]['thickness_list']

Wall time: 361 ms


array([45, 45, 45, 45, 46, 46, 46, 46, 46, 46, 47, 47, 47, 47, 47, 47, 47,
       48, 48, 48, 48, 48, 48, 48, 49, 49, 49, 49, 49, 49, 50, 50, 50, 50,
       50, 50, 51, 51, 51, 51, 51, 51, 51, 51, 52, 52, 52, 52, 52, 52, 52,
       52, 52, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53,
       53, 53, 53, 53, 53, 53])

### centroids and edges lengths

In [32]:
def set_nodes_centroids(root):
    nodes = get_nodes_with_dfs(root)
    for node in nodes:
        node['centroid'] = np.mean(node['voxels'], axis=0)

        
def calculate_edge_length(edge, chunk_length=1):
    voxels = np.array(edge['voxels'])
    needed_nans = (chunk_length - (len(voxels) % chunk_length)) % chunk_length
    voxels = np.concatenate([voxels, np.full((needed_nans, 3), np.nan)])
    
    chunked_voxels = voxels.reshape(-1, chunk_length, 3)
    
    edge_centroids = np.nanmean(chunked_voxels, axis=1)
    starting_centroid = edge.node_a['centroid']
    ending_centroid = edge.node_b['centroid']
    
    centroids = np.concatenate([
        starting_centroid[np.newaxis, ...],
        edge_centroids,
        ending_centroid[np.newaxis, ...]
    ])
    
    squared_diffs = np.diff(centroids, axis=0) ** 2
    squared_lengths = np.sum(squared_diffs, axis=1)
    lengths = np.sqrt(squared_lengths)
    total_length = np.sum(lengths)
    
    return total_length
    

def set_edges_length(root, chunk_length=1):
    edges = get_edges_with_dfs(root)
    for edge in edges:
        edge['length'] = calculate_edge_length(edge, chunk_length)

In [33]:
%%time
set_nodes_centroids(clean_root)
set_edges_length(clean_root, 2)
print(clean_root.edges[0]['length'])

82.3454181093773
Wall time: 768 ms


## Creating DAG object

In [36]:
class DAG:
    def __init__(self, root, volume_shape):
        self.root = root
        self.nodes = get_nodes_with_dfs(root)
        self.edges = get_edges_with_dfs(root)
        self.volume_shape = volume_shape
        self.data = {}
    
    def __setitem__(self, key, value):
        self.data[key] = value
    
    def __getitem__(self, key):
        return self.data[key]
        

def save_dag(dag, filename):
    with open(filename, 'wb') as output:
        pickle.dump(dag, output)
        

def load_dag(filename):
    with open(filename, 'rb') as input_:
        dag = pickle.load(input_)
        return dag

In [38]:
dag = DAG(clean_root, reconstruction.shape)

## DAG visualization

In [39]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere


def print_kernels(image, nodes, value):
    image = image.copy()
    max_kernel_radius = int(max([node['thickness'] for node in nodes]))
    kernels = [spherical_kernel(radius) for radius in range(max_kernel_radius + 1)]
    
    padded_image = np.pad(image, max_kernel_radius)
    kernels_image = np.zeros(padded_image.shape)
    
    for node in nodes:
        x, y, z = (coord + max_kernel_radius for coord in node.coords)
        kernel_radius = int(node['thickness'])
        kernel = kernels[kernel_radius]
        
        mask_slice = kernels_image[
            x-kernel_radius:x+kernel_radius + 1,
            y-kernel_radius:y+kernel_radius + 1,
            z-kernel_radius:z+kernel_radius + 1
        ]
        
        mask_slice[:] = np.logical_or(mask_slice, kernel)
            
    kernels_image = kernels_image[
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius,
        max_kernel_radius:-max_kernel_radius
    ]
    
    image[kernels_image == 1] = value
    return image


def draw_nodes(image, nodes, value=2):
    nodes_image = print_kernels(image, nodes, value)
    return nodes_image

    
def draw_edges(image, edges, value='mean_thickness', interpolate=True):
    image = image.copy()

    for i, edge in enumerate(edges):
        if type(value) == str:
            fill_value = edge[value]
        else:
            fill_value = value
        
        if interpolate:
            image[line_nd(edge.node_a.coords, edge.node_b.coords)] = fill_value
        else:
            for v in edge['voxels']:
                image[tuple(v)] = fill_value
        
    return image

def draw_central_line(image, dag):
    image_with_edges = draw_edges(image, dag.edges, value=1, interpolate=False)
    for n in dag.nodes:
        for v in n['voxels']:
            image_with_edges[tuple(v)] = 1
        
    return image_with_edges

In [41]:
visualization = np.zeros(skeleton.shape)
visualization = draw_nodes(visualization, dag.nodes, 25)
visualization = draw_edges(visualization, dag.edges, value='mean_thickness')
visualize_gradient(visualization)

In [42]:
visualization = np.zeros(skeleton.shape)
visualization = draw_edges(visualization, dag.edges, value='length')
visualize_gradient(visualization)

In [43]:
central_line = draw_central_line(np.zeros(skeleton.shape), dag)
visualize_addition(central_line, skeleton)

## Saving dag

In [44]:
save_dag(dag, source_dir + TREE_NAME + '/dag.pkl')