Code of semi-automated segmentation of spheroids in MoNNets.

@uthor: Raju Tomer.

In [6]:
import skimage as sk
from skimage import io
import skimage.external.tifffile as tff
import os
import numpy as np
from scipy import ndimage as ndi
import scipy as spy
import skimage as sk
import skimage.filters as skf
import skimage.morphology as skm
import skimage.measure as skmes
import skimage.segmentation as sks
import matplotlib.patches as mptch
import scipy.io as sio
import scipy as scpy
import glob
import skimage.transform as skt
import pickle
import scipy.signal as spy_sig
from skimage import transform as tf
import matplotlib.pyplot as plt

In [2]:
#Functions to segment
def segment_nsps(im_use, marker_up=0.5, marker_low=0.5, sz_thres=100, clear_borders=False, vmin=1500, vmax=3000):
    im_seg = np.multiply(ndi.gaussian_filter(im_use, sigma=2), 0.6) - np.multiply(ndi.gaussian_filter(im_use, sigma=20), 0.4)
    im_seg[im_seg<0]=0
    th = skf.threshold_isodata(im_seg)
    print(th)
    elevation_map = sk.filters.sobel(im_seg)
    markers = np.zeros_like(im_seg)
    markers[im_seg <= marker_low*th] = 1
    markers[im_seg > marker_up*th] = 2
    seg = skm.watershed(elevation_map, markers)
    seg = ndi.binary_fill_holes(seg - 1)
    #seg = skm.opening(seg, skm.square(7))
    seg = skm.closing(seg, skm.square(3))
    if (clear_borders):
        seg = sks.clear_border(seg)
    label_objects, nr_labels = ndi.label(seg) # to remove small objects
    sizes = np.bincount(label_objects.ravel())
    mask_sizes = sizes > sz_thres
    #print(sizes)
    mask_sizes[0] = 0
    seg_cleaned = mask_sizes[label_objects]
    seg_cleaned = skm.erosion(seg_cleaned, skm.square(1))
    #seg_cleaned = skm.erosion(seg_cleaned)
    #seg_cleaned = skm.erosion(seg_cleaned,)
    im_lab, tmp = ndi.label(seg_cleaned>0)

    #print(im_lab.shape)
    fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(12, 7), sharex=True, sharey=True)
    ax = axes.ravel()

    ax[0].imshow(im_use, cmap=plt.cm.gray, interpolation='nearest', vmin=vmin, vmax=vmax)
    ax[0].set_title('im_max')

    ax[1].imshow(im_use, cmap=plt.cm.gray, interpolation='nearest', vmin=vmin, vmax=vmax)
    ax[1].set_title('im_max')

    ax[2].imshow(markers, cmap=plt.cm.gray, interpolation='nearest')
    ax[2].set_title('markers')

    ax[3].imshow(seg, cmap=plt.cm.gray, interpolation='nearest')
    ax[3].set_title('seg')

    ax[4].imshow(seg_cleaned, cmap=plt.cm.gray, interpolation='nearest')
    ax[4].set_title('seg_cleaned')

    ax[5].imshow(im_lab, interpolation='nearest')
    ax[5].set_title('im_lab')

    for region in skmes.regionprops(im_lab):
        minr, minc, maxr, maxc = region.bbox
        rect = mptch.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor='red', linewidth=2)
        ax[0].add_patch(rect)
        c = region.centroid
        ax[3].text(c[1].astype(int), c[0].astype(int), region['label'],bbox={'facecolor':'white', 'alpha':0.8, 'pad':5})
    plt.show()
    im_lab = im_lab.astype('uint16')
    return im_lab

