In [63]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import pickle
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from scipy.signal import fftconvolve
from skimage.morphology import skeletonize_3d, binary_dilation
from skimage import filters, morphology
from scipy import signal
from skimage.filters import frangi, sato
from PIL import Image
from tqdm import tqdm

In [80]:
TREE_NAME = 'M18'

## Utility visualisation functions

In [81]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

In [82]:
def visualize_skeleton_2(skeleton, mask):
    skel = (skeleton > 0).astype(np.uint8) * 4
    mask = (mask > 0).astype(np.uint8) * 3
    mask[skeleton != 0] = 0
    ColorMapVisualizer(skel + mask).visualize()

## Loading specimen reconstruction

In [83]:
source_dir = './data/models/'
reconstruction = np.load(source_dir + TREE_NAME + '/reconstruction.npy')

In [51]:
visualize_mask_non_bin(reconstruction)
# visualize_mask_bin(reconstruction)
# visualize_skeleton(reconstruction)

## Obtaining skeleton

In [84]:
%%time
skeleton = (skeletonize_3d(reconstruction) > 0).astype(np.uint8)

Wall time: 25.7 s


In [71]:
visualize_skeleton_2(skeleton, reconstruction)

In [12]:
visualize_mask_bin(skeleton)

## Trimming skeleton

In [87]:
def trim_skeleton_once(skeleton, condidate_voxels):
    trimmed_skeleton = skeleton.copy()
    leaves_neighbours = []
    
    for voxel in condidate_voxels:
        x, y, z = tuple(voxel)
        neighbours_count = 0
        voxel_neighbours = []
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    if dx == dy == dz == 0:
                        continue
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if skeleton[neighbour_x, neighbour_y, neighbour_z]:
                        neighbours_count += 1
                        voxel_neighbours.append((neighbour_x, neighbour_y, neighbour_z))
        if neighbours_count < 2:
            trimmed_skeleton[x, y, z] = 0
            leaves_neighbours += voxel_neighbours     
    return trimmed_skeleton.astype(np.uint8), leaves_neighbours

def trim_skeleton(skeleton, iters):
    trimmed_skeleton, trim_neighbours = trim_skeleton_once(skeleton, np.argwhere(skeleton))
    print('iter 1 done')
    for i in range(1, iters):
        trimmed_skeleton, trim_neighbours = trim_skeleton_once(trimmed_skeleton, trim_neighbours)
        print(f'iter {i + 1} done')
    return trimmed_skeleton

In [90]:
%%time
iterations = {
    'P01': 15,
    'P05': 75,
    'P12': 24,
    'M18': 3,
}

trimmed_skeleton = trim_skeleton(skeleton, iters=iterations.get(TREE_NAME, 30))

iter 1 done
iter 2 done
iter 3 done
Wall time: 2.44 s


In [91]:
visualize_addition(trimmed_skeleton, skeleton)
# visualize_addition(trimmed_skeleton, reconstruction)

## Propagating thickness

In [92]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

In [93]:
# def convolve_with_ball_on_given_voxels(padded_reconstruction, voxels_mask, kernel_radius, normalize=True):
#     result = np.zeros(voxels_mask.shape)
#     kernel = spherical_kernel(kernel_radius, filled=True)
    
#     voxels = np.argwhere(voxels_mask)
#     for voxel in voxels:
#         x, y, z = tuple(voxel + kernel_radius)
        
#         reconstruction_slice = padded_reconstruction[
#             x - kernel_radius : x + kernel_radius + 1,
#             y - kernel_radius : y + kernel_radius + 1,
#             z - kernel_radius : z + kernel_radius + 1
#         ]
        
#         filled_slice = np.logical_and(reconstruction_slice, kernel)
#         result[tuple(voxel)] = np.sum(filled_slice)           
    
#     if not normalize:
#         return result
    
#     return (result / kernel.sum()).astype(np.float16)

In [94]:
# #PSEUDO
# def calculate_radii(skeleton, reconstruction):
#   radii_mask = zeros(skeleton.shape)

