In [1]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import pickle
from skimage import measure, segmentation, feature
from vis_utils import load_volume, VolumeVisualizer, ColorMapVisualizer
from scipy.ndimage import zoom
from scipy.signal import fftconvolve
from skimage.morphology import skeletonize_3d, binary_dilation, convex_hull_image
from skimage import filters, morphology
from scipy import signal
from skimage.filters import frangi, sato
from PIL import Image
from tqdm import tqdm

from scipy.ndimage import zoom

In [2]:
TREE_NAME = 'P02'

## Loading specimen volume

In [3]:
source_dir = './data/'
files = {path.split('\\')[1]: path for path in sorted(glob.glob(source_dir + '*/*.raw'))}
files

{'P01': './data\\P01\\P01_60um_1612x623x1108.raw',
 'P02': './data\\P02\\P02_60um_1387x778x1149.raw',
 'P03': './data\\P03\\P03_60um_1473x1163x1148.raw',
 'P04': './data\\P04\\P04_60um_1273x466x1045.raw',
 'P05': './data\\P05\\P05_60um_1454x817x1102.raw',
 'P06': './data\\P06\\P06_60um_1425x564x1028.raw',
 'P07': './data\\P07\\P7_60um_1216x692x926.raw',
 'P08': './data\\P08\\P08_60um_1728x927x1149.raw',
 'P09': './data\\P09\\P09_60um_1359x456x1040.raw',
 'P10': './data\\P10\\P10_60um_1339x537x1035.raw',
 'P11': './data\\P11\\P11_60um_1735x595x1150.raw',
 'P12': './data\\P12\\P12_60um_1333x443x864.raw',
 'P13': './data\\P13\\P13_60um_1132x488x877.raw',
 'P14': './data\\P14\\P14_60um_1927x746x1124.raw',
 'P15': './data\\P15\\P15_60um_1318x640x1059.raw',
 'P16': './data\\P16\\P16_60um_1558x687x1084.raw',
 'P17': './data\\P17\\P17_60um_1573x555x968.raw',
 'P18': './data\\P18\\73_60um_1729x854x1143.raw',
 'P19': './data\\P19\\320_60um_1739x553x960.raw',
 'P20': './data\\P20\\333_60um_1762x9

In [4]:
%%time

# scales = {
#     'P01': 0.5,
#     'P04': 0.5,
#     'P05': 0.5,
#     'P12': 0.5,
# }

volume = load_volume(files[TREE_NAME]) #, scale=scales.get(TREE_NAME, 0.5))
print(volume.shape)
# VolumeVisualizer(volume, binary=False).visualize()

(1149, 778, 1387)
Wall time: 2.83 s


## Utility visualisation functions

In [5]:
def visualize_addition(base, base_with_addition):
    base = (base.copy() > 0).astype(np.uint8)
    addition = (base_with_addition > 0).astype(np.uint8)
    addition[base == 1] = 0
    ColorMapVisualizer(base + addition * 4).visualize()
    
def visualize_lsd(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize()
    
def visualize_gradient(lsd_mask):
    ColorMapVisualizer(lsd_mask.astype(np.uint8)).visualize(gradient=True)
    
def visualize_mask_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8), binary=True).visualize()
    
def visualize_mask_non_bin(mask):
    VolumeVisualizer((mask > 0).astype(np.uint8) * 255, binary=False).visualize()
    
def visualize_skeleton(mask, visualize_mask=True, visualize_both_versions=False):
    skeleton = skeletonize_3d((mask > 0).astype(np.uint8))
    if not visualize_mask or visualize_both_versions:
        VolumeVisualizer(skeleton, binary=True).visualize()
    if visualize_mask or visualize_both_versions:
        skeleton = skeleton.astype(np.uint8) * 4
        mask = (mask > 0).astype(np.uint8) * 3
        mask[skeleton != 0] = 0
        ColorMapVisualizer(skeleton + mask).visualize()

def visualize_ultimate(lsd, base_mask):
    visualize_lsd(lsd)
    visualize_mask_non_bin(lsd)
    visualize_addition(base_mask, lsd)
    visualize_skeleton(lsd, visualize_mask=True)

## Thresholding volume to get binary mask

In [6]:
VolumeVisualizer(volume, binary=False).visualize()

In [7]:
thresholds = {
    'P01': 21,
    'P02': 26,
    'P04': 30,
    'P05': 24,
    'P12': 70,
}

mask = volume > thresholds[TREE_NAME]
# volume = None
visualize_mask_bin(mask)

In [8]:
visualize_mask_non_bin(mask)

## Extracting main region 

