In [56]:
import os
import xml.etree.ElementTree as ET
import cv2
from skimage import io as skimageio
import numpy as np
from skimage.draw import polygon as skimagedrawpolygon
from skimage.transform import resize as skimageresize
import skimage
from tqdm import tqdm
from imgaug import augmenters as iaa
import random
import keras
from keras.preprocessing import image as image_utils
from keras.models import load_model
from skimage.draw import polygon_perimeter as skimagedrawpolygonperimeter
from keras.layers import Input, Conv2D, Lambda, MaxPooling2D, Conv2DTranspose, concatenate, Dropout, BatchNormalization

from scipy.io import savemat,loadmat

In [2]:
# saving all file names in a single array - easy to call images and annotations

temp_files = os.listdir('/Users/ravitejachunduri/Documents/monuseg/training_data/')
all_files=[]
for file in temp_files:
    if (str.lower(file.strip().split('.')[-1]) in ['tif','tiff']):
        all_files.append(''.join(file.strip().split('.')[0:-1]))

In [3]:
# Images directory path:
im_path = '/Users/ravitejachunduri/Documents/monuseg/training_data'
# XML files directory path:
xml_path = '/Users/ravitejachunduri/Documents/monuseg/training_data'
# target images directory path:
target_path = '/Users/ravitejachunduri/Documents/monuseg/target'
if not os.path.exists(target_path):
    os.makedirs(target_path)
# target images directory path:
target_bw_path = '/Users/ravitejachunduri/Documents/monuseg/target_bw'
if not os.path.exists(target_bw_path):
    os.makedirs(target_bw_path)
# Random color target images directory path:
target_clr_path = '/Users/ravitejachunduri/Documents/monuseg/target_clr'
if not os.path.exists(target_clr_path):
    os.makedirs(target_clr_path)
# 70 percent of orignal size masks directory path:    
target_masks_70p_path = '/Users/ravitejachunduri/Documents/monuseg/target_masks_70p/'    
if not os.path.exists(target_masks_70p_path):
    os.makedirs(target_masks_70p_path)

In [4]:
def loading_images(files_list,im_path,xml_path,save_format1,save_format2,target_images,target_mask_bw,target_mask_clr):
    for i in tqdm(range(len(files_list))):
        xml_file = os.path.join(xml_path,(files_list[i])+'.xml')
        im_file = os.path.join(im_path,(files_list[i])+'.tif')
        
        image = cv2.imread(im_file)
        im_height, im_width = image.shape[0], image.shape[1]
        
        image_mask_true = np.zeros((im_height,im_width), dtype=np.uint8)
        image_mask = np.zeros((im_height,im_width), dtype=np.bool)
        image_mask_clr = np.zeros((im_height,im_width,3), dtype=np.float)
        tree = ET.parse(xml_file)
        for reg in tree.getroot()[0][1][1:]:
            x = []
            y = []
            for ver_k in reg[1][:]:
                x.append(float(ver_k.attrib['X']))
                y.append(float(ver_k.attrib['Y']))
            poly_row, poly_col = skimagedrawpolygon(y, x, (im_height,im_width))
            mask = np.zeros((im_height,im_width), dtype=np.bool)
            mask[poly_row, poly_col] = 1
            image_mask_true = image_mask_true + mask
            image_mask = np.logical_or(image_mask, mask)
            image_mask_clr[:,:,0] = image_mask_clr[:,:,0] + np.random.ranf()*mask
            image_mask_clr[:,:,1] = image_mask_clr[:,:,1] + np.random.ranf()*mask
            image_mask_clr[:,:,2] = image_mask_clr[:,:,2] + np.random.ranf()*mask
        image = image.astype(np.uint8)
        skimageio.imsave(os.path.join(target_images,(files_list[i] + save_format1)), image)    
        image_mask = image_mask*1
        image_mask = image_mask.astype(np.uint8)
        skimageio.imsave(os.path.join(target_mask_bw,(files_list[i] + save_format2)), image_mask, check_contrast=False)    
        image_mask = image_mask*255
        image_mask_clr[image_mask_clr > 1.0] = 1.0    
        image_mask_clr = image_mask_clr.astype(np.uint8)
        skimageio.imsave(os.path.join(target_mask_clr,(files_list[i] + save_format2)), image_mask_clr, check_contrast=False)    
    return
                         
loading_images(all_files,im_path,xml_path,'.png','.png',target_path,target_bw_path,target_clr_path)                        
                         
                         

100%|██████████| 30/30 [04:57<00:00,  9.91s/it]


In [5]:

