# Nuclei Segmentation by NB(Nuclei Boundary) model

## Initial setup

In [1]:
import os
import glob
import skimage.io as io
import numpy as np

from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as K

KeyboardInterrupt: 

## Configuration

In [None]:
image_size = 256 # this means input images are 256*256.
batch_size = 8
epochs = 2

# Data augmentation
horizontal_flip = True
vertical_flip = True

## Image generator

In [None]:
background = [255,0,0]
boundary = [0,255,0]
inside = [0,0,255]

COLOR_DICT = np.array([background, boundary, inside])
num_class = 3

def adjustData(img,mask,target_size):
    img = img / 255
    onehot = np.zeros((mask.shape[0], target_size[0], target_size[1], num_class), dtype=np.uint8)
    for i in range(num_class):
        cat_color = COLOR_DICT[i]
        temp = np.where((mask[:, :, :, 0] == cat_color[0]) &
                        (mask[:, :, :, 1] == cat_color[1]) &
                        (mask[:, :, :, 2] == cat_color[2]), 1, 0)
        onehot[:, :, :, i] = temp
    return (img,onehot)

def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "rgb",
                    mask_color_mode = "rgb",image_save_prefix  = "image",mask_save_prefix  = "mask",
                    save_to_dir = None,target_size = (256,256),seed = 1):
    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)
    image_generator = image_datagen.flow_from_directory(
        train_path,
        classes = [image_folder],
        class_mode = None,
        color_mode = image_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = image_save_prefix,
        seed = seed)
    mask_generator = mask_datagen.flow_from_directory(
        train_path,
        classes = [mask_folder],
        class_mode = None,
        color_mode = mask_color_mode,
        target_size = target_size,
        batch_size = batch_size,
        save_to_dir = save_to_dir,
        save_prefix  = mask_save_prefix,
        seed = seed)
    train_generator = zip(image_generator, mask_generator)
    for (img,mask) in train_generator:
        img,mask = adjustData(img,mask,target_size)
        yield (img,mask)

In [None]:
# data augumentation
data_gen_args = dict(horizontal_flip = horizontal_flip,
                    vertical_flip = vertical_flip)
val_gen_args = dict(horizontal_flip = False,
                    vertical_flip = False)

# generator of training data and validation data
trainGene = trainGenerator(batch_size = batch_size,
                        train_path = 'data',
                        image_folder = 'train_image',
                        mask_folder = 'train_label',
                        aug_dict = data_gen_args,
                        target_size = (image_size, image_size),
                        save_to_dir = None)
valGene = trainGenerator(batch_size=1,
                        train_path = 'data',
                        image_folder = 'test_image',
                        mask_folder = 'test_label',
                        aug_dict = val_gen_args,
                        target_size = (image_size, image_size),
                        save_to_dir = None)

## Define Model

In [None]:
def conv_bn(filters,x):
    conv = Conv2D(filters, 3, padding = 'same', kernel_initializer = 'he_normal')(x)
    conv = BatchNormalization()(conv)
    conv = Activation('relu')(conv)
    conv = Conv2D(filters, 3, padding = 'same', kernel_initializer = 'he_normal')(conv)
    conv = BatchNormalization()(conv)
    conv = Activation('relu')(conv)
    return conv

def unet(input_size = (256,256,3)):
    inputs = Input(input_size)
    conv1 = conv_bn(64,inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = conv_bn(128,pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = conv_bn(256,pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = conv_bn(512,pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = conv_bn(1024,pool4)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv5))
    merge6 = concatenate([conv4,up6], axis = 3)
    conv6 = conv_bn(512,merge6)
    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = conv_bn(256,merge7)
    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = conv_bn(128,merge8)
    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis = 3)
    conv9 = conv_bn(64,merge9)
    conv10 = Conv2D(3, 1, activation = 'softmax')(conv9)

    model = Model(inputs = inputs, outputs = conv10)

    return model

In [None]:
model = unet(input_size = (image_size, image_size, 3))
model.summary()
model.compile(optimizer = Adam(lr = 1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])

steps_per_epoch = len(glob.glob("data/train_image/*"))//batch_size
validation_steps = len(glob.glob("data/test_image/*"))

model_checkpoint = ModelCheckpoint('unet.h5', monitor='loss',verbose=1, save_best_only=True)
history = model.fit_generator(trainGene,
                              validation_data = valGene,
                              steps_per_epoch = steps_per_epoch,
                              epochs = epochs,
                              validation_steps = validation_steps,
                              callbacks=[model_checkpoint])

## Plot learning curve

In [None]:
fig = plt.figure()
# plt.plot(range(1,validation_steps+1),...)
plt.plot(range(1, 21), history.history['accuracy'], label='training')
plt.plot(range(1, 21), history.history['val_accuracy'], label='validation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
#fig.savefig('accuracy.png')