In [10]:
def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
    labeled = measure.label(binary_mask, connectivity=connectivity)
    region_props = measure.regionprops(labeled)
    
    main_regions = np.zeros(binary_mask.shape)
    bounding_boxes = []
    for props in region_props:
        if props.area >= min_size:
            bounding_boxes.append(props.bbox)
            main_regions = np.logical_or(main_regions, labeled==props.label)
            
    lower_bounds = np.min(bounding_boxes, axis=0)[:3]
    upper_bounds = np.max(bounding_boxes, axis=0)[3:]

    return main_regions[
        lower_bounds[0]:upper_bounds[0],
        lower_bounds[1]:upper_bounds[1],
        lower_bounds[2]:upper_bounds[2],
    ], bounding_boxes

# def get_main_regions(binary_mask, min_size=10_000, connectivity=3):
#     labeled = measure.label(binary_mask, connectivity=connectivity)
#     region_props = measure.regionprops(labeled)
    
#     main_regions_masks = []
#     regions_labels = []
#     bounding_boxes = []
    
#     for props in region_props:
#         if props.area >= min_size:
#             main_regions_masks.append(props.filled_image)
#             regions_labels.append(props.label)
#             bounding_boxes.append(props.bbox)
            
#     return main_regions_masks, regions_labels, bounding_boxes

In [13]:
main_region_min_size = {
    'P01': 30_000,
    'P04': 30_000,
    'P05': 30_000,
    'P12': 30_000,
}

main_regions, bounding_boxes = get_main_regions(mask, min_size=main_region_min_size.get(TREE_NAME, 20_000))
print('number of main regions:', len(bounding_boxes))
mask_main = main_regions
# mask = None
visualize_mask_non_bin(mask_main)

number of main regions: 2


# Creating reconstruction using resolution scaling

In [17]:
def scale_mask(mask, scale, order=0):
    return zoom(mask, scale, order=order)
    
def verify_mask(mask):
    regions_count = np.max(measure.label(mask, connectivity=3))
    return "good mask" if regions_count == 1 else f"scattered mask, number of regions: {regions_count}"

### Creating reconstruction of scaled tree

In [20]:
%%time
scales = {
    'P01': 0.5,
    'P02': 0.7,
    'P04': 0.5,
    'P05': 0.5,
    'P12': 0.5,
}
s_main_region_min_size = {
    'P01': 10_000,
    'P04': 10_000,
    'P05': 10_000,
    'P12': 10_000,
}

s_volume = load_volume(files[TREE_NAME], scale=scales.get(TREE_NAME, 0.5))
print('scaled volume shape', s_volume.shape)
s_mask = s_volume > thresholds[TREE_NAME]
s_volume = None
s_main_regions, s_bounding_boxes = get_main_regions(s_mask, min_size=s_main_region_min_size.get(TREE_NAME, 10_000))
print('number of s main regions:', len(s_bounding_boxes))
s_mask_main = s_main_regions
s_mask = None
s_main_regions = None
verify_mask(s_mask_main)

scaled volume shape (804, 545, 971)
number of s main regions: 2
Wall time: 17.3 s


'scattered mask, number of regions: 2'

In [21]:
visualize_mask_bin(s_mask_main)

In [22]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True, fft=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    if fft:
        convolved = fftconvolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    else:
        convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def calculate_reconstruction(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16, fft=True):
    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True, fft=fft)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

In [23]:
%%time

s_kernel_sizes = {
    'P01': range(0, 14),
    'P02': range(0, 14),
    'P04': range(0, 13),
    'P05': range(0, 13),
    'P12': range(0, 13),
}

s_number_of_iterations = {
    'P01': 7,
    'P02': 5,
    'P04': 2,
    'P05': 3,
    'P12': 3,
}

s_recos = calculate_reconstruction(s_mask_main, 
                                   kernel_sizes=s_kernel_sizes.get(TREE_NAME, range(0, 13)), 
                                   iters=s_number_of_iterations.get(TREE_NAME, 3))

Iteration 1 kernel 0 done
Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 kernel 13 done
Iteration 1 ended successfully
Iteration 2 kernel 0 done
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 kernel 13 done
Iteration 2 ended successfully
Iteration 3 kernel 0 done
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7

In [24]:
s_reco = s_recos[-1] > 0
verify_mask(s_reco)
# visualize_mask_non_bin(s_reco)
# visualize_addition(s_mask_main, s_reco)

'good mask'

### Scaling up and trimming reconstruction of scaled tree

In [25]:
def upscale_reconstruction(s_reco, scale):
    upscale = 1 / scale
    u_reco = scale_mask(s_reco, upscale, order=3)
    return u_reco

