In [1]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import random

import cv2
from skimage import io
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from tqdm import tqdm
import skimage.morphology as morph
from scipy import ndimage as ndi
from scipy.stats import itemfreq
from sklearn.externals import joblib
import pylab
from util import plots

data_path = 'data/prediction_data.npz'
data_dst_path = 'data/prediction_data_final_with_post.npz'
batch_size = 1

def load_data():
    data = np.load(data_path)
    imgs = data['images']
    masks = data['masks']
    contours = data['contours']
    return imgs, masks, contours

def morphological_method_simple(image, mask, contour):
    mask = np.squeeze(mask, axis=2)
    contour = np.squeeze(contour, axis=2)
  
    # mask -> binary mask, otsu
    mask_binary, contour_binary = binary_mask(mask, contour)
    contour_binary_thin = morph.binary_erosion(contour_binary)
    
    # cut the mask by contour
    mask_tmp = np.where(~mask_binary | contour_binary_thin, 0, 1)
    mask_tmp = np.expand_dims(mask_tmp, axis=2)
    plots([image, mask_binary, contour, mask_tmp])

    return mask_tmp

def binary_mask(mask, contour):
    # mask -> binary mask, otsu
    m_thresh = threshold_otsu(mask)
    c_thresh = threshold_otsu(contour)
#     m_thresh = 0.5
#     c_thresh = 0.5
    mask_binary = mask > m_thresh
    contour_binary = contour > c_thresh
    return mask_binary, contour_binary


def post_processing(images, masks, contours, mode=1):
    # mode=1 with contour, mode=0 no contour
    num_img = masks.shape[0]
    masks_final = np.zeros(masks.shape)
    if mode == 0:
        for i in range(num_img):
            print(np.mean(images[i,:,:,:]))
            if np.mean(images[i,:,:,:]) > 100:
                images[i,:,:,:] = 255 - images[i,:,:,:]
            mask_binary, contour_binary = binary_mask(masks[i,:,:,:], contours[i,:,:,:])
            masks_final[i,:,:] = mask_binary
            plots([images[i,:,:,:], mask_binary])
    else:
        for i in range(num_img):
#             morphological_method(images[i,:,:,:], masks[i,:,:,:], contours[i,:,:,:])
            masks_final[i,:,:] = morphological_method(images[i,:,:,:], masks[i,:,:,:], contours[i,:,:,:])
#             plots([images[i,:,:,:], masks_final[i,:,:]])
    np.savez_compressed(data_dst_path, images=images, masks=masks, contours=contours, result=masks_final)

def morphological_method(image, mask_pred, contour_pred):
    mask_pred = np.squeeze(mask_pred, axis=2)
    contour_pred = np.squeeze(contour_pred, axis=2)
    
    # mask -> binary mask, otsu
    m_thresh = threshold_otsu(mask_pred)
    c_thresh = threshold_otsu(contour_pred)
    mask_binary = mask_pred > m_thresh
    contour_binary = contour_pred > 0.5
    struct_size = 1.5
    struct_el = morph.disk(struct_size)
    contour_binary = morph.binary_erosion(contour_binary, struct_el)
    
    # combine mask and contour and fill the holes
#     mask_tmp = np.where(mask_binary | contour_binary, 1, 0)
    mask_tmp = mask_binary
    mask_tmp = ndi.binary_fill_holes(mask_tmp)
#     print("binary mask and combined, fill holes")
    mask_label, num_cell_mask = ndi.label(mask_tmp)
#     print(mask_label.dtype)
    mask_label_color = (mask_label*255.0/num_cell_mask).astype(np.uint8)
#     plots([np.squeeze(mask_label_color)], cmap=None)
    
    
    # only keep the region inside the contour
    mask_tmp = np.where(~mask_tmp | contour_binary, 0, 1)
    mask_tmp = ndi.binary_fill_holes(mask_tmp)
#     plots([mask_binary, contour_binary, mask_tmp])
    
    mask_label, num_cell_mask = ndi.label(mask_tmp)
    mask_label_color = (mask_label*255.0/num_cell_mask).astype(np.uint8)
#     plots([np.squeeze(mask_label_color)], cmap=None)
    
    
    # dilate each individual region and & with contour, overlap together
    label_overlap = np.zeros(mask_label.shape)
#     struct_el = morph.disk(1.2)
#     contour_thin = morph.binary_erosion(contour_binary, struct_el)
    struct_el = morph.disk(2)
    
    label_cnt = 1
    for i in range(1, num_cell_mask):
#         print(label_cnt)
        label_region_ori = np.where(mask_label==i, 1, 0)
        label_region = morph.binary_dilation(label_region_ori, struct_el)  # dilate the label
        label_region_add = np.where(label_region & contour_binary, 1, 0)  # find valid contour
        label_region = np.maximum(label_region_ori, label_region_add) # add contour
        if np.sum(label_region) < 5: #remove small region
#             print('delete')
            continue
#         print(np.sum(label_region))
        label_overlap = np.where(label_region, i, label_overlap)   # overlap each region
        label_cnt += 1
    print(label_overlap.shape)
#     plots([np.squeeze(label_overlap*255.0/num_cell_mask).astype(np.uint8)], cmap=None)
    return np.expand_dims(label_overlap, axis=2).astype(np.int32)    
    
    
imgs, masks, contours = load_data()
post_processing(imgs, masks, contours, mode=1)
print("generate final prediction")



(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
(384, 384)
generate final prediction