for i in tqdm(range(len(all_files))):
    xml_file = os.path.join(xml_path,(all_files[i])+'.xml')
    image_file = os.path.join(im_path,(all_files[i])+'.tif')
    xml_tree = ET.parse(xml_file)
    image = cv2.imread(image_file)
    im_height, im_width = image.shape[0], image.shape[1]
    image_mask_fill = np.zeros((im_height,im_width), dtype=np.uint8)
    image_mask_contour = np.zeros((im_height,im_width), dtype=np.uint8)
    for reg_j in xml_tree.getroot()[0][1][1:]:
        x = []
        y = []
        for ver_k in reg_j[1][:]:
            x.append(float(ver_k.attrib['X']))
            y.append(float(ver_k.attrib['Y']))
        if len(x) < 3:
            continue
        poly_row, poly_col = skimagedrawpolygonperimeter(y, x, (im_height,im_width),clip=True)
        a = np.concatenate((np.expand_dims(poly_col,axis=1), np.expand_dims(poly_row,axis=1)), axis=1)
        b = np.expand_dims(a, axis=0)
        c = np.expand_dims(b, axis=2)
        center = [int(np.mean(y)), int(np.mean(x))]
        min_dist = 1000.0
        for k in range(len(poly_row)):
            if np.sqrt((poly_row[k]-center[0])**2 + (poly_col[k]-center[1])**2) < min_dist:
                min_dist = np.sqrt((poly_row[k]-center[0])**2 + (poly_col[k]-center[1])**2)
        d = np.floor(min_dist*0.7)
        image_mask_center = np.zeros((im_height,im_width), dtype=np.uint8)
        mask = cv2.drawContours(image_mask_center, c, -1, (255, 255, 255), int(d)) 
        poly_row2, poly_col2 = skimagedrawpolygon(y, x, (im_height,im_width))
        mask2 = np.zeros((im_height,im_width), dtype=np.bool)
        mask2[poly_row2, poly_col2] = True
        mask2 = mask2 * 255
        image_mask_contour = image_mask_contour + mask
        image_mask_fill = image_mask_fill + mask2
    image_mask_contour[image_mask_contour > 0] = 255
    image_mask_fill[image_mask_fill > 0] = 255
    f_mask = image_mask_fill - image_mask_contour
    f_mask[f_mask < 0] = 0
    f_mask[f_mask < 0] = 0
    f_mask = f_mask.astype(np.uint8)
    skimageio.imsave(os.path.join(target_masks_70p_path,(all_files[i] + '.png')), f_mask, check_contrast=False)  

100%|██████████| 30/30 [02:24<00:00,  4.82s/it]


In [6]:
im_height = 1000
im_width = 1000
channels = 3
patch_size_image = 125
patch_size_model = 128
height = patch_size_image
width = patch_size_image
stride = 40

In [7]:
original_im_path = '/Users/ravitejachunduri/Documents/monuseg/target'
original_mask_path = '/Users/ravitejachunduri/Documents/monuseg/target_bw'
p70_masks_path = '/Users/ravitejachunduri/Documents/monuseg/target_masks_70p/'

# saving all the images from disk into an array
list_files =[]
for file in os.listdir(original_im_path):
    if (str.lower(file.strip().split('.')[-1]) in ['png']):
        list_files.append(''.join(file.strip().split('.')[0:-1]))
        
original_images = np.zeros((len(list_files),im_height,im_width,channels), dtype=np.uint8)
for im_i in tqdm(range(len(list_files))):
    image_file = os.path.join(original_im_path, (list_files[im_i] + '.png'))
    original_images[im_i, :, :, :3] = skimageio.imread(image_file)        
    
# saving all the masks from disk into an array    

original_masks = np.zeros((len(list_files),im_height,im_width), dtype=np.uint8)
for im_i in tqdm(range(len(list_files))):
    mask_file = os.path.join(original_mask_path, (list_files[im_i] + '.png'))
    original_masks[im_i] = skimageio.imread(mask_file)
    
# saving all the 70 percent masks from disk into an array    

p70_masks = np.zeros((len(list_files),im_height,im_width), dtype=np.uint8)
for im_i in tqdm(range(len(list_files))):
    mask_file_reduced = os.path.join(p70_masks_path, (list_files[im_i] + '.png'))
    p70_masks[im_i] = skimageio.imread(mask_file_reduced)    

100%|██████████| 30/30 [00:02<00:00, 13.05it/s]
100%|██████████| 30/30 [00:00<00:00, 104.09it/s]
100%|██████████| 30/30 [00:00<00:00, 106.07it/s]