In [26]:
%%time
u_reco = upscale_reconstruction(s_reco, scale=scales.get(TREE_NAME, 0.5))

Wall time: 4min 52s


In [27]:
shape = np.min([mask_main.shape, u_reco.shape], axis=0)
print('mask_main shape', mask_main.shape, 'u_reco shape', u_reco.shape)
print('final shape', shape)
mask_main_r = mask_main[:shape[0], :shape[1], :shape[2]]
u_reco_r = u_reco[:shape[0], :shape[1], :shape[2]]

mask_main shape (1121, 719, 1260) u_reco shape (1120, 720, 1259)
final shape [1120  719 1259]


In [28]:
joint_reco = np.logical_or(mask_main_r, u_reco_r)
verify_mask(joint_reco)
main_regions, bounding_boxes = get_main_regions(joint_reco, min_size=main_region_min_size.get(TREE_NAME, 20_000))
print('number of main regions:', len(bounding_boxes))
joint_reco_main = main_regions
mask = None
main_regions = None
new_joint_reco = np.logical_or(mask_main_r, joint_reco_main)
# visualize_addition(mask_main_r, new_joint_reco)

number of main regions: 1


In [29]:
np.sum(joint_reco), np.sum(joint_reco_main), np.sum(new_joint_reco)

(24107828, 24107169, 24107169)

In [30]:
# def is_any_neighbour_in_mask(voxel, mask):
#     for margin in [[0, 0, 0], [-1, 0, 0], [1, 0, 0], [0, -1, 0], [0, 1, 0], [0, 0, -1], [0, 0, 1]]:
#         neighbour = voxel + margin
#         if (np.zeros(3) <= neighbour).all() and (neighbour < mask.shape).all() and mask[tuple(neighbour)]:
#             return True
#     return False
    

# def trim_upscaled_reco(joint_reco, mask_main_r, min_removed=1000):
#     reco = joint_reco.copy()
#     removed_count = min_removed + 1
#     while(removed_count >= min_removed):
#         removed_count = 0
#         convolution = convolve_with_ball(reco, 1)
#         edge_voxels = np.argwhere(np.logical_and(convolution < 1, convolution > 0))
#         print(f'{edge_voxels.shape[0]} edge voxels found')
#         for edge_voxel in edge_voxels:
#             if not is_any_neighbour_in_mask(edge_voxel, mask_main_r):
#                 removed_count += 1
#                 reco[tuple(edge_voxel)] = 0
#         print('removed', removed_count, 'vertices,', 'finished' if removed_count < min_removed else 'continuing')
#     return reco

In [31]:
# new_joint_reco.sum()

In [32]:
# %%time
# trimmed_reco = trim_upscaled_reco(new_joint_reco, mask_main_r, min_removed=1000)

In [33]:
trimmed_reco = new_joint_reco

### Fine tuning full res reconstruction

In [34]:
%%time

ft_kernel_sizes = {
    'P01': range(0, 28),
    'P02': range(0, 28),
    'P04': range(0, 26),
    'P05': range(0, 26),
    'P12': range(0, 26),
}

ft_number_of_iterations = {
    'P01': 1,
    'P02': 1,
    'P04': 1,
    'P05': 1,
    'P12': 1,
}

ft_recos = calculate_reconstruction(new_joint_reco, 
                                    kernel_sizes=ft_kernel_sizes.get(TREE_NAME, range(0, 25)), 
                                    iters=ft_number_of_iterations.get(TREE_NAME, 1))

Iteration 1 kernel 0 done
Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 kernel 13 done
Iteration 1 kernel 14 done
Iteration 1 kernel 15 done
Iteration 1 kernel 16 done
Iteration 1 kernel 17 done
Iteration 1 kernel 18 done
Iteration 1 kernel 19 done
Iteration 1 kernel 20 done
Iteration 1 kernel 21 done
Iteration 1 kernel 22 done
Iteration 1 kernel 23 done
Iteration 1 kernel 24 done
Iteration 1 kernel 25 done
Iteration 1 kernel 26 done
Iteration 1 kernel 27 done
Iteration 1 ended successfully
Wall time: 59min 58s


In [35]:
reco = ft_recos[-1]

In [36]:
# %%time
# # second iter

# ft_recos2 = calculate_reconstruction(reco, 
#                                     kernel_sizes=ft_kernel_sizes.get(TREE_NAME, range(0, 25)), 
#                                     iters=1)

In [37]:
# reco2 = ft_recos2[-1]

In [38]:
# Verify and save reconstruction

In [39]:
# np.sum(reco != reco2)

In [40]:
# visualize_skeleton(reco2)

In [41]:
# visualize_mask_non_bin(reco)
# visualize_skeleton(reco)