def run_segmentation(from_key, till_key, dict_in_fn, dict_im_lab,
                    clear_borders=True, marker_up=0.97, marker_low=0.97, sz_thres=80, vmin=1500, vmax=5000):
    for key in dict_in_fn.keys():
        if ((key >= from_key) & (key <= till_key)):
            print('###################')
            print('dict_key =', key)
            print(dict_in_fn[key])
            im_max = tff.imread(os.path.join(dict_in_fn[key]))
            print(im_max.shape)
            dict_im_lab[key] = np.copy(segment_nsps(im_use=np.copy(im_max), clear_borders=clear_borders, 
                                                    marker_up= marker_up, marker_low=marker_low, sz_thres=sz_thres, vmin=vmin, vmax=vmax))

def remove_labels(to_remove, d_key, dict_im_lab, dict_in_fn, vmin=1500, vmax=3000):
    im_l = np.copy(dict_im_lab[d_key])
    for k in range(len(to_remove)):
        im_l[im_l==to_remove[k]] = 0
    im_l, tmp = ndi.label(im_l>0)

    fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(12, 7), sharex=True, sharey=True)
    ax = axes.ravel()
    im_m = tff.imread(dict_in_fn[d_key])
    ax[0].imshow(im_m, cmap=plt.cm.gray, interpolation='nearest', vmin=vmin, vmax=vmax)
    ax[0].set_title('im_max')

    ax[1].imshow(im_m, cmap=plt.cm.gray, interpolation='nearest', vmin=vmin, vmax=vmax)
    ax[1].set_title('im_max')

    ax[2].imshow(im_l>0, interpolation='nearest')
    ax[2].set_title('fixed seg')

    ax[3].imshow(im_l, interpolation='nearest')
    ax[3].set_title('fixed im_lab')

    for region in skmes.regionprops(im_l):
        minr, minc, maxr, maxc = region.bbox
        rect = mptch.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor='red', linewidth=2)
        ax[0].add_patch(rect)
        c = region.centroid
        ax[2].text(c[1].astype(int), c[0].astype(int), region['label'],bbox={'facecolor':'white', 'alpha':0.8, 'pad':5})
    plt.show()
    return im_l

def calc_mean_f(im, im_lab):
    index = np.unique(im_lab)
    index = index[1:]
    f_NSp = np.zeros((index.shape[0], im.shape[0]), dtype=float)
    for i in range(im.shape[0]):
        t = np.copy(im[i,:,:])
        bg_sig = t[im_lab==0]
        bg = np.mean(bg_sig[:])
        t = t - bg
        t[t<0] = 0
        f_NSp[:,i] = ndi.mean(t, labels=im_lab, index=index)
    return f_NSp


In [3]:
#Initialize dictionary variables
dict_im_lab = {} #labelled image data, indexed by "key"
dict_in_fn = {} #input file paths
dict_out_fn = {} #output file paths

In [13]:
# Read all the file names (Max projection images generated by GenMaxProj.ipynb) into dict
# Main directory containing all DIV folders
rootdir = r'E:\Raju\MoNNet\BT'
i=-1
for root2, dirs2, files2 in os.walk(rootdir):
    for fn in sorted(files2):
        if ('MAX_' in fn):
            i = i+1
            print('###################')
            print('dict_key =', i)
            print(fn)
            dict_in_fn[i] = os.path.join(rootdir, fn)
            t = fn.replace('MAX_', '')
            dict_out_fn[i] = os.path.join(rootdir, 'im_lab_' + t)
            #im_max = tff.imread(os.path.join(rootdir, dn, fn))
            #dict_im_lab[i] = np.copy(segment_nsps(im_max, clear_borders=True, marker_up=0.97, marker_low=0.97, sz_thres=80, vmax=5000))
print(len(dict_in_fn))

###################
dict_key = 0
MAX_211222_D1WB_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 1
MAX_211222_D1WC_bs_DIV15_crp_moco_4xRed.tif
###################
dict_key = 2
MAX_211222_D2WA_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 3
MAX_211222_D2WD_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 4
MAX_211222_D3WA_bs_DIV15_crp_Red4x.tif
###################
dict_key = 5
MAX_211222_D3WD_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 6
MAX_211222_D4WA_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 7
MAX_211222_D4WB_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 8
MAX_211222_D4WD_bs_DIV15_moco_crp_Red4x.tif
###################
dict_key = 9
MAX_211222_D5WB_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 10
MAX_211222_D5WC_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 11
MAX_211222_D6WA_bs_DIV15_crp_moco_Red4x.tif
###################
dict_key = 12
MAX_211222_D6WD_bs_DIV15_crp_moco_Red4x.tif
#