In [8]:
num_images_original = int(len(list_files) * ((np.ceil((im_height - height)/ stride)+1) **2))
augment_clahe_per_image = 50 
augment_gblur_per_image = 50
augment_aff_rotate_per_image = 50

num_images_augmented = (augment_clahe_per_image + augment_gblur_per_image + augment_aff_rotate_per_image ) * len(list_files)
num_images = num_images_original + num_images_augmented

images = np.zeros((num_images, patch_size_image, patch_size_image, channels), dtype=np.uint8)
masks = np.zeros((num_images, patch_size_image, patch_size_image,1), dtype=np.bool)
masks_reduced = np.zeros((num_images, patch_size_image, patch_size_image,1), dtype=np.bool)

counter = 0
for im_i in tqdm(range(len(list_files))):
    image = original_images[im_i]
    mask = original_masks[im_i]
    mask_reduced = p70_masks[im_i]
    for i in list(range(0, im_height-height, stride))+[im_height-height]:
        for j in list(range(0, im_width-height, stride))+[im_height-height]:
            images[counter,:,:,:] = image[i:i+height, j:j+width, :]
            masks[counter,:,:,0] = mask[i:i+height, j:j+width]
            masks_reduced[counter,:,:,0] = mask_reduced[i:i+height, j:j+width]
            counter += 1

100%|██████████| 30/30 [00:00<00:00, 37.72it/s]


In [9]:
def augment_clahe(img,angled):
    #img_aug = skimage.exposure.equalize_adapthist(img, clip_limit=0.03)
    augmenter = iaa.AllChannelsCLAHE()
    img_aug = augmenter.augment_image(img)
    return img_aug

def augment_gblur(img,angled):
    augmenter = iaa.GaussianBlur(1.0)
    img_aug = augmenter.augment_image(img)
    return img_aug

def augment_aff_rotate(img,angled):
    augmenter = iaa.Affine(rotate=angled, mode='reflect')
    img_aug = augmenter.augment_image(img)
    return img_aug

def aug_generator(original_images,original_masks,p70_masks,images,masks,masks_reduced,counter,list_files,
                  im_height,im_width,patch_size_image,aug_type_fn,aug_img_ct_per_image):
    for im_i in tqdm(range(len(list_files))):
        image = original_images[im_i]
        mask = original_masks[im_i]
        mask_reduced = p70_masks[im_i]
        y = random.sample(range(0, im_height-patch_size_image), aug_img_ct_per_image)
        x = random.sample(range(0, im_width-patch_size_image), aug_img_ct_per_image)
        for i in range(aug_img_ct_per_image):
            angle = random.choice([90,180,270])
            image_patch = image[y[i]:y[i]+patch_size_image, x[i]:x[i]+patch_size_image, :]
            mask_patch = mask[y[i]:y[i]+patch_size_image, x[i]:x[i]+patch_size_image]
            mask_patch_reduced = mask_reduced[y[i]:y[i]+patch_size_image, x[i]:x[i]+patch_size_image]
            images[counter, :, :, :] = aug_type_fn(image_patch,angle)
            if aug_type_fn in [augment_aff_rotate]:
                masks[counter, :, :, 0] = aug_type_fn(mask_patch,angle)
                masks_reduced[counter, :, :, 0] = aug_type_fn(mask_patch_reduced,angle)
            else:
                masks[counter, :, :, 0] = mask_patch
                masks_reduced[counter, :, :, 0] = mask_patch_reduced
            counter += 1
    print(counter)
    return images,masks,masks_reduced,counter

In [10]:
images,masks,masks_reduced,counter = aug_generator(original_images,original_masks,p70_masks,images,masks,masks_reduced,
                                                  counter,list_files,im_height,im_width,patch_size_image,augment_clahe,augment_clahe_per_image)
images,masks,masks_reduced,counter = aug_generator(original_images,original_masks,p70_masks,images,masks,masks_reduced,
                                                  counter,list_files,im_height,im_width,patch_size_image,augment_gblur,augment_gblur_per_image)
images,masks,masks_reduced,counter = aug_generator(original_images,original_masks,p70_masks,images,masks,masks_reduced,
                                                  counter,list_files,im_height,im_width,patch_size_image,augment_aff_rotate,augment_aff_rotate_per_image)



100%|██████████| 30/30 [00:01<00:00, 17.54it/s]
 13%|█▎        | 4/30 [00:00<00:00, 38.99it/s]

17370