In [42]:
# visualize_addition(mask_main_r, reco)

In [43]:
padded_reco = np.pad(reco, 1) # padding reconstruction to avoid padding later
np.save(source_dir + TREE_NAME + '/reco', padded_reco)

In [None]:
# pray for it to work

In [74]:
old_reco = np.load(source_dir + TREE_NAME + '/reconstruction.npy')

In [79]:
reco.shape

(1096, 571, 1546)

In [78]:
old_reco.shape

(1098, 573, 1560)

In [82]:
old_reco[1:-1, 1:-1, 7:-7].shape

(1096, 571, 1546)

In [83]:
np.sum(reco != old_reco[1:-1, 1:-1, 7:-7])

24223410

In [84]:
visualize_mask_non_bin(old_reco)

In [85]:
visualize_skeleton(old_reco)

In [87]:
visualize_addition(np.pad(mask_main, 1), old_reco)

In [86]:
mask_main.shape

(1096, 571, 1558)

In [89]:
%%time
old_skel = skeletonize_3d((old_reco > 0).astype(np.uint8))
skel = skeletonize_3d((reco > 0).astype(np.uint8))

Wall time: 6min 19s


In [91]:
skel.shape

(1096, 571, 1546)

In [92]:
old_skel.shape

(1098, 573, 1560)

In [93]:
skel.sum()

157283

In [94]:
old_skel.sum()

153510

In [113]:
skel[1:-1, 1:-1, 2:-1].sum()

157283

In [114]:
np.sum(skel != old_skel[1:-1, 1:-1, 1:-13])

279506

In [115]:
visualize_addition(old_skel[1:-1, 1:-1, 1:-13], skel)

In [116]:
visualize_mask_bin(skel)

In [22]:
scales = {
    'P01': 0.5,
    'P04': 0.5,
    'P05': 0.5,
    'P12': 0.5,
}

scaled_mask = scale_mask(mask_main, scale=scales.get(TREE_NAME, 0.5))
verify_mask(scaled_mask)

'scattered mask!!! number of regions: 767'

In [18]:
visualize_mask_non_bin(scaled_mask)

In [29]:
%%time

kernel_sizes = {
    'P01': range(0, 14),
    'P04': range(0, 13),
    'P05': range(0, 13),
    'P12': range(0, 13),
}

number_of_iterations = {
    'P01': 5,
    'P04': 1,
    'P05': 2,
    'P12': 2,
}

low_res_lsd_trees_2 = annihilate_jemiolas(scaled_mask, 
                                        kernel_sizes=kernel_sizes.get(TREE_NAME, range(0, 13)), 
                                        iters=number_of_iterations.get(TREE_NAME, 3),
                                        fft=False)

Iteration 1 kernel 0 done
Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 kernel 13 done
Iteration 1 ended successfully
Iteration 2 kernel 0 done
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 kernel 8 done
Iteration 2 kernel 9 done
Iteration 2 kernel 10 done
Iteration 2 kernel 11 done
Iteration 2 kernel 12 done
Iteration 2 kernel 13 done
Iteration 2 ended successfully
Iteration 3 kernel 0 done
Iteration 3 kernel 1 done
Iteration 3 kernel 2 done
Iteration 3 kernel 3 done
Iteration 3 kernel 4 done
Iteration 3 kernel 5 done
Iteration 3 kernel 6 done
Iteration 3 kernel 7

In [32]:
low_res_reco = low_res_lsd_trees[-1]
reco_2 = low_res_lsd_trees_2[-1]
np.unique(low_res_reco == reco_2)

array([ True])

In [31]:
visualize_mask_non_bin(low_res_reco)

In [33]:
visualize_addition(scaled_mask, low_res_reco)

In [68]:
xd = r * np.logical_not(e)

In [70]:
verify_mask(xd)

'scattered mask!!! number of regions: 2131'

In [53]:
visualize_addition(mask, upscaled_r)

In [55]:
verify_mask(mask)

'good mask'

## Filling holes in the mask to get rid of mistletoes in skeleton

In [9]:
def spherical_kernel(outer_radius, thickness=1, filled=True):    
    outer_sphere = morphology.ball(radius=outer_radius)
    if filled:
        return outer_sphere
    
    thickness = min(thickness, outer_radius)
    
    inner_radius = outer_radius - thickness
    inner_sphere = morphology.ball(radius=inner_radius)
    
    begin = outer_radius - inner_radius
    end = begin + inner_sphere.shape[0]
    outer_sphere[begin:end, begin:end, begin:end] -= inner_sphere
    return outer_sphere

def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = fftconvolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

