In [None]:
import numpy as np

# package for 3d visualization
from itkwidgets import view                              
from aicssegmentation.core.visual import seg_fluo_side_by_side, single_fluorescent_view, segmentation_quick_view
import matplotlib.pyplot as plt

# package for io
import os
from shutil import rmtree
import skimage
from skimage.io import imread, imsave
from aicsimageio import AICSImage

# function for core algorithm
from aicssegmentation.core.vessel import filament_2d_wrapper
from aicssegmentation.core.pre_processing_utils import intensity_normalization, image_smoothing_gaussian_3d
from skimage.morphology import disk, dilation, erosion, closing, opening, remove_small_objects, remove_small_holes
from skimage.segmentation import watershed
from skimage.measure import label, regionprops
from skimage.filters import difference_of_gaussians as dog_filter

### Load images

In [None]:
original_membrane_img = imread('./test_membrane_zstack.tif')
output_dir = './segmentation/'
print(original_membrane_img.shape)

### Check original image and decide the range in z

In [None]:
bottom_z, top_z = 80, 150 # open in ImageJ and find the index of bottom and top of the cell layer
# bottom_z, top_z = 0, original_membrane_img.shape[1] # if use the entire z-stack

membrane_img = original_membrane_img[bottom_z:top_z]

num_z = membrane_img.shape[0]

print(membrane_img.shape)
# z, y, x

### Normalization and smoothing

In [None]:
%%capture
intensity_scaling_param = [3000]
gaussian_smoothing_sigma = 1
mid_z = num_z // 2

# intensity normalization
membrane_img_norm = intensity_normalization(membrane_img, scaling_param=intensity_scaling_param)

# smoothing with 2d gaussian filter slice by slice
membrane_img_smooth = image_smoothing_gaussian_3d(membrane_img_norm, sigma=gaussian_smoothing_sigma)

imsave(output_dir+'smooth_membrane_mid_z.tif', membrane_img_smooth[mid_z])
imsave(output_dir+'smooth_membrane_all_z.tif', membrane_img_smooth)

In [None]:
view(single_fluorescent_view(membrane_img_smooth))

### optional image processing routines

In [None]:
# extracted = closing(membrane_img_smooth[mid_z], disk(10))
# view(extracted)

In [None]:
# # from skimage.feature import blob_dog, blob_doh, blob_log
# extracted = blob_log(membrane_img_smooth[mid_z], 10, 100)
# view(extracted)

### 2D membrane contour segmentation

In [None]:
view(membrane_img_smooth[mid_z])

In [None]:
# Smaller cutoff_x may yield more filaments, especially detecting more dim ones and thicker segmentation.
# Larger cutoff_x could be less permisive and yield less filaments and slimmer segmentation.

f2_param = [[8, 0.4]]
membrane_mid_z = filament_2d_wrapper(membrane_img_smooth[mid_z], f2_param)
view(segmentation_quick_view(membrane_mid_z[np.newaxis])) # has to be 3D, click the projection buttion in the pannel

In [None]:
membrane_mid_z_filtered = remove_small_objects(membrane_mid_z, min_size=1000, in_place=False)
view(segmentation_quick_view(membrane_mid_z_filtered))

### optional image processing routines

In [None]:
dilated = dilation(membrane_mid_z_filtered, selem=disk(15))
view(segmentation_quick_view(dilated))

In [None]:
closed = closing(dilated, disk(10))
view(segmentation_quick_view(dilated))

### create seed for watershed using the automatically generated contour

In [None]:
seed_z = label(~closed) # invert and then label each connected cluster
imsave(output_dir+'auto_seed.tif', seed_z.astype('uint8'))

In [None]:
# Open in ImageJ and convert the label of background to default 0
for bg_index in [1,3]:
    seed_z[(seed_z==bg_index)] = 0
imsave(output_dir+'auto_seed.tif', seed_z.astype('uint8'))

In [None]:
# Open in ImageJ and correct the seed
manual_seed = imread(output_dir+'manual_seed.tif')
# manual_seed = imread(output_dir+'auto_seed.tif') # if no correction is needed

manual_seed = label(manual_seed)
manual_seed = manual_seed[np.newaxis]
view(manual_seed[0])

In [None]:
empty_seed = np.full(manual_seed.shape, 0, dtype=manual_seed.dtype)

# Option 1: only middle plane is used as seed
zstack_seed = []
for z in range(num_z):
    if z == 0:
        zstack_seed = empty_seed
    elif z == mid_z:
        zstack_seed = np.concatenate([zstack_seed, manual_seed])
    else:
        zstack_seed = np.concatenate([zstack_seed, empty_seed])
        
