In [18]:
import numpy as np                                  # type: ignore
from skimage.measure import label, regionprops      # type: ignore
from skimage.io import imread, imshow               # type: ignore
from skimage.filters import threshold_yen           # type: ignore
import glob
from dataclasses import dataclass
import napari
import os
import time
import myfunctions as mf
from tqdm import tqdm

In [31]:
def mask(image, threshold):
    return np.vectorize(label, signature='(n,m)->(n,m)')(image > threshold)

def remove_small_agglomerates(sequence_mask, smallest_volume):
    bincount = np.bincount(sequence_mask.flatten())
    sequence_mask[np.isin(sequence_mask, np.where(bincount < smallest_volume))] = 0
    return sequence_mask

def propagate_3Dlabels(previous_mask, current_mask, forward=True):
    if forward:
        current_mask[current_mask > 0] = current_mask[current_mask > 0] + np.max(previous_mask)
    unique_labels, label_counts = np.unique(previous_mask, return_counts=True)
    ordered_labels = unique_labels[np.argsort(label_counts)]
    for previous_slice_label in ordered_labels:
        if previous_slice_label == 0:
            continue
        bincount = np.bincount(current_mask[previous_mask == previous_slice_label])
        if len(bincount) <= 1:
            continue
        bincount[0] = 0
        current_slice_label = np.argmax(bincount)
        current_mask[current_mask == current_slice_label] = previous_slice_label
    return current_mask

def segment3D(sequence, threshold, smallest_volume=100, filtering=True):

    sequence_mask = np.zeros_like(sequence).astype(int)
    sequence_mask[0,:,:] = mask(sequence[0,:,:], threshold)
    # masking of current slice and forward propagation from the first slice
    for i in range(1, sequence.shape[0]):
        sequence_mask[i,:,:] = mask(sequence[i,:,:], threshold)
        sequence_mask[i,:,:] = propagate_3Dlabels(sequence_mask[i-1,:,:], sequence_mask[i,:,:], forward=True)
    # backward propagation from the last slice
    for i in range(sequence_mask.shape[0]-1, 0, -1):
        sequence_mask[i-1,:,:] = propagate_3Dlabels(sequence_mask[i,:,:], sequence_mask[i-1,:,:], forward=False)
    # removal of the agglomerates with volume smaller than smallest_volume
    if filtering:
        sequence_mask = remove_small_agglomerates(sequence_mask, smallest_volume)
    return sequence_mask

In [32]:
sequence = mf.read_3Dsequence('P28A_FT_H_Exp1', time=150)
viewer = napari.Viewer()
_ = viewer.add_image(sequence, opacity=0.4)
threshold = 2

In [33]:
segmented_sequence = mf.segment3D(sequence, threshold)

In [34]:
segmented_sequence = segment3D(sequence, threshold)

In [None]:
_ = viewer.add_labels(segmented_sequence, blending='additive', opacity=0.8)