100%|██████████| 30/30 [00:00<00:00, 40.13it/s]
  3%|▎         | 1/30 [00:00<00:04,  6.55it/s]

18870


100%|██████████| 30/30 [00:04<00:00,  6.64it/s]

20370





In [11]:
# resizing images and normalizing the tissue images between - 1 to 1

images_resize = np.zeros((len(images),patch_size_model,patch_size_model,channels))
for i in tqdm(range(len(images))):
    images_resize[i, :, :, :] = skimageresize(images[i,:,:,:], (patch_size_model, patch_size_model), mode='constant', cval=0, preserve_range=True)
    
images_resize = ((images_resize * (2/255))-1)

masks_resize = np.zeros((len(masks),patch_size_model,patch_size_model,1))
for i in tqdm(range(len(masks))):
    masks_resize[i, :, :,0] = skimageresize(masks[i,:,:,0], (patch_size_model, patch_size_model), mode='constant', cval=0, preserve_range=True)
    
masks_reduced_resize = np.zeros((len(masks_reduced),patch_size_model,patch_size_model,1))
for i in tqdm(range(len(masks_reduced))):
    masks_reduced_resize[i, :, :,0] = skimageresize(masks_reduced[i,:,:,0], (patch_size_model, patch_size_model), mode='constant', cval=0, preserve_range=True)    
    

100%|██████████| 20370/20370 [00:54<00:00, 375.21it/s]
100%|██████████| 20370/20370 [00:27<00:00, 736.90it/s]
100%|██████████| 20370/20370 [00:28<00:00, 705.06it/s]


In [12]:
def unet(input_size):
    inputs = keras.engine.Input((input_size[0], input_size[1], input_size[2]), dtype=np.float32) # height, width, channels
    conv_1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
    conv_1 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv_1)
    conv_1 = BatchNormalization()(conv_1)
    pool_1 = MaxPooling2D((2, 2))(conv_1)
    
    conv_2 = Conv2D(32, (3, 3), activation='relu', padding='same')(pool_1)
    conv_2 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv_2)
    conv_2 = BatchNormalization()(conv_2)
    pool_2 = MaxPooling2D((2, 2))(conv_2)
    
    conv_3 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool_2)
    conv_3 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv_3)
    conv_3 = BatchNormalization()(conv_3)
    pool_3 = MaxPooling2D((2, 2))(conv_3)
    
    conv_4 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool_3)
    conv_4 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv_4)
    conv_4 = BatchNormalization()(conv_4)
    pool_4 = MaxPooling2D((2, 2))(conv_4)
    
    conv_5 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool_4)
    dropout = Dropout(0.2)(conv_5)
    conv_5 = Conv2D(256, (3, 3), activation='relu', padding='same')(dropout)
    conv_5 = BatchNormalization()(conv_5)
    
    up_6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv_5)
    up_6 = concatenate([up_6, conv_4], axis=3)
    conv_6 = Conv2D(128, (3, 3), activation='relu', padding='same')(up_6)
    conv_6 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv_6)
    conv_6 = BatchNormalization()(conv_6)
    
    up_7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv_6)
    up_7 = concatenate([up_7, conv_3], axis=3)
    conv_7 = Conv2D(64, (3, 3), activation='relu', padding='same')(up_7)
    conv_7 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv_7)
    conv_7 = BatchNormalization()(conv_7)
    
    up_8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv_7)
    up_8 = concatenate([up_8, conv_2], axis=3)
    conv_8 = Conv2D(32, (3, 3), activation='relu', padding='same')(up_8)
    conv_8 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv_8)
    conv_8 = BatchNormalization()(conv_8)
    
    up_9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(conv_8)
    up_9 = concatenate([up_9, conv_1], axis=3)
    conv_9 = Conv2D(16, (3, 3), activation='relu', padding='same')(up_9)
    conv_9 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv_9)
    conv_9 = BatchNormalization()(conv_9)

    dropout9 = Dropout(0.1)(conv_9)
    conv_10 = Conv2D(1, (1, 1), activation='sigmoid')(dropout9)
    print(conv_10.shape)
    return keras.models.Model(inputs=[inputs], outputs=[conv_10])

def create_callbacks(loc):
    print('Creating callbacks..')
    early = keras.callbacks.EarlyStopping(patience=5, verbose=1)
    checkpoint = keras.callbacks.ModelCheckpoint(loc, verbose=1, save_best_only=True)
    return [early, checkpoint]

def model_create(unet_fn,input_size,callback_loc):
    model = unet_fn(input_size)
    optimizer = keras.optimizers.Adam(lr=0.01)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    callbacks = create_callbacks(callback_loc)
    return model,callbacks