def annihilate_jemiolas(mask, kernel_sizes=[10, 9, 8], fill_threshold=0.5, iters=1, conv_dtype=np.uint16):
    kernel_sizes_maps = []
    mask = mask.astype(np.uint8)
    
    for i in range(iters):
#         kernel_sizes_maps = []
        kernel_size_map = np.zeros(mask.shape, dtype=np.uint8)

        for kernel_size in kernel_sizes:
            fill_percentage = convolve_with_ball(mask, kernel_size, dtype=conv_dtype, normalize=True)
            
            above_threshold_fill_indices = fill_percentage > fill_threshold
            kernel_size_map[above_threshold_fill_indices] = kernel_size + 1

            mask[above_threshold_fill_indices] = 1
            
            print(f'Iteration {i + 1} kernel {kernel_size} done')

        kernel_sizes_maps.append(kernel_size_map)
        print(f'Iteration {i + 1} ended successfully')

    return kernel_sizes_maps

# TODO

na zmniejszonej rozdzielczosci pusc konwolucje duzymi kernelami zeby uzyskac te pewne punkty w srodku, a potem zrob taki fine tuning juz na duzej rozdzielczosci

In [11]:
# from scipy.spatial import ConvexHull, Delaunay

# specimen_voxels = np.argwhere(mask_main)
# specimen_voxels.shape

# hull = ConvexHull(specimen_voxels)
# hull.simplices.shape

# triangles = [specimen_voxels[simplex] for simplex in hull.simplices]

# extreme_voxels = np.array([specimen_voxels[i] for i in np.unique(hull.simplices.flatten())])
# inside_point = np.round(np.mean(extreme_voxels, axis=0)).astype(np.int)

# def tmp_vis(mask, extreme_voxels, mark_radius):
#     mask2 = mask.copy()
#     for v in extreme_voxels:
#         x, y, z = tuple(v)
#         mask2[x - mark_radius: x + mark_radius, 
#               y - mark_radius: y + mark_radius, 
#               z - mark_radius: z + mark_radius] = 4
#     visualize_lsd(mask2)
    
# tmp_vis(mask_main, extreme_voxels, 5)
# tmp_vis(mask_main, [inside_point], 5)

# normals_per_triangle = []
# for triangle in triangles:
#     p0, p1, p2 = tuple(triangle)
#     normal = np.cross(p1 - p0, p2 - p0)
#     if np.dot(p0 - inside_point, normal) < 0:
#         normal *= -1
#     normals_per_triangle.append(normal)
    
# all_points = np.argwhere(np.ones(mask_main.shape))

In [12]:
# %%time
# convex_hull_mask = np.zeros(mask_main.shape)
# for point in all_points:
#     inside = True
#     for tri, normal in zip(triangles, normals_per_triangle):
#         if np.dot(point - tri[0], normal) > 0:
#             inside = False
#             break
#     if inside:
#         convex_hull_mask[tuple(point)] = 1

In [13]:
## bad idea :c

In [14]:
max_kernel_size = 28

In [15]:
%%time
convolution = convolve_with_ball(mask_main, max_kernel_size, dtype=np.int)

Wall time: 2min 2s


In [16]:
candidate_mask = (convolution > 0).astype(np.int) - mask_main
candidate_vertices = np.argwhere(candidate_mask)
candidate_vertices.shape

(235902399, 3)

In [18]:
convolution_hull = (convolution > 0)

In [19]:
visualize_mask_non_bin(convolution_hull)

In [38]:
def prepare_kernels(kernel_sizes):
    kernels = []
    for kernel_radius in sorted(kernel_sizes):
        kernel = spherical_kernel(kernel_radius, filled=True)
        kernels.append((kernel_radius, kernel, np.sum(kernel)))
    return kernels

def prepare_kernels2(kernel_sizes):
    kernels = []
    max_kernel_radius = max(kernel_sizes)
    for kernel_radius in sorted(kernel_sizes):
        kernel = spherical_kernel(kernel_radius, filled=True)
        kernel_sum = np.sum(kernel)
        kernel = np.pad(kernel, max_kernel_radius - kernel_radius)
        kernels.append((kernel_radius, kernel, kernel_sum))
    return kernels

def try_fill_voxels(candidate_voxels, 
                    kernels,
                    max_kernel_radius,
                    reconstruction,
                    fill_threshold):
    
    result = reconstruction.copy()
    padded_reconstruction = np.pad(reconstruction, max_kernel_radius)
    
    for voxel in tqdm(candidate_voxels):
        x, y, z = tuple(voxel + max_kernel_radius)
        reconstruction_slice = padded_reconstruction[
                x - max_kernel_radius : x + max_kernel_radius + 1,
                y - max_kernel_radius : y + max_kernel_radius + 1,
                z - max_kernel_radius : z + max_kernel_radius + 1
            ]
        for kernel_radius, kernel, kernel_sum in kernels:
            fill_factor = np.sum(np.logical_and(reconstruction_slice, kernel)) / kernel_sum
            if fill_factor > fill_threshold:
                result[tuple(voxel)] = 1
                break
    return result