# Option 2: the middle plane seed is copied to all other stacks to make a cylindrical seed
# zstack_seed = []
# for z in range(num_z):
#     if z == 0:
#         zstack_seed = manual_seed
#     else:
#         zstack_seed = np.concatenate([zstack_seed, manual_seed])

zstack_seed.shape

In [None]:
imsave(output_dir+'final_seed_3d.tif', zstack_seed.astype('uint8'))

### use the final seed to run 3D watershed

In [None]:
# takes some time
watershed_mask = watershed(membrane_img_smooth, markers=zstack_seed, watershed_line=True, connectivity=1)

In [None]:
imsave(output_dir+'watershed_mask.tif', watershed_mask)

view(label(watershed_mask))

In [None]:
num_label = max(watershed_mask.ravel())
print('Labels are 0 to '+str(num_label))

#### manually check each label and pick labels

In [None]:
label_num = 1
mask = watershed_mask == label_num
print('Label '+str(label_num))
viewer_label = view(membrane_view(mask))
viewer_label

In [None]:
# save all cell crops
if os.path.isdir(output_dir+'individual_labels/'):
    rmtree(output_dir+'individual_labels/')
os.mkdir(output_dir+'individual_labels/')

for label_num in range(num_label+1):
    mask = watershed_mask == label_num
    imsave(output_dir+'individual_labels/label_'+str(label_num)+'.png', mask[len(mask)//2])

#### extract desired labels and check

In [None]:
# enter desired cell labels
invalid_labels = [11, 0, 1, 15, 17, 25, 26, 27]
# valid_labels = []

In [None]:
checked_mask = watershed_mask.copy()
# clear the other labels
for l in range(num_label+1):
    if l in invalid_labels:
#     if l not in valid_labels:
        mask = watershed_mask == l
        checked_mask[mask] = 0
        
# sort and relabel from 1 to N with 0 as bg
checked_mask = label(checked_mask)
imsave(output_dir+'checked_mask_3d.tif', checked_mask.astype('uint8'))
view(checked_mask)

In [None]:
# filtered_mask = checked_mask.copy()

# ## do closing to close holes if needed, will cause index problem
# for z in range(len(checked_mask)):
#     filtered_mask[z] = closing(checked_mask[z], disk(10))

# final_mask = label(filtered_mask)
# imsave(output_dir+'final_mask.tif', final_mask.astype('uint8'))
# view(final_mask)

In [None]:
final_mask = checked_mask
num_cell = max(final_mask.ravel())
num_cell

In [None]:
# save selected crops
if os.path.isdir(output_dir+'final_labels/'):
    rmtree(output_dir+'final_labels/')
os.mkdir(output_dir+'final_labels/')

for cell_num in range(1,num_cell+1):
    mask = final_mask == cell_num
    imsave(output_dir+'final_labels/cell_'+str(cell_num)+'.png', skimage.img_as_ubyte(mask[mid_z]))

### crop mitochondria signal for each segmented cell

In [None]:
## reload data if closed kernel
# checked_mask = imread(output_dir+'checked_mask_3d.tif')
# num_cell = max(checked_mask.ravel())

In [None]:
for n in range(1, num_cell+1): # label 0 is background

    count = 0

    # make dir for each cell
    if n < 10:
        cell_dir = output_dir+'cell_0'+str(n)+'/'
    else:
        cell_dir = output_dir+'cell_'+str(n)+'/'
    print('Start cropping cell number '+str(n))

    # remove old dir (be careful)
    if os.path.isdir(cell_dir):
        rmtree(cell_dir)
        print('All old files are removed')
    os.mkdir(cell_dir)

    mask = (checked_mask == n)

    for frame in range(3):
        mito_img = imread('./mito_frame_'+str(frame)+'.tif')
        mito_img = mito_img[bottom_z:top_z]
        
        region =  regionprops(label(mask), mito_img)[0]
        minz, miny, minx, maxz, maxy, maxx = region.bbox

        mito_cell = mito_img.copy()
        mito_cell[~mask] = 0 # clear signal outside of the cell

        # crop the region around the mask
        cropped_mito_cell = mito_cell[:, miny:maxy, minx:maxx] # z has already been determined before so no crop

        ## mitochondria smoothing (optional)
        #cropped_mito_cell = image_smoothing_gaussian_3d(cropped_mito_cell, sigma=1)
        #cropped_mito_cell = dog_filter(cropped_mito_cell, gaussian_smoothing_sigma, 0.5)

        # make dir for each frame
        frame_dir = cell_dir+'frame_'+str(frame)+'/'

        # remove old dir (be careful)
        if os.path.isdir(frame_dir):
            rmtree(frame_dir)
        os.mkdir(frame_dir)

        count += 1

        imsave(frame_dir+'frame_'+str(frame)+'.tif', cropped_mito_cell)

    print('Done cropping '+str(count)+' frames')