In [129]:
import sys
# insert at 1, 0 is the script path (or '' in REPL)
# These top lines are critical for import from another folder
sys.path.insert(1, '/net/capricorn/home/xing/huijing/Segmentation/scripts/vimentin_DIC_segmentation_pipeline/hj_modify_pipe')

from IPython.display import display, clear_output
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
import ipywidgets as widgets
import random
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt

from skimage.io import imread, imsave
from skimage.segmentation import watershed, clear_border
from skimage.feature import peak_local_max
from skimage.morphology import remove_small_objects, local_maxima, h_maxima, disk, dilation
from skimage.measure import regionprops, label
from skimage.color import label2rgb
from skimage import filters
from skimage.exposure import equalize_adapthist
from skimage import img_as_float32, img_as_ubyte
from skimage.draw import rectangle_perimeter

import shutil
import hj_util

In [130]:
def hmax_watershed(img, h_thresh, small_obj_thresh = 10, mask_thresh=0.15):
    """
    img - a 2d array for predicted edt image.
    h_thresh is between 0 to 1.
    """
    # converting to quantile.
    h_thresh = np.quantile(img, h_thresh)
    mask_thresh = np.quantile(img.flatten(), mask_thresh)
    
    try:
        local_hmax = h_maxima(img, h_thresh)
    except ValueError:
        labels = np.zeros(img.shape)
        local_hmax_label = np.zeros(img.shape)
        return labels, local_hmax_label
    
    local_hmax_label = label(local_hmax, connectivity=1)
    labels = watershed(-img, local_hmax_label, mask=img > mask_thresh)
    labels = remove_small_objects(labels, small_obj_thresh)
    return labels, local_hmax_label

def folder_watershed_labels(folder, h_thresh, small_obj_thresh = 10, mask_thresh = 0.15):
    """
    Obtain watershed masks for a edt image folder
    h_thresh - float between 0 to 1, quantil for watershed
    small_obj_thresh - int
    """
    folder = hj_util.folder_verify(folder)
    edt_files = hj_util.folder_file_num(folder)
    label_list = []
    for f in edt_files:
        edt = imread(f)
        label, local_hmax_label = hmax_watershed(edt, h_thresh, small_obj_thresh, mask_thresh)
        label_list.append(label)

    print("---------- Watershed processing is complete. -----------")
    return label_list

In [131]:
random.seed(123) 


def list_select_by_index(the_list, index_list):
    """Return a list of selected items. """
    selected_elements = [the_list[index] for index in index_list]
    return selected_elements

def create_class_folders(class_list, aim_folder):
    """
    Create subfolders for each class in the aim folder.
    Return a list of subfolder path"""
    aim_folder = hj_util.folder_verify(aim_folder)
    c_folder_list = []
    
    i = 0
    while i<len(class_list):
        sub_folder = aim_folder+class_list[i]
        hj_util.create_folder(sub_folder)
        c_folder_list.append(sub_folder)
        i = i+1
    return c_folder_list