def create_reconstruction(mask,
                          hull_mask,
                          kernel_sizes,
                          fill_threshold,
                          iters=2,
                          save_history=False):
    kernels = prepare_kernels2(kernel_sizes)
    max_kernel_radius = np.max(kernel_sizes)
    reconstruction = mask
    history = []
    
    for i in range(iters):
        candidate_mask = hull_mask - reconstruction
        candidate_voxels = np.argwhere(candidate_mask)
        
        reconstruction = try_fill_voxels(candidate_voxels,
                                         kernels,
                                         max_kernel_radius,
                                         reconstruction,
                                         fill_threshold)
        if save_history:
            history.append(reconstruction)
        print(f'iteration {i} done')
    return reconstruction, history

In [39]:
%%time
reconstruction, history = create_reconstruction(mask=mask_main,
                                                hull_mask=convolution_hull,
                                                kernel_sizes=list(range(28)),
                                                fill_threshold=0.5,
                                                iters=1)

  0%|                                                                    | 3155/235902399 [00:31<651:58:52, 100.51it/s]


KeyboardInterrupt: 

In [40]:
%%time
reconstruction, history = create_reconstruction(mask=mask_main,
                                                hull_mask=convolution_hull,
                                                kernel_sizes=list(range(28)),
                                                fill_threshold=0.5,
                                                iters=1)

  0%|                                                                    | 3125/235902399 [00:30<647:49:42, 101.15it/s]


KeyboardInterrupt: 

In [124]:
np.argwhere(mask_main).shape

(13215738, 3)

In [128]:
spherical_kernel(1)

array([[[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]],

       [[0, 1, 0],
        [1, 1, 1],
        [0, 1, 0]],

       [[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]]], dtype=uint8)

In [24]:
%%time
colvolution_with_1 = convolve_with_ball(mask_main, 1)

Wall time: 1min 48s


In [25]:
borderline_voxels_mask = np.logical_and(mask_main > 0, colvolution_with_1 < 1)
borderline_voxels = np.argwhere(borderline_voxels_mask)

In [26]:
borderline_voxels.shape

(4277192, 3)

In [27]:
borderline_voxels = [tuple(v) for v in borderline_voxels]

In [None]:
# %%time
# thresshold = np.ceil(spherical_kernel(1).sum() / 2)

# reconstruction = mask_main.copy()
# new_reconstruction = np.pad((reconstruction > 0) * 10, 1)

# new_borderline_voxels = []

# for borderline_voxel in tqdm(borderline_voxels):
#     neighbours_in_reconstruction = 0
#     x, y, z = borderline_voxel
#     for dx, dy, dz in [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)]:
#         neighbour_x = x + dx + 1
#         neighbour_y = y + dy + 1
#         neighbour_z = z + dz + 1
#         if(new_reconstruction[neighbour_x, neighbour_y, neighbour_z] >= thresshold):
#             neighbours_in_reconstruction += 1
#             continue
#         else:
#             new_reconstruction[neighbour_x, neighbour_y, neighbour_z] += 1
#             if new_reconstruction[neighbour_x, neighbour_y, neighbour_z] >= thresshold:
#                 new_borderline_voxels.append((neighbour_x, neighbour_y, neighbour_z))
#     if neighbours_in_reconstruction < 6:
#         new_borderline_voxels.append(borderline_voxel)
# new_reconstruction = (new_reconstruction > thresshold)[1:-1, 1:-1, 1:-1]

In [42]:
def create_reconstruction(mask, starting_borderline_voxels, iters=5):
    reconstruction = (mask > 0) * 10
    thresshold = np.floor(spherical_kernel(1).sum() / 2)
    borderline_voxels = starting_borderline_voxels.copy()

    for it in range(iters):
        new_borderline_voxels = []
        for borderline_voxel in tqdm(borderline_voxels):
            neighbours_in_reconstruction = 0
            x, y, z = borderline_voxel
            for dx, dy, dz in [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)]:
                neighbour_x = x + dx + 1
                neighbour_y = y + dy + 1
                neighbour_z = z + dz + 1
                try:
                    reconstruction[neighbour_x, neighbour_y, neighbour_z]
                except:
                    neighbours_in_reconstruction += 1
                    continue
                if(reconstruction[neighbour_x, neighbour_y, neighbour_z] >= thresshold):
                    neighbours_in_reconstruction += 1
                    continue
                else:
                    reconstruction[neighbour_x, neighbour_y, neighbour_z] += 1
                    if reconstruction[neighbour_x, neighbour_y, neighbour_z] >= thresshold:
                        new_borderline_voxels.append((neighbour_x, neighbour_y, neighbour_z))
            if neighbours_in_reconstruction < 6:
                new_borderline_voxels.append(borderline_voxel)
        reconstruction = (reconstruction >= thresshold) * 10
        borderline_voxels = new_borderline_voxels
        print(f'Iteration {it} done')
    return reconstruction > 0

