In [1]:
import sys
# insert at 1, 0 is the script path (or '' in REPL)
# These top lines are critical for import from another folder
sys.path.insert(0, 'C:/Users/14432/OneDrive/Research/Projects/A549_144hr/src/')
sys.path.insert(1, 'C:/Users/14432/OneDrive/Research/Projects/A549_144hr/src/memes/')

import warnings
warnings.filterwarnings("ignore")

from IPython.display import display, clear_output
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
import ipywidgets as widgets
import random
import pandas as pd
import numpy as np 
from math import pi
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
from pilutil import toimage

import os
from os.path import basename

from skimage.io import imread, imsave
from skimage.segmentation import watershed, clear_border
from skimage.feature import peak_local_max
from skimage.morphology import remove_small_objects, local_maxima, h_maxima, disk, dilation
from skimage.measure import regionprops, label
from skimage.color import label2rgb
from skimage.exposure import equalize_adapthist
from skimage import filters, img_as_float32, img_as_ubyte
from skimage.draw import rectangle_perimeter

import hj_util

In [2]:
def color_num(labels):
    
    """draw color numbers on the center of the label"""
    label_rgb=label2rgb(labels,bg_label=0) # Return an RGB image where color-coded labels are painted over the image.
    img_rgb= toimage(label_rgb)
    base = img_rgb.convert('RGBA')
    # make a blank image for the text, initialized to transparent text color
    txt = Image.new('RGBA', base.size, (255,255,255,0))
    # get a font
    fnt = ImageFont.truetype('arial.ttf', 40)
    # get a drawing context
    d = ImageDraw.Draw(txt)
    for region in regionprops(labels):
        cx=int(region.centroid[1])     
        cy=int(region.centroid[0])
        d.text((cx,cy),str(labels[cy][cx]),font=fnt,fill=(255,255,255,255)) 
    out = Image.alpha_composite(base, txt)       
    return out

In [3]:
#this is faster than the one from lineage mapper,only with matrix calculation
def compute_overlap_matrix(seg1,seg2):
    """Calc neighboring img labels overlap region, as a matrix."""
    nb_cell_1=np.amax(seg1)
    nb_cell_2=np.amax(seg2)

    seg_overlap=np.zeros((nb_cell_1,nb_cell_2))
    for obj_idx1 in range(nb_cell_1):
        obj_num1=obj_idx1+1
        sc_img=(seg1==obj_num1)
        ol_judge=np.logical_and(sc_img,seg2)
        ol_value=np.multiply(ol_judge,seg2)
        ol_obj2=np.unique(ol_value).tolist()
        #ol_obj2=ol_obj2[ol_obj2!=0]
        ol_obj2.remove(0)
        if len(ol_obj2)>0:
            for obj_num2 in ol_obj2:
                ol_area=np.sum(ol_value==obj_num2)
                obj_idx2=obj_num2-1
                seg_overlap[obj_idx1][obj_idx2]=ol_area  

    return seg_overlap

In [4]:
#-------calculate the cell fusion -----------------------
def cal_cell_fusion(frame_overlap):
    nb_cell_1=frame_overlap.shape[0]
    nb_cell_2=frame_overlap.shape[1]
    
    prefuse_group=[]#each element is a list include all prefuse cells in a fuse event, corresponding to postfuse_cells
    postfuse_cells=[]#include: img_num,obj_num
    frame_fusion = np.zeros(frame_overlap.shape)
    for source_o_n in range(1,nb_cell_1+1):
        #find target whose max_overlap mother is source
        ol_target=frame_overlap[source_o_n-1,:]
        if np.all(ol_target==0):#if source obj have no overlap target
            target_o_n=0    
        else:      
            target_o_n=np.argmax(frame_overlap,axis=1)[source_o_n-1]+1#axis=1,maximum of each row,return column index
        
       
        if target_o_n> 0:
            frame_fusion[source_o_n-1, target_o_n-1] = 1
    
        #Compute the sum vector S which is the sum of all the columns of frame_fusion matrix. The fusion target region
        #will have at least 2 cells tracked to it => S>1
    S = np.sum(frame_fusion, axis=0)
    frame_fusion[:, S==1] = 0          
    # Update the sum vector
    S = np.sum(frame_fusion, axis=0)

    for i in range(len(np.where(S >= 2)[0])):
        f_group=[]
        postfuse_cells.append([np.where(S >= 2)[0][i]+1])#num of prefuse cells:S[np.where(S >= 2)[0][i]]
        frame_fusion_i=frame_fusion[:,np.where(S >= 2)[0][i]]

        for r in range(len(np.where(frame_fusion_i==1)[0])):
            #fuse_pairs.append([img_num_1,np.where(frame_fusion_i==1)[0][r]+1,img_num_2,np.where(S >= 2)[0][i]+1])
            f_group.append(np.where(frame_fusion_i==1)[0][r]+1)
        prefuse_group.append(f_group)
    return postfuse_cells,prefuse_group