def generating_img_label_list(img_list, label_list, img_selection = None, env_size = 50,
                              num = 500, area_quantile = [0.15, 0.85], fractions = [0.2, 0.6, 0.2]):
    
    """
    img_list - list of np.array
    label_list - list of np.array
    
    num - int, number of total single cells
    img_selection - list of int, index for imgs that user choose to use.
    size_quantile - quantile for cell size threshold, between 0 to 1.
    fraction - fraction of small, mid, large cells.
    """
    
    if img_selection:
        img_list = list_select_by_index(img_list, img_selection)
        label_list = list_select_by_index(label_list, img_selection)
    
    img_list = np.array(img_list)
    label_list = np.array(label_list, dtype=np.int32)
    
    area_list = []
    img_thread = []
    label_thread = []
    env_list = []
    i = 0
    while i<len(label_list):
        
        label = label_list[i]
        img = img_list[i]
        
        r, c = img.shape
        if ((r+1.)/(c+1.)>15. or (c+1.)/(r+1.)>15. or r<4 or c<4):
            i = i+1 
            continue
        
        label_props = regionprops(label) 
        
        j = 0
        while j<len(label_props):
            
            # generating img crop and label crop
            bound = label_props[j].bbox
            img_crop = img[bound[0]:bound[2],bound[1]:bound[3]]
            blank = np.zeros(label.shape)+1
            label_crop = label[bound[0]:bound[2],bound[1]:bound[3]]
            
            # generating env crop with box
            env_b1 = bound[0]-env_size
            if env_b1<=0:
                env_b1=0
            env_b2 = bound[2]+env_size
            if env_b2>=r:
                env_b2=r
            env_b3 = bound[1]-env_size
            if env_b3<=0:
                env_b3=0
            env_b4 = bound[3]+env_size
            if env_b4>=c:
                env_b4=c
            start = (bound[0], bound[1])
            end = (bound[2], bound[3])
            rr, cc = rectangle_perimeter(start, end=end, shape=img.shape)
            img_local = img.copy()
            img_local[rr, cc] = 0.
            #print(env_b1, env_b2, env_b3, env_b4)
            env_crop = img_local[env_b1:env_b2,env_b3:env_b4]
            
            img_thread.append(img_crop)
            label_thread.append(label_crop)
            area_list.append(label_props[j].area)
            env_list.append(env_crop)
            j = j+1
        i = i+1
    
    inx = np.arange(len(img_thread))
    
    if num and area_quantile and fractions:
        s_thre, l_thre = np.quantile(area_list, area_quantile)
        inx_s = []
        inx_m = []
        inx_l = []
        i = 0
        while i<len(area_list):
            if area_list[i]<s_thre:
                inx_s.append(i)
            elif area_list[i]<l_thre:
                inx_m.append(i)
            else:
                inx_l.append(i)
            i = i+1
            
        num_s = len(inx_s)
        num_m = len(inx_m)
        num_l = len(inx_l)
        
        num_s_need, num_m_need, num_l_need = np.int32(np.array(fractions) * 200)
        
        if num_s>num_s_need:
            inx_s = np.random.choice(inx_s, size=num_s_need, replace=False)
        if num_m>num_m_need:
            inx_s = np.random.choice(inx_m, size=num_m_need, replace=False)
        if num_l>num_l_need:
            inx_s = np.random.choice(inx_l, size=num_l_need, replace=False)
        inx = np.concatenate([inx_s, inx_m, inx_l])
        
        img_thread = list_select_by_index(img_thread, inx)
        label_thread = list_select_by_index(label_thread, inx)
        area_list = list_select_by_index(area_list, inx)
        env_list = list_select_by_index(env_list, inx)
        
    print("The total image number is: ", len(img_thread))
    return img_thread, label_thread, env_list, inx

def img_label_show(img_thread, label_thread, env_list, i):
    """Show the img and curresponding edt pictures."""
    fig = plt.figure(figsize = (6, 4))
    fig.add_subplot(1,3, 1)
    plt.imshow(img_thread[i], cmap = "gray")
    fig.add_subplot(1,3, 2)
    plt.imshow(label_thread[i], cmap = "gray")
    fig.add_subplot(1,3, 3)
    plt.imshow(env_list[i], cmap = "gray")
    
def make_img_list(img_folder):
    
    img_f = hj_util.folder_file_num(img_folder)
    img_list = []

    i = 0
    while i<len(img_f):
        
        img = imread(img_f[i])
        img_list.append(img)
        i = i+1
    return img_list

def single_cell_save(img, i, name, folder):
    
    folder = hj_util.folder_verify(folder)
    imsave(folder + name + str(i).zfill(4) +".tif", img.astype(np.uint8))

In [138]:
edt_folder = '/net/capricorn/home/xing/huijing/Segmentation/data/Incucyte_data/06-21-21-B1_02_crop_part_30_output/edt'
img_folder = '/net/capricorn/home/xing/huijing/Segmentation/data/Incucyte_data/06-21-21-B1_02_crop_part_30'

label_list = folder_watershed_labels(edt_folder, 0.95, small_obj_thresh = 10, mask_thresh = 0.15)
img_list = make_img_list(img_folder)