In [62]:
%%time
# colvolution_with_1 = convolve_with_ball(mask_main, 1)
borderline_voxels_mask = np.logical_and(mask_main > 0, colvolution_with_1 < 1)
borderline_voxels = np.argwhere(borderline_voxels_mask)
borderline_voxels = list(borderline_voxels)

Wall time: 7.64 s


In [95]:
# temporary
import warnings
warnings.filterwarnings("ignore")

In [113]:
def create_reconstruction_2(mask, starting_borderline_voxels, iters=5):
    reconstruction = mask > 0
    thresshold = np.floor(spherical_kernel(1).sum() / 2)
    borderline_voxels = starting_borderline_voxels.copy()
    
    padded_shape = (np.array(mask_main.shape) + np.array([2, 2, 2])).astype(np.int)

    for it in range(iters):       
        borderline_voxels_array = np.array(borderline_voxels)        
        votings = np.zeros(padded_shape)
        voted_for_reco_counter = np.zeros(len(borderline_voxels))
        for column in [0, 1, 2]:
            for value in [-1, 1]:
                voted_voxels = borderline_voxels_array.copy()
                voted_voxels[:, column] += value
                votings[list(voted_voxels.T)] += 1
                print('.', end='')
                voted_for_reco_counter += np.pad(reconstruction, 1, constant_values=1)[list(voted_voxels.T)]
                
        chosen_new_voxels_mask = votings[1:-1, 1:-1, 1:-1] >= thresshold
        chosen_new_voxels_mask[reconstruction > 0] = 0
        still_borderline = voted_for_reco_counter < 6
        borderline_voxels = list(borderline_voxels_array[still_borderline]) + list(np.argwhere(chosen_new_voxels_mask))
#         borderline_voxels = borderline_voxels + list(np.argwhere(chosen_new_voxels_mask))
        reconstruction = (reconstruction + chosen_new_voxels_mask) > 0
        print(f'iter {it + 1} done')
        
    return reconstruction

In [107]:
%%time
recco = create_reconstruction_2(mask=mask_main,
                                starting_borderline_voxels=borderline_voxels,
                                iters=3)

4277192
......4639389
iter 1 done
4639389
......6259520
iter 2 done
6259520
......7414711
iter 3 done
Wall time: 54 s


In [112]:
%%time
reco = create_reconstruction(mask=mask_main,
                             starting_borderline_voxels=borderline_voxels,
                             iters=2)

100%|█████████████████████████████████████████████████████████████████████| 4277192/4277192 [02:35<00:00, 27520.24it/s]
  0%|                                                                        | 4911/4637967 [00:00<03:02, 25433.86it/s]

Iteration 0 done


100%|█████████████████████████████████████████████████████████████████████| 4637967/4637967 [02:44<00:00, 28221.10it/s]


Iteration 1 done
Wall time: 5min 28s


In [114]:
visualize_mask_non_bin(recco)

In [115]:
%%time
recco = create_reconstruction_2(mask=mask_main,
                                starting_borderline_voxels=borderline_voxels,
                                iters=300)

......iter 1 done
......iter 2 done
......iter 3 done
......iter 4 done
......iter 5 done
......iter 6 done
......iter 7 done
......iter 8 done
......iter 9 done
......iter 10 done
......iter 11 done
......iter 12 done
......iter 13 done
......iter 14 done
......iter 15 done
......iter 16 done
......iter 17 done
......iter 18 done
......iter 19 done
......iter 20 done
......iter 21 done
......iter 22 done
......iter 23 done
......iter 24 done
......iter 25 done
......iter 26 done
......iter 27 done
......iter 28 done
......iter 29 done
......iter 30 done
......iter 31 done
......iter 32 done
......iter 33 done
......iter 34 done
......iter 35 done
......iter 36 done
......iter 37 done
......iter 38 done
......iter 39 done
......iter 40 done
......iter 41 done
......iter 42 done
......iter 43 done
......iter 44 done
......iter 45 done
......iter 46 done
......iter 47 done
......iter 48 done
......iter 49 done
......iter 50 done
......iter 51 done
......iter 52 done
......iter 53 done
..