In [13]:
# Unet 1 
model1,callbacks1 = model_create(unet,(128,128,3),'/Users/ravitejachunduri/Documents/monuseg/models/keras_unet1_cp_10ep_128.h5')
model1.summary()
unet1 = model1.fit(images_resize, masks_resize, batch_size=25, epochs = 10, verbose=2, callbacks=callbacks1, shuffle=True, validation_split=0.1)
model1.save('/Users/ravitejachunduri/Documents/monuseg/models/keras_unet1_model_10ep_128.h5')

(None, 128, 128, 1)
Creating callbacks..
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 16) 448         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 128, 128, 16) 2320        conv2d_1[0][0]                   
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 16) 64          conv2d_2[0][0]                   
___________________________________________________

Train on 18333 samples, validate on 2037 samples
Epoch 1/10
 - 1241s - loss: 0.2803 - accuracy: 0.8506 - val_loss: 0.2853 - val_accuracy: 0.8459

Epoch 00001: val_loss improved from inf to 0.28528, saving model to /Users/ravitejachunduri/Documents/monuseg/models/keras_unet1_cp_10ep_128.h5
Epoch 2/10
 - 1246s - loss: 0.2266 - accuracy: 0.8732 - val_loss: 0.2629 - val_accuracy: 0.8752

Epoch 00002: val_loss improved from 0.28528 to 0.26294, saving model to /Users/ravitejachunduri/Documents/monuseg/models/keras_unet1_cp_10ep_128.h5
Epoch 3/10
 - 1260s - loss: 0.1964 - accuracy: 0.8858 - val_loss: 0.1991 - val_accuracy: 0.8834

Epoch 00003: val_loss improved from 0.26294 to 0.19912, saving model to /Users/ravitejachunduri/Documents/monuseg/models/keras_unet1_cp_10ep_128.h5
Epoch 4/10
 - 1267s - loss: 0.1840 - accuracy: 0.8907 - val_loss: 0.2240 - val_accuracy: 0.8782

Epoch 00004: val_loss did not improve from 0.19912
Epoch 5/10
 - 1265s - loss: 0.1754 - accuracy: 0.8940 - val_loss: 0.1911

In [14]:
# unet 2 
model2,callbacks2 = model_create(unet,(128,128,3),'/Users/ravitejachunduri/Documents/monuseg/models/keras_unet2_cp_10ep_128.h5')
model2.summary()
unet2 = model2.fit(images_resize, masks_reduced_resize, batch_size=25, epochs = 10, verbose=2, callbacks=callbacks2, shuffle=True, validation_split=0.1)
model2.save('/Users/ravitejachunduri/Documents/monuseg/models/keras_unet2_model_10ep_128.h5')

(None, 128, 128, 1)
Creating callbacks..
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 128, 128, 16) 448         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 128, 128, 16) 2320        conv2d_20[0][0]                  
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 128, 128, 16) 64          conv2d_21[0][0]                  
___________________________________________________

Train on 18333 samples, validate on 2037 samples
Epoch 1/10
 - 1302s - loss: 0.2020 - accuracy: 0.8942 - val_loss: 0.1891 - val_accuracy: 0.9038

Epoch 00001: val_loss improved from inf to 0.18911, saving model to /Users/ravitejachunduri/Documents/monuseg/models/keras_unet2_cp_10ep_128.h5
Epoch 2/10
 - 1288s - loss: 0.1483 - accuracy: 0.9152 - val_loss: 0.1656 - val_accuracy: 0.9062

Epoch 00002: val_loss improved from 0.18911 to 0.16565, saving model to /Users/ravitejachunduri/Documents/monuseg/models/keras_unet2_cp_10ep_128.h5
Epoch 3/10
 - 1288s - loss: 0.1370 - accuracy: 0.9195 - val_loss: 0.1473 - val_accuracy: 0.9162

Epoch 00003: val_loss improved from 0.16565 to 0.14728, saving model to /Users/ravitejachunduri/Documents/monuseg/models/keras_unet2_cp_10ep_128.h5
Epoch 4/10
 - 1284s - loss: 0.1305 - accuracy: 0.9219 - val_loss: 0.1550 - val_accuracy: 0.9148

Epoch 00004: val_loss did not improve from 0.14728
Epoch 5/10
 - 1279s - loss: 0.1253 - accuracy: 0.9239 - val_loss: 0.1429