In [5]:
def hmax_watershed(img, h_thresh, small_obj_thres = 10, mask_thres=0.15):
    """
    img - a 2d array for predicted edt image.
    h_thresh is between 0 to 1.
    """

    try:
        local_hmax = h_maxima(img, h_thresh)
    except ValueError:
        labels = np.zeros(img.shape)
        local_hmax_label = np.zeros(img.shape)
        return labels, local_hmax_label
    
    local_hmax_label = label(local_hmax, connectivity=1)
    labels = watershed(-img, local_hmax_label, mask=img > mask_thres)
    
    return labels, local_hmax_label

def folder_watershed_labels(folder, h_thresh, small_obj_thres = 10, mask_thres = 0.15):
    """
    Obtain watershed masks for a edt image folder
    h_thresh - float between 0 to 1, quantil for watershed
    small_obj_thres - int
    """
    folder = hj_util.folder_verify(folder)
    edt_files = hj_util.folder_file_num(folder)
    label_list = []
    for f in edt_files:
        edt = imread(f)
        edt_flat = edt[edt!=0.0].reshape(-1)
        h_thresh = np.quantile(edt_flat,h_thresh)
        mask_thres = np.quantile(edt_flat,mask_thres)
        label, local_hmax_label = hmax_watershed(edt, h_thresh, small_obj_thres, mask_thres)
        label = remove_small_objects(label, small_obj_thres)
        label= clear_border(label)
        label_list.append(label)

    print("---------- Watershed processing is complete. -----------")
    return label_list

In [6]:
def generate_single_cell_img(img,seg,img_num,obj_num):
    
    #single_obj_img=morphology.binary_dilation(seg_img==obj_num,morphology.diamond(16))
    single_obj_img=seg==obj_num
    single_obj_img=label(single_obj_img)
    rps=regionprops(single_obj_img)
    candi_r=[r for r in rps if r.label==1][0]
    candi_box=candi_r.bbox        
    single_cell_img=single_obj_img*img
    crop_img=single_cell_img[candi_box[0]:candi_box[2],candi_box[1]:candi_box[3]]

    return crop_img

In [7]:
def generate_single_cell_label(img,seg,img_num,obj_num):
    
    #single_obj_img=morphology.binary_dilation(seg_img==obj_num,morphology.diamond(16))
    single_obj_img=seg==obj_num
    single_obj_img=label(single_obj_img)
    rps=regionprops(single_obj_img)
    candi_r=[r for r in rps if r.label==1][0]
    candi_box=candi_r.bbox        
    single_cell_label=single_obj_img*seg
    crop_label=single_cell_label[candi_box[0]:candi_box[2],candi_box[1]:candi_box[3]]

    return crop_label

