In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import pickle
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from scipy.signal import fftconvolve
from skimage.morphology import skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy import signal
from skimage.filters import frangi, sato
from PIL import Image
from tqdm import tqdm

In [2]:
TREE_NAME = 'P07'

## Utility visualisation functions

In [3]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [4]:
def visualize_skeleton_2(skeleton, mask):
    skel = (skeleton > 0).astype(np.uint8) * 4
    mask = (mask > 0).astype(np.uint8) * 3
    mask[skeleton != 0] = 0
    ColorMapVisualizer(skel + mask).visualize()

## Loading specimen reconstruction

In [5]:
source_dir = './data/'
reconstruction = np.load(source_dir + TREE_NAME + '/reconstruction.npy')
reconstruction.shape

(838, 638, 1120)

In [6]:
good_names = list({
    'P01': 21,    'P02': 28,    'P03': 160,    'P04': 30,
    'P05': 24,    'P06': 22,    'P07': 100,    'P09': 70,
    'P10': 50,    'P11': 50,    'P12': 74,    'P13': 45,
    'P14': 42,    'P15': 95,    'P16': 25,    'P17': 80,
    'P18': 85,    'P19': 50,    'P20': 95,    'P21': 45,
    'P23': 60,    'P24': 80,    'P25': 70,    'P26': 130,
    'P27': 70,    'P28': 25,    'P30': 80,    'P31': 50,
    'P32': 48,    'P33': 65,}.keys())
# good_names

In [7]:
# for name in good_names:
# #     reko = np.load(source_dir + name + '/reconstruction.npy')
#     central_line = np.load(source_dir + name + '/central-line.npy')
#     print(name)
# #     print(reko.shape)
# #     print(central_line.shape)
# #     visualize_mask_non_bin(reko)
# #     visualize_skeleton_2(central_line, reko)
#     visualize_mask_bin(central_line)

In [8]:
visualize_mask_non_bin(reconstruction)

In [6]:
# reco = reconstruction[:-80, :, :]
# reco[-1, :, :] = 0
# visualize_mask_bin(reco)
# # np.save(source_dir + TREE_NAME + '/reconstruction', reco)

## Obtaining skeleton

In [9]:
%%time
skeleton = (skeletonize_3d(reconstruction) > 0).astype(np.uint8)

Wall time: 1min 39s


In [12]:
visualize_skeleton_2(skeleton, reconstruction)

## Trimming skeleton

In [23]:
def trim_skeleton_once(skeleton, condidate_voxels):
    trimmed_skeleton = skeleton.copy()
    leaves_neighbours = []
    
    for voxel in condidate_voxels:
        x, y, z = tuple(voxel)
        neighbours_count = 0
        voxel_neighbours = []
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if skeleton[neighbour_x, neighbour_y, neighbour_z]:
                        neighbours_count += 1
                        voxel_neighbours.append((neighbour_x, neighbour_y, neighbour_z))
        if neighbours_count < 2:
            trimmed_skeleton[x, y, z] = 0
            leaves_neighbours += voxel_neighbours     
    return trimmed_skeleton.astype(np.uint8), leaves_neighbours

def trim_skeleton(skeleton, iters):
    trimmed_skeleton, trim_neighbours = trim_skeleton_once(skeleton, np.argwhere(skeleton))
    print('iter 1 done')
    for i in range(1, iters):
        trimmed_skeleton, trim_neighbours = trim_skeleton_once(trimmed_skeleton, trim_neighbours)
        print(f'iter {i + 1} done')
    return trimmed_skeleton

In [24]:
%%time
iterations = {
    'P01': 15,
    'P02': 65,
    'P04': 15,
    'P05': 75,
    'P11': 45,
    'P12': 24,
    'P13': 20,
    'P15': 45,
    'P16': 15,
    'P21': 40,
    'P26': 45,
    'P32': 40,
    'P33': 40,
}

trimmed_skeleton = trim_skeleton(skeleton, iters=iterations.get(TREE_NAME, 30))

iter 1 done
iter 2 done
iter 3 done
iter 4 done
iter 5 done
iter 6 done
iter 7 done
iter 8 done
iter 9 done
iter 10 done
iter 11 done
iter 12 done
iter 13 done
iter 14 done
iter 15 done
iter 16 done
iter 17 done
iter 18 done
iter 19 done
iter 20 done
iter 21 done
iter 22 done
iter 23 done
iter 24 done
iter 25 done
iter 26 done
iter 27 done
iter 28 done
iter 29 done
iter 30 done
Wall time: 19.4 s


In [25]:
visualize_addition(trimmed_skeleton, skeleton)
# visualize_addition(trimmed_skeleton, reconstruction)

## Propagating thickness

In [26]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def prepare_kernels(kernel_sizes):
    kernels = []
    for kernel_radius in sorted(kernel_sizes):
        kernel = spherical_kernel(kernel_radius, filled=True)
        kernels.append((kernel_radius, kernel, np.sum(kernel)))
    return kernels