#   for voxel in skeleton.voxels:
#     for kernel_radius range(60):
#       kernel = get_sphere_kernel(kernel_radius)
#       volume_slice = get_volume_slice(volume = reconstruction,
#                                       midpoint = voxel,
#                                       size = kernel_radius)
#       covered_size = logical_and(slice, kernel)
#       covered_percentage = covered_size /  kernel.sum()
#       if covered_percentage > 0.85:
#         radii_mask[voxel] = kernel_radius
#       else:
#         break
#   return radii_mask

In [95]:
# def calculate_skeleton_thickness(skeleton, 
#                                  reconstruction, 
#                                  kernel_sizes, 
#                                  fill_threshold):
    
#     kernel_size_map = np.zeros(skeleton.shape)
    
#     for kernel_size in sorted(kernel_sizes):
#         padded_reconstruction = np.pad(reconstruction, kernel_size)
#         fill_percentage = convolve_with_ball_on_given_voxels(padded_reconstruction, skeleton, kernel_size)
#         above_threshold_fill_indices = fill_percentage >= fill_threshold
#         kernel_size_map[above_threshold_fill_indices] = kernel_size + 1
#         print(f'Kernel {kernel_size} done')
        
#     return kernel_size_map

In [96]:
def prepare_kernels(kernel_sizes):
    kernels = []
    for kernel_radius in sorted(kernel_sizes):
        kernel = spherical_kernel(kernel_radius, filled=True)
        kernels.append((kernel_radius, kernel, np.sum(kernel)))
    return kernels

def calculate_skeleton_thickness(skeleton, 
                                 reconstruction, 
                                 kernel_sizes, 
                                 fill_threshold):
    kernels = prepare_kernels(kernel_sizes)
    max_kernel_radius = np.max(kernel_sizes)
    padded_reconstruction = np.pad(reconstruction, max_kernel_radius)
    skeleton_voxels = np.argwhere(skeleton)
    result = np.zeros(skeleton.shape, dtype=np.int)
    
    for voxel in tqdm(skeleton_voxels):
        x, y, z = tuple(voxel + max_kernel_radius)
        for kernel_radius, kernel, kernel_sum in kernels:
            reconstruction_slice = padded_reconstruction[
                x - kernel_radius : x + kernel_radius + 1,
                y - kernel_radius : y + kernel_radius + 1,
                z - kernel_radius : z + kernel_radius + 1
            ]
            fill_factor = np.sum(np.logical_and(reconstruction_slice, kernel)) / kernel_sum
            if fill_factor > fill_threshold:
                result[tuple(voxel)] = kernel_radius + 1
            else:
                break
    return result

In [97]:
%%time
kernel_sizes = {
    'P01': range(60),
    'P05': range(50),
    'P12': range(30),
}

fill_thresholds = {
    'P01': 0.85,
}

skeleton_thickness = calculate_skeleton_thickness(trimmed_skeleton,
                                                  reconstruction,
                                                  kernel_sizes.get(TREE_NAME, range(70)),
                                                  fill_thresholds.get(TREE_NAME, 0.85))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25828/25828 [00:03<00:00, 7921.38it/s]

Wall time: 6.82 s





In [98]:
%%time
np.unique(skeleton_thickness, return_counts=True)

Wall time: 3.75 s


(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]),
 array([225580252,        14,       201,      6990,      7569,      4724,
             2341,       836,      1182,       275,       598,       140,
               82,       334,        63,        41,        86,       128,
               31,        15,        16,        20,        72,        16,
                9,         9,        11,        25], dtype=int64))

In [99]:
# %%time
# kernel_sizes = {
#     'P01': range(50),
#     'P05': range(50),
#     'P12': range(30),
# }

# fill_thresholds = {
#     'P01': 0.8,
#     'P05': 0.8,
#     'P12': 0.8,
# }

# skeleton_thickness = calculate_skeleton_thickness(trimmed_skeleton,
#                                                   reconstruction,
#                                                   kernel_sizes.get(TREE_NAME, range(12)),
#                                                   fill_thresholds.get(TREE_NAME, 0.8))

In [51]:
visualize_gradient(skeleton_thickness)

## Propagate thickness from trimmed skeleton to the ends