In [8]:
def generate_single_cell_env(img,seg,img_num,obj_num,env_size):
    
    #single_obj_img=morphology.binary_dilation(seg_img==obj_num,morphology.diamond(16))
    single_obj_img=seg==obj_num
    single_obj_img=label(single_obj_img)
    rps=regionprops(single_obj_img)
    candi_r=[r for r in rps if r.label==1][0]
    candi_box=candi_r.bbox        
    
    start = (candi_box[0], candi_box[1])
    end = (candi_box[2], candi_box[3])
    w, h = img.shape
    ww, hh = rectangle_perimeter(start, end=end, shape=img.shape)
    img_local = img.copy()

    env_b1 = candi_box[0]-env_size
    if env_b1<=0:
        env_b1=0
    env_b2 = candi_box[2]+env_size
    if env_b2>=w:
        env_b2=w
    env_b3 = candi_box[1]-env_size
    if env_b3<=0:
        env_b3=0
    env_b4 = candi_box[3]+env_size
    if env_b4>=h:
        env_b4=h
    img_local[ww, hh] = np.min(img_local[env_b1:env_b2,env_b3:env_b4])
    crop_env = img_local[env_b1:env_b2,env_b3:env_b4]

    return crop_env

In [9]:
def list_select_by_index(the_list, index_list):
    """Return a list of selected items. """
    selected_elements = [the_list[index] for index in index_list]
    return selected_elements

def color_num(labels):
    
    """draw color numbers on the center of the label"""
    label_rgb=label2rgb(labels,bg_label=0) # Return an RGB image where color-coded labels are painted over the image.
    img_rgb= toimage(label_rgb)
    base = img_rgb.convert('RGBA')
    # make a blank image for the text, initialized to transparent text color
    txt = Image.new('RGBA', base.size, (255,255,255,0))
    # get a font
    fnt = ImageFont.truetype('arial.ttf', 40)
    # get a drawing context
    d = ImageDraw.Draw(txt)
    for region in regionprops(labels):
        cx=int(region.centroid[1])     
        cy=int(region.centroid[0])
        d.text((cx,cy),str(labels[cy][cx]),font=fnt,fill=(255,255,255,255)) 
    out = Image.alpha_composite(base, txt)       
    return out 

def generating_img_label_list(img_list, img_name_list, reg_list, \
                              small_obj_thres, frag_size_thres, \
                              low_h,high_h,mask_thres, \
                              crop_path, frag_path, \
                              env_size = 50):
      
    img_list = np.array(img_list)
    reg_list = np.array(reg_list)

    img_filename_list = []
    img_thread = []
    label_thread = []
    env_thread = []
    img_stat_list = []
    rgb_l_list = []
    rgb_h_list = []
    
    for i in range(len(img_list)):
        
        img_num=i+1
        img=img_list[i]
        reg=reg_list[i]
        
        img_name=img_name_list[i][:-4]
        pos = img_name.split('_XY')[1][0]
        fr = img_name.split('_T')[1][:4]
        
        low_h_seg,low_h_markers=hmax_watershed(reg,h_thresh=low_h,small_obj_thres=small_obj_thres,mask_thres=mask_thres)
        rgb_l_list.append(color_num(low_h_seg))
        high_h_seg,high_h_markers=hmax_watershed(reg,h_thresh=high_h,small_obj_thres=small_obj_thres,mask_thres=mask_thres)
        rgb_h_list.append(color_num(high_h_seg))
        seg_overlap=compute_overlap_matrix(low_h_seg,high_h_seg)
        fuse_cells,fuse_group=cal_cell_fusion(seg_overlap)

        for m in range(len(fuse_cells)):

            fc_obj=generate_single_cell_img(img,high_h_seg,img_num,fuse_cells[m][0])
            fc_img=toimage(fc_obj,high=np.max(fc_obj),low=np.min(fc_obj),mode='I')
            img_filename = 'XY' + pos + '_fc_fr' + str(fr).zfill(4) + \
                           '_obj' + str(fuse_cells[m][0]).zfill(3)
            fc_img.save(crop_path + '/' + img_filename + '.png')
            img_filename_list.append(img_filename)
            img_thread.append(fc_obj)
            label_thread.append(generate_single_cell_label(img,high_h_seg,img_num,fuse_cells[m][0]))
            env_thread.append(generate_single_cell_env(img,high_h_seg,img_num,fuse_cells[m][0],env_size))
            img_stat_list.append([pos,fr,fuse_cells[m][0]]) 
            
            for n in range(len(fuse_group[m])):
                fp_obj=generate_single_cell_img(img,low_h_seg,img_num,fuse_group[m][n])

                fp_img=toimage(fp_obj,high=np.max(fp_obj),low=np.min(fp_obj),mode='I')
                fp_mask=fp_obj>0
                if np.sum(fp_mask)<frag_size_thres:
                    img_filename = 'XY' + pos + '_fp_fr' + str(fr).zfill(4) + \
                                   '_obj' + str(fuse_group[m][n]).zfill(3)
                    fp_img.save(frag_path + '/' + img_filename + '.png')
                else:
                    img_filename = 'XY' + pos + '_fp_fr' + str(fr).zfill(4) + \
                                   '_obj' + str(fuse_group[m][n]).zfill(3)
                    fp_img.save(crop_path + '/' + img_filename + '.png')
                img_filename_list.append(img_filename)
                img_thread.append(fp_obj)
                label_thread.append(generate_single_cell_label(img,low_h_seg,img_num,fuse_group[m][n]))
                env_thread.append(generate_single_cell_env(img,low_h_seg,img_num,fuse_group[m][n],env_size))
                img_stat_list.append([pos,fr,fuse_group[m][n]]) 
                
    img_filename_list = np.array(img_filename_list)            
    img_thread = np.array(img_thread)
    label_thread = np.array(label_thread)
    env_thread = np.array(env_thread)
    for img_idx in range(len(img_stat_list)):
        for stat_idx in range(len(img_stat_list[img_idx])):
            img_stat_list[img_idx][stat_idx] = np.int(img_stat_list[img_idx][stat_idx])
    img_stat_list = np.array(img_stat_list)
    
    print("The total image number is: ", len(img_thread))
    return img_filename_list, img_thread, label_thread, env_thread, img_stat_list
    