def calculate_skeleton_thickness(skeleton, 
                                 reconstruction, 
                                 kernel_sizes, 
                                 fill_threshold):
    kernels = prepare_kernels(kernel_sizes)
    max_kernel_radius = np.max(kernel_sizes)
    padded_reconstruction = np.pad(reconstruction, max_kernel_radius)
    skeleton_voxels = np.argwhere(skeleton)
    result = np.zeros(skeleton.shape, dtype=np.int)
    
    for voxel in tqdm(skeleton_voxels):
        x, y, z = tuple(voxel + max_kernel_radius)
        for kernel_radius, kernel, kernel_sum in kernels:
            reconstruction_slice = padded_reconstruction[
                x - kernel_radius : x + kernel_radius + 1,
                y - kernel_radius : y + kernel_radius + 1,
                z - kernel_radius : z + kernel_radius + 1
            ]
            fill_factor = np.sum(np.logical_and(reconstruction_slice, kernel)) / kernel_sum
            if fill_factor > fill_threshold:
                result[tuple(voxel)] = kernel_radius + 1
            else:
                break
    return result

In [27]:
%%time
kernel_sizes = {
    'P01': range(70),
}

fill_thresholds = {
    'P01': 0.85,
}

skeleton_thickness = calculate_skeleton_thickness(trimmed_skeleton,
                                                  reconstruction,
                                                  kernel_sizes.get(TREE_NAME, range(70)),
                                                  fill_thresholds.get(TREE_NAME, 0.85))

100%|██████████████████████████████████████████████████████████████████████████| 51657/51657 [00:10<00:00, 4879.11it/s]

Wall time: 16.1 s





In [28]:
%%time
np.unique(skeleton_thickness, return_counts=True)

Wall time: 10.5 s


(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43]),
 array([598749623,      2258,      4485,     11193,      7892,      6922,
             4150,      3025,      2348,      2378,      1434,      1019,
              776,       744,       634,       529,       302,       295,
              279,       174,        84,       110,        71,       114,
               60,        31,       110,       112,        28,        11,
                5,        11,         6,         6,         6,         4,
                4,         4,         4,         6,         4,         6,
               15,         8], dtype=int64))

In [29]:
visualize_gradient(skeleton_thickness)

## Propagate thickness from trimmed skeleton to the ends

In [30]:
def propagate_thickness_to_trims(trimmed_skeleton, skeleton, skeleton_thickness):   
    whole_skeleton_thicksness = np.zeros(skeleton.shape, dtype=np.int)
    whole_skeleton_thicksness[trimmed_skeleton > 0] = skeleton_thickness[trimmed_skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(trimmed_skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thickness = whole_skeleton_thicksness[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if whole_skeleton_thicksness[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if not skeleton[neighbour_x, neighbour_y, neighbour_z]:
                        continue
                        
                    whole_skeleton_thicksness[neighbour_x, neighbour_y, neighbour_z] = thickness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return whole_skeleton_thicksness

In [31]:
%%time
whole_skeleton_thicksness = propagate_thickness_to_trims(trimmed_skeleton, skeleton, skeleton_thickness)

Wall time: 9.05 s


## Adding leaves to the skeleton

In [32]:
def get_largest_regions(binary_mask, min_size=1000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_labels = []
    for props in region_props:
        if props.area > min_size:
            main_regions_labels.append(props.label)
            
    print(f"found {len(main_regions_labels)} main regions, out of {np.max(labeled)} regions")
            
    largest_regions = np.zeros(labeled.shape)
    for main_region_label in main_regions_labels:
        largest_regions[labeled == main_region_label] = 1
    return largest_regions.astype(np.uint8)

def make_ends_meet(skeleton, trimmed_skeleton, skeleton_thickness, ends_max_radius):
    ends = (skeleton - trimmed_skeleton) * (skeleton_thickness <= ends_max_radius)
    ends = ends.astype(np.uint8)
    
    trimmed_with_ends = (trimmed_skeleton > 0).astype(np.uint8) + ends
    return get_largest_regions(trimmed_with_ends, connectivity=3)

In [33]:
# visualize_gradient(whole_skeleton_thicksness)

In [34]:
%%time
ends_max_radius = {
#     'P01': 20,
    'P02': 30,
    'P05': 30,
    'P07': 15,
    'P10': 7,    
    'P12': 10,
    'P13': 10,
    'P14': 15,
    'P15': 17,
    'P18': 15,
    'P19': 15,
    'P21': 17,
    'P23': 17,
    'P24': 17,
    'P25': 17,
    'P26': 13,
    'P27': 10,
    'P28': 14,
    'P30': 15,
    'P30': 18,
    'P32': 22,
    'P33': 17,
}

full_skeleton = make_ends_meet(skeleton, trimmed_skeleton, whole_skeleton_thicksness, 
                               ends_max_radius=ends_max_radius.get(TREE_NAME, 20))

found 1 main regions, out of 1 regions
Wall time: 15.9 s


In [35]:
visualize_addition(full_skeleton, skeleton) # full_skeleton is green (1), purple is what i trimmed

In [36]:
full_skeleton_thickness = whole_skeleton_thicksness.copy()
full_skeleton_thickness[full_skeleton == 0] = 0

In [37]:
# visualize_gradient(full_skeleton_thickness)

In [27]:
visualize_gradient(full_skeleton_thickness)

## Saving skeleton and thiccness map

In [28]:
np.save(source_dir + TREE_NAME + '/central-line', full_skeleton)
np.save(source_dir + TREE_NAME + '/central-line-radii', full_skeleton_thickness.astype(np.uint8))