In [25]:
# convert the unet 1 model to work for input size of 1024X1024
train_unet1 = load_model('/Users/ravitejachunduri/Documents/monuseg/models/keras_unet1_model_10ep_128.h5')
unet_main1,callbacks_main1 = model_create(unet,(1024,1024,3),'/Users/ravitejachunduri/Documents/monuseg/models/keras_unet_main1_10ep_1024.h5')
unet_main1.summary()
unet_main1.set_weights(train_unet1.get_weights())

(None, 1024, 1024, 1)
Creating callbacks..
Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 1024, 1024, 3 0                                            
__________________________________________________________________________________________________
conv2d_77 (Conv2D)              (None, 1024, 1024, 1 448         input_5[0][0]                    
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 1024, 1024, 1 2320        conv2d_77[0][0]                  
__________________________________________________________________________________________________
batch_normalization_37 (BatchNo (None, 1024, 1024, 1 64          conv2d_78[0][0]                  
_________________________________________________

In [26]:
# convert the unet 2 model to work for input size of 1024X1024
train_unet2 = load_model('/Users/ravitejachunduri/Documents/monuseg/models/keras_unet2_model_10ep_128.h5')
unet_main2,callbacks_main2 = model_create(unet,(1024,1024,3),'/Users/ravitejachunduri/Documents/monuseg/models/keras_unet_main2_10ep_1024.h5')
unet_main2.summary()
unet_main2.set_weights(train_unet2.get_weights())

(None, 1024, 1024, 1)
Creating callbacks..
Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            (None, 1024, 1024, 3 0                                            
__________________________________________________________________________________________________
conv2d_96 (Conv2D)              (None, 1024, 1024, 1 448         input_6[0][0]                    
__________________________________________________________________________________________________
conv2d_97 (Conv2D)              (None, 1024, 1024, 1 2320        conv2d_96[0][0]                  
__________________________________________________________________________________________________
batch_normalization_46 (BatchNo (None, 1024, 1024, 1 64          conv2d_97[0][0]                  
_________________________________________________

In [18]:
# test images pre-processing

test_files = os.listdir('/Users/ravitejachunduri/Documents/monuseg/Test_data/')
test_all_files=[]
for file in test_files:
    if (str.lower(file.strip().split('.')[-1]) in ['tif','tiff']):
        test_all_files.append(''.join(file.strip().split('.')[0:-1]))

# Test Images directory path:
test_path = '/Users/ravitejachunduri/Documents/monuseg/Test_data'

# target test images directory path:
test_images_path = '/Users/ravitejachunduri/Documents/monuseg/test_images'
if not os.path.exists(test_images_path):
    os.makedirs(test_images_path)
# target test masks directory path for masks :
test_bw_path = '/Users/ravitejachunduri/Documents/monuseg/test_bw'
if not os.path.exists(test_bw_path):
    os.makedirs(test_bw_path)
# Random color target test masks directory path:
test_clr_path = '/Users/ravitejachunduri/Documents/monuseg/test_clr'
if not os.path.exists(test_clr_path):
    os.makedirs(test_clr_path)

loading_images(test_all_files,test_path,test_path,'.png','.png',test_images_path,test_bw_path,test_clr_path)                                



100%|██████████| 14/14 [01:57<00:00,  8.36s/it]


In [19]:
# resizing images and normalizing the tissue images between - 1 to 1

original_test = np.zeros((len(test_all_files),im_height,im_width,channels), dtype=np.uint8)
for im_i in tqdm(range(len(test_all_files))):
    image_file = os.path.join(test_images_path, (test_all_files[im_i] + '.png'))
    original_test[im_i, :, :, :3] = skimageio.imread(image_file)  
    
original_test_resize = np.zeros((len(original_test),1024,1024,channels))
for i in tqdm(range(len(original_test))):
    original_test_resize[i, :, :, :] = skimageresize(original_test[i,:,:,:], (1024, 1024), mode='constant', cval=0, preserve_range=True)
    
original_test_resize = ((original_test_resize * (2/255))-1)



100%|██████████| 14/14 [00:01<00:00, 13.08it/s]
100%|██████████| 14/14 [00:01<00:00,  9.94it/s]


In [33]:
# TTA augmentation

predicted_masks1 = np.zeros((len(test_all_files), 1000, 1000), dtype=np.float32)
predicted_masks2 = np.zeros((len(test_all_files), 1000, 1000), dtype=np.float32)

prediction_path1 = '/Users/ravitejachunduri/Documents/monuseg/predictions_unet1'
if not os.path.exists(prediction_path1):
    os.makedirs(prediction_path1)
prediction_path2 = '/Users/ravitejachunduri/Documents/monuseg/predictions_unet2'
if not os.path.exists(prediction_path2):
    os.makedirs(prediction_path2)    
    
def tta_preds(files_list,test_images,predicted_masks,model):    
    for im_i in tqdm(range(len(files_list))):
        input_image = test_images[im_i]
        output = model.predict(np.expand_dims(input_image, axis=0))[0, :, :, 0]
        #
        aug1 = iaa.Fliplr(1.0)
        input2 = aug1.augment_image(input_image)
        output2 = model.predict(np.expand_dims(input2, axis=0))[0, :, :, 0]
        output2 = aug1.augment_image(output2)
        #
        aug2 = iaa.Flipud(1.0)
        input3 = aug2.augment_image(input_image)
        output3 = model.predict(np.expand_dims(input3, axis=0))[0, :, :, 0]
        output3 = aug2.augment_image(output3)
        #
        input4 = aug1.augment_image(aug2.augment_image(input_image))
        output4 = model.predict(np.expand_dims(input4, axis=0))[0, :, :, 0]
        output4 = aug1.augment_image(aug2.augment_image(output4))
        #
        aug4 = iaa.Affine(rotate=90)
        aug4_rev = iaa.Affine(rotate=-90)
        input5 = aug4.augment_image(input_image)
        output5 = model.predict(np.expand_dims(input5, axis=0))[0, :, :, 0]
        output5 = aug4_rev.augment_image(output5)
        #
        aug5 = iaa.Affine(rotate=270)
        aug5_rev = iaa.Affine(rotate=-270)
        input6 = aug5.augment_image(input_image)
        output6 = model.predict(np.expand_dims(input6, axis=0))[0, :, :, 0]
        output6 = aug5_rev.augment_image(output6)
        #
        aug6 = iaa.Sequential([iaa.Affine(rotate=270), iaa.Fliplr(1.0)])
        aug6_rev = iaa.Sequential([iaa.Fliplr(1.0), iaa.Affine(rotate=-270)])
        input7 = aug6.augment_image(input_image)
        output7 = model.predict(np.expand_dims(input7, axis=0))[0, :, :, 0]
        output7 = aug6_rev.augment_image(output7)
        #
        aug7 = iaa.Sequential([iaa.Affine(rotate=90), iaa.Fliplr(1.0)])
        aug7_rev = iaa.Sequential([iaa.Fliplr(1.0), iaa.Affine(rotate=-90)])
        input8 = aug7.augment_image(input_image)
        output8 = model.predict(np.expand_dims(input8, axis=0))[0, :, :, 0]
        output8 = aug7_rev.augment_image(output8) 
        #
        output_mask = (output + output2 + output3 + output4 + output5 + output6 + output7 + output8) / 8
        output_mask = skimageresize(output_mask, (1000, 1000), mode='constant', cval=0, preserve_range=True)
        temp1 = output_mask > 0.45
    
        predicted_masks[im_i] = temp1.astype(np.float32)
        
    return predicted_masks
    
predicted_masks_unet1 = tta_preds(test_all_files,original_test_resize,predicted_masks1,unet_main1)
predicted_masks_unet2 = tta_preds(test_all_files,original_test_resize,predicted_masks2,unet_main2)

100%|██████████| 14/14 [01:53<00:00,  8.10s/it]
100%|██████████| 14/14 [01:50<00:00,  7.92s/it]


In [39]:
for im_i in tqdm(range(len(test_all_files))):
    skimageio.imsave(os.path.join(prediction_path1,test_all_files[im_i]+'.png'), predicted_masks_unet1[im_i].astype(np.uint8),check_contrast=False)
    
for im_i in tqdm(range(len(test_all_files))):
    skimageio.imsave(os.path.join(prediction_path2,test_all_files[im_i]+'.png'), predicted_masks_unet2[im_i].astype(np.uint8),check_contrast=False)


100%|██████████| 14/14 [00:06<00:00,  2.08it/s]
100%|██████████| 14/14 [00:06<00:00,  2.31it/s]


In [53]:
# watershed
from scipy import ndimage
from skimage.feature import peak_local_max
from skimage.morphology import watershed

pred_files_path = os.listdir('/Users/ravitejachunduri/Documents/monuseg/predictions_unet1/')
pred_files = []
for file in tqdm(pred_files_path):
    if file.strip().split('.')[-1] in ['png']:
        pred_files.append(''.join(file.strip().split('.')[0:-1])) 

final_preds_path = '/Users/ravitejachunduri/Documents/monuseg/pred_mats_final'
if not os.path.exists(final_preds_path):
    os.makedirs(final_preds_path)
gt_path = '/Users/ravitejachunduri/Documents/monuseg/gt_final'
if not os.path.exists(gt_path):
    os.makedirs(gt_path)    
    
for filename in tqdm(pred_files):
    image1 = skimageio.imread(os.path.join('/Users/ravitejachunduri/Documents/monuseg/predictions_unet1/',filename+'.png'))
    image2 = skimageio.imread(os.path.join('/Users/ravitejachunduri/Documents/monuseg/predictions_unet2/',filename+'.png'))
    gtimage = skimageio.imread(os.path.join(test_bw_path,filename+'.png'))
    
    D = ndimage.distance_transform_edt(image2)
    localMax = peak_local_max(D, indices=False, min_distance=0, labels=image2)
    markers, n = ndimage.label(localMax, structure=np.ones((3, 3)))
    labels = watershed(-D, markers, mask=image1)
    savemat(os.path.join(final_preds_path, filename + '.mat'), {'predicted_map':labels})
    
    labels2, n2 = ndimage.label(gtimage)
    savemat(os.path.join(gt_path, filename + '.mat'), {'gt_map':labels2})

print ('Prediction mats are created')    

100%|██████████| 14/14 [00:00<00:00, 29656.69it/s]
100%|██████████| 14/14 [00:04<00:00,  3.37it/s]

Prediction mats are created





In [75]:
# calculation of performance matric - Aggregated Jaccard Index

def calculate_aji_new(mask1, mask2):
    # Mask1 >> Prediction mask
    # Mask2 >> Ground truth mask
    labels1 = mask1['predicted_map']
    #n1 = max(np.unique(labels1))
    labels2 = mask2['gt_map']
    #n2 = max(np.unique(labels2))
    gt_map = labels2
    predicted_map = labels1
    gt_list = np.unique(labels2)
    gt_list = gt_list[1:]
    ngt = len(gt_list)
    pr_list = np.unique(labels1)
    pr_list = pr_list[1:]
    pr_list = np.concatenate([np.expand_dims(pr_list, axis=1), np.zeros(shape=(len(pr_list), 1), dtype=np.int8)], axis=1)
    npredicted = len(pr_list[:,1])
    overall_correct_count = 0
    union_pixel_count = 0
    i = len(gt_list)
    while len(gt_list)>0:
        gt = 1*(gt_map == gt_list[i-1])
        predicted_match = np.multiply(gt, predicted_map)
        if np.count_nonzero(predicted_match) == 0:
            union_pixel_count += np.count_nonzero(gt)
            gt_list = np.delete(gt_list, i-1, 0)
            i = len(gt_list)
        else:
            predicted_nuc_index = np.unique(predicted_match)
            predicted_nuc_index = predicted_nuc_index[1:]
            JI = 0
            for j in range(len(predicted_nuc_index)):
                matched = 1*(predicted_map == predicted_nuc_index[j])
                nJI = np.count_nonzero(gt & matched) / np.count_nonzero(gt | matched)
                if nJI > JI:
                    best_match = predicted_nuc_index[j]
                    JI = nJI
            predicted_nuclei = 1*(predicted_map == best_match)
            overall_correct_count += np.count_nonzero(gt & predicted_nuclei)
            union_pixel_count += np.count_nonzero(gt | predicted_nuclei)
            gt_list = np.delete(gt_list, i-1, 0)
            i = len(gt_list)
            index = np.argwhere(pr_list[:,0] == best_match)
            pr_list[index, 1] += 1
    unused_nuclei_list = np.argwhere(pr_list[:,1] == 0)
    for k in range(len(unused_nuclei_list)):
        unused_nuclei = 1*(predicted_map == pr_list[unused_nuclei_list[k], 0])
        union_pixel_count += np.count_nonzero(unused_nuclei)
    aji = overall_correct_count * 1.0 / union_pixel_count
    return aji

In [82]:
agg_aji =[]
for i in tqdm(range(len(test_all_files))):
    pred_mask = loadmat(os.path.join(final_preds_path,test_all_files[i]+'.mat'))
    gt_mask = loadmat(os.path.join(gt_path,test_all_files[i]+'.mat'))
    aji_val = calculate_aji_new(pred_mask, gt_mask)
    agg_aji.append(aji_val)

print('The aggregated jaccard index for all images:{}'.format(sum(agg_aji)/len(agg_aji)))    

100%|██████████| 14/14 [03:59<00:00, 17.11s/it]

The aggregated jaccard index for all images:0.637433139509428