def frame_img_label_show(env_list,img_thread,label_thread,img_stat_list,fr_list,fr_idx,img_filename_list):
    
    list_idx = img_stat_list[:,1] == fr_list[fr_idx]

    fr_img_filename_list = img_filename_list[list_idx]
    fr_env_list = env_list[list_idx]
    fr_img_thread = img_thread[list_idx]
    fr_label_thread = label_thread[list_idx]
    fr_img_stat_list = img_stat_list[list_idx]
    
    fr_list_len = len(fr_env_list)
    for fr_list_idx in range(fr_list_len):
        fig = plt.figure(figsize = (5, 15))
        fig.add_subplot(1,3, 1)
        plt.imshow(fr_env_list[fr_list_idx], cmap = "gray")
        ax1 = plt.gca(); ax1.axes.xaxis.set_visible(False); ax1.axes.yaxis.set_visible(False)
        ax1.title.set_text(fr_img_filename_list[fr_list_idx])
        fig.add_subplot(1,3, 2)
        plt.imshow(fr_img_thread[fr_list_idx], cmap = "gray")
        ax2 = plt.gca(); ax2.axes.xaxis.set_visible(False); ax2.axes.yaxis.set_visible(False)
        fig.add_subplot(1,3, 3)
        plt.imshow(fr_label_thread[fr_list_idx], cmap = "gray")
        ax3 = plt.gca(); ax3.axes.xaxis.set_visible(False); ax3.axes.yaxis.set_visible(False)
    
def make_img_list(img_path):
    
    img_f = hj_util.folder_file_num(img_path)
    img_list = []

    i = 0
    while i<len(img_f):
        
        img = imread(img_f[i])
        img_list.append(img)
        i = i+1
    return img_list

def make_name_list(img_path):
    
    img_f = hj_util.folder_file_num(img_path)
    img_name_list = []
    
    i = 0
    while i<len(img_f):
        
        img_name = basename(img_f[i])
        img_name_list.append(img_name)
        i = i+1
    return img_name_list