In [116]:
rr = np.pad(recco, 1) # padding reconstruction to avoid padding later
np.save(source_dir + TREE_NAME + '/reconstruction-new', rr)

In [33]:
visualize_mask_non_bin(reco)

In [35]:
%%time
skeleton = skeletonize_3d(reco).astype(np.uint8)

Wall time: 7min 43s


In [37]:
visualize_mask_bin(skeleton)

In [None]:
visualize_mask_non_bin(reco)

In [139]:
spherical_kernel(1)

array([[[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]],

       [[0, 1, 0],
        [1, 1, 1],
        [0, 1, 0]],

       [[0, 0, 0],
        [0, 1, 0],
        [0, 0, 0]]], dtype=uint8)

In [18]:
%%time
lsd_trees = annihilate_jemiolas(mask_main, kernel_sizes=[10], iters=1)

Iteration 1 kernel 10 done
Iteration 1 ended successfully
Wall time: 2min 5s


In [25]:
def convolve_with_ball(img, ball_radius, dtype=np.uint16, normalize=True):
    kernel = spherical_kernel(ball_radius, filled=True)
    convolved = signal.convolve(img.astype(dtype), kernel.astype(dtype), mode='same')
    
    if not normalize:
        return convolved
    
    return (convolved / kernel.sum()).astype(np.float16)

In [26]:
%%time
lsd_trees = annihilate_jemiolas(mask_main, kernel_sizes=[10], iters=1)

Iteration 1 kernel 10 done
Iteration 1 ended successfully
Wall time: 2min 45s


In [17]:
7*27*130 / 60

409.5

In [10]:
%%time

kernel_sizes = {
    'P01': range(0, 28),
    'P04': range(0, 26),
    'P05': range(0, 26),
    'P12': range(0, 26),
}

number_of_iterations = {
    'P01': 7,
    'P04': 2,
    'P05': 3,
    'P12': 3,
}

lsd_trees = annihilate_jemiolas(mask_main, 
                                kernel_sizes=kernel_sizes.get(TREE_NAME, range(0, 13)), 
                                iters=number_of_iterations.get(TREE_NAME, 3))

Iteration 1 kernel 0 done
Iteration 1 kernel 1 done
Iteration 1 kernel 2 done
Iteration 1 kernel 3 done
Iteration 1 kernel 4 done
Iteration 1 kernel 5 done
Iteration 1 kernel 6 done
Iteration 1 kernel 7 done
Iteration 1 kernel 8 done
Iteration 1 kernel 9 done
Iteration 1 kernel 10 done
Iteration 1 kernel 11 done
Iteration 1 kernel 12 done
Iteration 1 kernel 13 done
Iteration 1 kernel 14 done
Iteration 1 kernel 15 done
Iteration 1 kernel 16 done
Iteration 1 kernel 17 done
Iteration 1 kernel 18 done
Iteration 1 kernel 19 done
Iteration 1 kernel 20 done
Iteration 1 kernel 21 done
Iteration 1 kernel 22 done
Iteration 1 kernel 23 done
Iteration 1 kernel 24 done
Iteration 1 kernel 25 done
Iteration 1 kernel 26 done
Iteration 1 kernel 27 done
Iteration 1 ended successfully
Iteration 2 kernel 0 done
Iteration 2 kernel 1 done
Iteration 2 kernel 2 done
Iteration 2 kernel 3 done
Iteration 2 kernel 4 done
Iteration 2 kernel 5 done
Iteration 2 kernel 6 done
Iteration 2 kernel 7 done
Iteration 2 ker

## Verifying obtained reconstruction 

In [None]:
# if the reconstruction looks bad try using results of previous iterations
# if the skelecon still has mistletoes try increasing number of iterations
reconstruction = (lsd_trees[-1] > 0).astype(np.uint8)
visualize_mask_non_bin(reconstruction) # check for holes
visualize_skeleton(reconstruction) # check for mistletoes
visualize_addition(mask_main, reconstruction) # check for anomalies

## Saving the reconstruction

In [12]:
reconstruction = np.pad(reconstruction, 1) # padding reconstruction to avoid padding later
np.save(source_dir + TREE_NAME + '/reconstruction', reconstruction)

In [None]:
%%time
visualize_skeleton(reconstruction)

In [27]:
source_dir = './data/'
reconstruction = np.load(source_dir + TREE_NAME + '/reconstruction.npy')
# visualize_skeleton(reconstruction)

In [37]:
visualize_addition(mask_main, reconstruction[1:-1, 1:-1, 1:-1])

KeyboardInterrupt: 