img_thread, label_thread, env_list, inx = generating_img_label_list(img_list, label_list, num=None)

/net/capricorn/home/xing/huijing/Segmentation/data/Incucyte_data/06-21-21-B1_02_crop_part_30_output/edt/ has 29 files
---------- Watershed processing is complete. -----------
/net/capricorn/home/xing/huijing/Segmentation/data/Incucyte_data/06-21-21-B1_02_crop_part_30/ has 29 files
The total image number is:  530


In [140]:
b1 = widgets.Button(
    description='other',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
b2 = widgets.Button(
    description='apoptosis',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
b3 = widgets.Button(
    #description='Two Cell Segmentation',
    description='mitosis',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
b4 = widgets.Button(
    description='Ignore This Pic',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
b5 = widgets.Button(
    description='Ignore This Pic',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)

aim_folder = "./testing"
class_list = ['other', 'apoptosis', 'mitosis']
folder_list = create_class_folders(class_list, aim_folder)
name = "dld1"

    
def init_global_i():
    global i
    i = 0
    

def on_b1_clicked(b): # other, normal cell button
    global i
    folder = folder_list[0]
    single_cell_save(img_thread[i], i, name, folder)
    
    i = i+1  
    with out:
        clear_output()      
        print(i)
        if i<len(img_thread):
            img_label_show(img_thread, label_thread, env_list, i)
            show_inline_matplotlib_plots()
        else:
            print("This is the end of pic queue.")

        
        
def on_b2_clicked(b): # apoptosis button
    global i
    folder = folder_list[1]
    single_cell_save(img_thread[i], i, name, folder)
    
    i = i+1  
    with out:
        clear_output()      
        print(i)
        if i<len(img_thread):
            img_label_show(img_thread, label_thread, env_list, i)
            show_inline_matplotlib_plots()
        else:
            print("This is the end of pic queue.")
        
def on_b3_clicked(b): # mitosis button
    global i
    folder = folder_list[2]
    single_cell_save(img_thread[i], i, name, folder)
    
    i = i+1 
    with out:
        clear_output()       
        print(i)
        if i<len(img_thread):
            img_label_show(img_thread, label_thread, env_list, i)
            show_inline_matplotlib_plots()
        else:
            print("This is the end of pic queue.")
        
def on_b4_clicked(b): # The more cell segmentation button
    
    global i
    i = i+1 
    
    with out:
        clear_output()       
        print(i)
        if i<len(img_thread):
            img_label_show(img_thread, label_thread, env_list, i)
            show_inline_matplotlib_plots()
        else:
            print("This is the end of pic queue.")
        
def on_b5_clicked(b): # The ignore
    
    global i
    i = i+1  
    
    with out:
        clear_output()      
        print(i)
        if i<len(img_thread):
            img_label_show(img_thread, label_thread, env_list, i)
            show_inline_matplotlib_plots()
        else:
            print("This is the end of pic queue.")

./testing/other/ folder is freshly created. 

./testing/apoptosis/ folder is freshly created. 

./testing/mitosis/ folder is freshly created. 



In [141]:
display(b1, b2, b3, b4, b5)

out = widgets.Output()
display(out)

init_global_i()

img_thread = img_thread
label_thread = label_thread

with out:
    clear_output()
    img_label_show(img_thread, label_thread, env_list, i)
    show_inline_matplotlib_plots()
    
    
b1.on_click(on_b1_clicked)
b2.on_click(on_b2_clicked)
b3.on_click(on_b3_clicked)
b4.on_click(on_b4_clicked)
b5.on_click(on_b5_clicked)

Button(button_style='info', description='other', icon='check', style=ButtonStyle(), tooltip='Click me')

Button(button_style='info', description='apoptosis', icon='check', style=ButtonStyle(), tooltip='Click me')

Button(button_style='info', description='mitosis', icon='check', style=ButtonStyle(), tooltip='Click me')

Button(button_style='info', description='Ignore This Pic', icon='check', style=ButtonStyle(), tooltip='Click m…

Button(button_style='info', description='Ignore This Pic', icon='check', style=ButtonStyle(), tooltip='Click m…

Output()