In [21]:
import napari
import numpy as np
from numpy.lib.format import open_memmap
import myfunctions as mf
import os
import matplotlib.pyplot as plt
from skimage.io import imshow
from skimage.measure import regionprops, label
import seaborn as sns
import pandas as pd
import time as clock
from tqdm import tqdm

In [110]:
def remove_long_agglomerates(sequence_mask, max_length=50):
    slices = range(sequence_mask.shape[0])
    unique_labels, counts = np.unique(sequence_mask, return_counts=True)
    unique_labels = unique_labels[np.argsort(counts)][::-1]
    labels = [np.unique(sequence_mask[z]) for z in slices]
    for label in unique_labels[2:]:
        if sum([label in labels[z] for z in slices]) > max_length:
            sequence_mask[sequence_mask == label] = 0
    return sequence_mask

In [119]:
def propagate_labels(previous_mask, current_mask):
    max_label = np.max(previous_mask)
    current_mask[current_mask > 0] += max_label
    unique_labels, label_counts = np.unique(previous_mask, return_counts=True)
    ordered_labels = unique_labels[np.argsort(label_counts)]
    update = dict()

    tic_cycle = clock.perf_counter()
    for previous_mask_label in ordered_labels[:-1]:
        current_mask_labels = current_mask[previous_mask == previous_mask_label]
        if np.any(current_mask_labels):
            for current_slice_label in np.unique(current_mask_labels):
                if current_slice_label:
                    update[current_slice_label] = previous_mask_label
    print(f"cycle time: {clock.perf_counter() - tic_cycle:0.6f} seconds")
    tic_cycle = clock.perf_counter()
    print(update)
    for current_slice_label, previous_mask_label in update.items():
        current_mask[current_mask == current_slice_label] = previous_mask_label
    print(f"cycle time: {clock.perf_counter() - tic_cycle:0.6f} seconds")

    new_labels = np.unique(current_mask[current_mask > np.max(previous_mask)])
    for i, new_label in enumerate(new_labels):
        current_mask[current_mask == new_label] = max_label + i + 1
    return current_mask

In [73]:
def old_propagate_labels(previous_mask, current_mask, propagation_threshold=0):
    max_label = np.max(previous_mask)
    current_mask[current_mask > 0] += max_label
    unique_labels, label_counts = np.unique(previous_mask, return_counts=True)
    ordered_labels = unique_labels[np.argsort(label_counts)]

    tic_cycle = clock.perf_counter()
    for previous_mask_label in ordered_labels[1:]:
        current_mask_labels = current_mask[previous_mask == previous_mask_label]
        if np.any(current_mask_labels):
            bincount = np.bincount(current_mask_labels)
            bincount[0] = 0
            labels_to_change = np.where(bincount > propagation_threshold)[0]
            for current_slice_label in labels_to_change:
                current_mask[current_mask == current_slice_label] = previous_mask_label
    print(f"cycle time: {clock.perf_counter() - tic_cycle:0.6f} seconds")

    new_labels = np.unique(current_mask[current_mask > np.max(previous_mask)])
    for i, new_label in enumerate(new_labels):
        current_mask[current_mask == new_label] = max_label + i + 1
    return current_mask

In [107]:
def segment3D(volume, threshold, smallest_volume=10, filtering=True):
    volume_mask = np.zeros_like(volume, dtype=np.ushort)
    volume_mask[0,:,:] = mf.mask(volume[0,:,:], threshold)
    # masking of current slice and forward propagation from the first slice
    for i in range(1, volume.shape[0]):
        volume_mask[i,:,:] = mf.mask(volume[i,:,:], threshold)
        volume_mask[i,:,:] = mf.propagate_labels(volume_mask[i-1,:,:], volume_mask[i,:,:], forward=True)
    # backward propagation from the last slice
    for i in range(volume_mask.shape[0]-1, 0, -1):
        volume_mask[i-1,:,:] = mf.propagate_labels(volume_mask[i,:,:], volume_mask[i-1,:,:], forward=False)
    # removal of the agglomerates with volume smaller than smallest_volume and of the agglomerates that are not present in neighboring slices
    if filtering:
        volume_mask = mf.remove_isolated_agglomerates(volume_mask)
        volume_mask = mf.remove_small_agglomerates(volume_mask, smallest_volume)
        volume_mask = remove_long_agglomerates(volume_mask)
    return volume_mask

In [75]:
OS = 'MacOS'
exp = mf.exp_list()[0]