def make_fr_list(img_path):
    
    img_f = hj_util.folder_file_num(img_path)
    fr_list = []
    
    i = 0
    while i<len(img_f):
        
        img_name = basename(img_f[i])
        name_split = img_name.split('_T')
        curr_fr = int(name_split[1][:4])
        fr_list.append(curr_fr)
        i = i+1
    fr_list = np.array(fr_list)
    return fr_list

def single_cell_save(img, folder, pos, fr, i):
    
    folder = hj_util.folder_verify(folder)
    imsave(folder + 'XY' + str(pos) + '_fr' + str(fr) + '_cr' + str(i+1).zfill(3) +".tif", img.astype(np.uint32))

In [10]:
dat_dir = 'C:/Users/14432/OneDrive/Research/Projects/A549_144hr/data/'
img_path = dat_dir + 'subsamps/XY3/seg_train-III/img/'
reg_path = dat_dir + 'subsamps/XY3/seg_train-III/edt/'
crop_path = dat_dir + 'train/icnn_seg/unclassified/'
frag_path = crop_path + 'frag/'
hj_util.create_folder(frag_path)

for f in os.listdir(crop_path):
    if os.path.isfile(os.path.join(crop_path, f)) == True:
        os.remove(os.path.join(crop_path, f))
        
for f in os.listdir(frag_path):
    os.remove(os.path.join(frag_path, f))

start_idx = 0
end_idx = 5
n_fr = end_idx-start_idx
low_h=0.05
high_h=0.8
small_obj_thres=1500
mask_thres=0.025
frag_size_thres=2500

img_list = make_img_list(img_path)[start_idx:end_idx]
img_name_list = make_name_list(img_path)[start_idx:end_idx]
reg_list = make_img_list(reg_path)[start_idx:end_idx]
fr_list = make_fr_list(img_path)[start_idx:end_idx]

img_filename_list, img_thread,label_thread, env_thread, img_stat_list = generating_img_label_list(
    img_list, img_name_list, reg_list, \
    small_obj_thres, frag_size_thres, \
    low_h,high_h,mask_thres, \
    crop_path, frag_path)



C:/Users/14432/OneDrive/Research/Projects/A549_144hr/data/train/icnn_seg/unclassified/frag/ folder is freshly created. 

C:/Users/14432/OneDrive/Research/Projects/A549_144hr/data/subsamps/XY3/seg_train-III/img/ has 5 files
C:/Users/14432/OneDrive/Research/Projects/A549_144hr/data/subsamps/XY3/seg_train-III/img/ has 5 files
C:/Users/14432/OneDrive/Research/Projects/A549_144hr/data/subsamps/XY3/seg_train-III/edt/ has 5 files
C:/Users/14432/OneDrive/Research/Projects/A549_144hr/data/subsamps/XY3/seg_train-III/img/ has 5 files
The total image number is:  86


In [12]:
b1 = widgets.Button(
    description='next frame',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check')

def init_globals():
    global fr_idx, crop_num
    fr_idx = 0
    crop_num = 0

def on_b1_clicked(b):
    global fr_idx
    fr_idx = fr_idx+1
    with out:
        clear_output()      
        if fr_idx<n_fr:
            print('XY ' + str(img_stat_list[fr_idx][0]) + ' fr ' + str(fr_list[fr_idx])+ \
                  ' ('+str(fr_idx+1)+'/'+str(n_fr)+')')
            frame_img_label_show(env_thread,img_thread,label_thread,img_stat_list,fr_list,fr_idx,img_filename_list)
            show_inline_matplotlib_plots()
        else:
            print("This is the end of pic queue.")

display(b1)
out = widgets.Output()
display(out)

init_globals()
with out:
    clear_output()
    print('XY ' + str(img_stat_list[fr_idx][0])+' fr ' + str(fr_list[fr_idx])+' ('+str(fr_idx+1)+'/'+str(n_fr)+')')
    frame_img_label_show(env_thread,img_thread,label_thread,img_stat_list,fr_list,fr_idx,img_filename_list)
    show_inline_matplotlib_plots()
        
b1.on_click(on_b1_clicked)

Button(button_style='info', description='next frame', icon='check', style=ButtonStyle(), tooltip='Click me')

Output()