In [None]:
### do for all keys and manually validate / adjust parameters
from_key = 0
till_key = 15
run_segmentation(from_key, till_key, dict_in_fn, dict_im_lab,
                    clear_borders=True, marker_up=1., marker_low=.8, sz_thres=85, vmin=1500, vmax=19000)

d_key = from_key

## the below code is to remove any labelled objects

# to_remove = (1,3,22,4,13,18,19,21,20)
# vmax=5000
# dict_im_lab[d_key] = remove_labels(to_remove, d_key, dict_im_lab, dict_in_fn, vmax=vmax)


In [33]:
# Dictionary for storing all F value extractions
dict_f_NsP = {}

In [None]:
# Calculate F values
for key in dict_out_fn.keys():
    try:
        print('key:', key, ', FN:', dict_out_fn[key])
        fn = dict_out_fn[key]
        fn = fn.replace('\\im_lab_', '\\')
        im = tff.imread(fn)
        dict_f_NsP[key] = calc_mean_f(im, dict_im_lab[key])
    except:
        print('Error in Key: ', key, '  FN: ', fn)
        pass

In [35]:
##save files
var_fn = r'X:\People\Raju\Data\dict_f_NsP_BT.pickle'
pickle_out = open(var_fn,"wb")
pickle.dump(dict_f_NsP, pickle_out)
pickle_out.close()

var_fn = r'X:\People\Raju\Data\dict_im_lab_BT.pickle'
pickle_out = open(var_fn,"wb")
pickle.dump(dict_im_lab, pickle_out)
pickle_out.close()

var_fn = r'X:\People\Raju\Data\dict_in_fn_BT.pickle'
pickle_out = open(var_fn,"wb")
pickle.dump(dict_in_fn, pickle_out)
pickle_out.close()

In [36]:
#Initialize variables for after treatment aligned images (done with Registration notebook)
dict_in_fn_AT = {}
dict_out_fn_AT = {}

In [None]:
# Read all the file names into dict

# Main directory containing all DIV folders
rootdir = r'E:\Raju\MoNNet\AT_aln'
i=-1
for root2, dirs2, files2 in os.walk(rootdir):
    for fn in sorted(files2):
        if ('MAX_' in fn):
            i = i+1
            print('###################')
            print('dict_key =', i)
            print(fn)
            dict_in_fn_AT[i] = os.path.join(rootdir, fn)
            t = fn.replace('MAX_', '')
            dict_out_fn_AT[i] = os.path.join(rootdir, 'im_lab_' + t)
print(len(dict_in_fn_AT))

In [38]:
dict_f_NsP_AT = {}

In [None]:
# Calculate F values
for key in dict_out_fn_AT.keys():
    try:
        print('key:', key, ', FN:', dict_out_fn_AT[key])
        fn = dict_out_fn_AT[key]
        fn = fn.replace('\\im_lab_', '\\')
        im = tff.imread(fn)
        dict_f_NsP_AT[key] = calc_mean_f(im, dict_im_lab[key])
    except:
        print('Error in Key: ', key, '  FN: ', fn)
        pass

In [None]:
print(len(dict_f_NsP_AT))

In [41]:
var_fn = r'X:\People\Raju\Data\dict_f_NsP_AT.pickle'
pickle_out = open(var_fn,"wb")
pickle.dump(dict_f_NsP_AT, pickle_out)
pickle_out.close()

var_fn = r'X:\People\Raju\Data\dict_in_fn_AT.pickle'
pickle_out = open(var_fn,"wb")
pickle.dump(dict_in_fn_AT, pickle_out)
pickle_out.close()