In [100]:
def propagate_thickness_to_trims(trimmed_skeleton, skeleton, skeleton_thickness):   
    whole_skeleton_thicksness = np.zeros(skeleton.shape, dtype=np.int)
    whole_skeleton_thicksness[trimmed_skeleton > 0] = skeleton_thickness[trimmed_skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(trimmed_skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thickness = whole_skeleton_thicksness[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if whole_skeleton_thicksness[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if not skeleton[neighbour_x, neighbour_y, neighbour_z]:
                        continue
                        
                    whole_skeleton_thicksness[neighbour_x, neighbour_y, neighbour_z] = thickness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return whole_skeleton_thicksness

In [101]:
%%time
whole_skeleton_thicksness = propagate_thickness_to_trims(trimmed_skeleton, skeleton, skeleton_thickness)

Wall time: 3.27 s


In [102]:
# visualize_gradient(whole_skeleton_thicksness)

In [103]:
def propagate_thickness(skeleton, skeleton_thicksness, reconstruction):   
    thiccness_map = np.zeros(reconstruction.shape)
    thiccness_map[skeleton > 0] = skeleton_thicksness[skeleton > 0]
    
    queue = list([tuple(coords) for coords in np.argwhere(skeleton)])
    while(len(queue) > 0):
        x, y, z = queue.pop(0)
        thiccness = thiccness_map[x, y, z]
        
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    neighbour_x = x + dx
                    neighbour_y = y + dy
                    neighbour_z = z + dz
                    if thiccness_map[neighbour_x, neighbour_y, neighbour_z] > 0:
                        continue
                        
                    if not reconstruction[neighbour_x, neighbour_y, neighbour_z]:
                        continue
                        
                    thiccness_map[neighbour_x, neighbour_y, neighbour_z] = thiccness
                    queue.append((neighbour_x, neighbour_y, neighbour_z))
                        
    return thiccness_map

In [104]:
# %%time
# thickness_map = propagate_thickness(trimmed_skeleton, skeleton_thickness, reconstruction)

In [105]:
# visualize_gradient(thiccness_map)

## Adding leaves to the skeleton

In [106]:
def get_largest_regions(binary_mask, min_size=1000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions_labels = []
    for props in region_props:
        if props.area > min_size:
            main_regions_labels.append(props.label)
            
    print(f"found {len(main_regions_labels)} main regions, out of {np.max(labeled)} regions")
            
    largest_regions = np.zeros(labeled.shape)
    for main_region_label in main_regions_labels:
        largest_regions[labeled == main_region_label] = 1
    return largest_regions.astype(np.uint8)

def make_ends_meet(skeleton, trimmed_skeleton, skeleton_thickness, ends_max_radius):
    ends = (skeleton - trimmed_skeleton) * (skeleton_thickness <= ends_max_radius)
    ends = ends.astype(np.uint8)
    
    trimmed_with_ends = (trimmed_skeleton > 0).astype(np.uint8) + ends
    return get_largest_regions(trimmed_with_ends, connectivity=3)

In [107]:
%%time
ends_max_radius = {
    'P01': 20,
    'P05': 24,
    'P12': 8,
    'M09': 7,
    'M10': 7
}

full_skeleton = make_ends_meet(skeleton, trimmed_skeleton, whole_skeleton_thicksness, 
                               ends_max_radius=ends_max_radius.get(TREE_NAME, 7)) #20)) # dla normalnych drzew 20

found 1 main regions, out of 2 regions
Wall time: 5.83 s


In [109]:
visualize_addition(full_skeleton, skeleton) # full_skeleton is green (1)

In [26]:
# wasteskel = skeleton.copy()
# wasteskel[:-40, :, :] = 0
# xdskel = full_skeleton + wasteskel
# visualize_addition(xdskel, skeleton)

In [24]:
# visualize_mask_bin(skeleton[-10:, :, :])

In [31]:
# full_skeleton = xdskel > 0

In [110]:
full_skeleton_thickness = whole_skeleton_thicksness.copy()
full_skeleton_thickness[full_skeleton == 0] = 0

In [29]:
visualize_gradient(full_skeleton_thickness)

## Saving skeleton and thiccness map

In [None]:
np.save(source_dir + TREE_NAME + '/central-line', full_skeleton)
np.save(source_dir + TREE_NAME + '/central-line-radii', full_skeleton_thickness.astype(np.uint8))

In [32]:
# np.save(source_dir + TREE_NAME + '/thickness-map', thiccness_map)