hypervolume = open_memmap(os.path.join(mf.OS_path(exp, OS), 'hypervolume.npy'), mode='r')
hypervolume_mask = open_memmap(os.path.join(mf.OS_path(exp, OS), 'hypervolume_mask.npy'), mode='r')
print(hypervolume_mask.shape)
print(hypervolume.shape)

end_time=220
skip180=True
smallest_3Dvolume=20
filtering3D=True

(55, 270, 500, 500)
(55, 270, 500, 500)


In [111]:
# defining the time steps for the current experiment
start_time = mf.exp_start_time()[mf.exp_list().index(exp)]
time_steps = range(start_time, end_time+1, 2) if skip180 else range(start_time, end_time+1)
time_index = range(len(time_steps))
# dealing with the first volume
previous_volume = hypervolume[30]
# evaluating the threshold on the first volume
threshold = mf.find_threshold(previous_volume)  
# segmenting the first volume
previous_mask = segment3D(previous_volume, threshold, smallest_volume=smallest_3Dvolume, filtering=filtering3D)
# reassigning the labels after the filtering
for new_label, old_label in enumerate(np.unique(previous_mask)):
    if old_label:
        previous_mask[previous_mask == old_label] = new_label


Finding threshold...
Threshold=1.88 found in 1.15 s



In [120]:
# segmenting the remaining volumes and propagating the labels from previous volumes
time = 31
current_volume = hypervolume[time]
tic = clock.perf_counter()
current_mask = segment3D(current_volume, threshold, smallest_volume=smallest_3Dvolume, filtering=filtering3D)
toc = clock.perf_counter()
print(f'3D segmentation completed in {toc-tic:.6f} seconds\n\n')
tic = clock.perf_counter()
current_mask = propagate_labels(previous_mask, current_mask)
toc = clock.perf_counter()
print(f'Propagation completed in {toc-tic:.6f} seconds')

3D segmentation completed in 11.538247 seconds


cycle time: 27.995493 seconds
{4562: 1011, 6527: 1705, 3422: 528, 4556: 1005, 2773: 328, 2184: 105, 3930: 765, 6778: 1801, 4170: 861, 2111: 79, 3805: 710, 6736: 1780, 5027: 1203, 5808: 1429, 2557: 295, 4301: 905, 6562: 1722, 4417: 950, 3097: 430, 5184: 1224, 2322: 116, 4411: 948, 4049: 807, 4974: 1159, 6516: 1704, 6062: 1516, 3499: 599, 5865: 1465, 2868: 362, 6765: 1793, 6028: 1541, 6489: 1693, 5494: 1338, 2193: 102, 3386: 560, 5767: 1422, 3202: 524, 3897: 753, 4799: 1093, 5234: 1275, 6298: 1598, 2933: 388, 6617: 1747, 1904: 14, 5824: 1434, 6534: 1707, 3859: 704, 4802: 1078, 4914: 1134, 4857: 1118, 4547: 999, 2924: 382, 4538: 997, 4434: 953, 2567: 247, 4589: 1021, 5168: 1213, 5823: 1433, 6470: 1684, 6818: 1810, 6101: 1536, 3265: 443, 4111: 837, 3685: 667, 5685: 1402, 1980: 35, 2794: 336, 6349: 1643, 2393: 185, 5165: 1257, 3538: 610, 4798: 1091, 2797: 337, 2708: 318, 6265: 1597, 5255: 1260, 5405: 1310, 3953: 775, 5117: 1217, 4708: 1083, 6

In [121]:
viewer = napari.Viewer()

daje = np.zeros((2, previous_mask.shape[0], previous_mask.shape[1], previous_mask.shape[2]), dtype=np.ushort)
daje[0] = previous_mask
daje[1] = current_mask
images = [viewer.add_image(hypervolume[30:32], name='Volume', opacity=0.4)]
labels = [viewer.add_labels(daje, name='Labels', blending='additive', opacity=0.8)]
labels2 = [viewer.add_labels(hypervolume_mask[30:32], name='Labels2', blending='additive', opacity=0.8)]

settings = napari.settings.get_settings()
settings.application.playback_fps = 5
viewer.dims.current_step = (0, 0)

In [115]:
viewer = napari.Viewer()
images = [viewer.add_image(hypervolume[31], name='Volume', opacity=0.4)]
labels = [viewer.add_labels(current_mask, name='Labels', blending='additive', opacity=0.8)]
labels2 = [viewer.add_labels(hypervolume_mask[31], name='Labels2', blending='additive', opacity=0.8)]
settings = napari.settings.get_settings()
settings.application.playback_fps = 5
viewer.dims.current_step = (0, 0)