# Segmentation: U-Net and MS-D

In this notebook, we build two different models for binary segmentation of histological images. 

We use the [Warwick QU gland segmentation](https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/) (GlaS) data set consisting of H&E stained histological images from the GlaS@MICCAI 2015 competition. 

In [1]:
# URL of Warwick QU GlaS data set.
# will be downloaded and extracted to ~/.keras if not existent
URL = 'https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/download/warwick_qu_dataset_released_2016_07_08.zip'

# size of images for prediction
HEIGHT, WIDTH = 320, 448

# size of image patches for training 
## note: must be divisible by 32 for UNet 
PATCH_HEIGHT, PATCH_WIDTH = 128,128

In [2]:
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np
import pandas as pd
import os, glob, sys

# Utility functions

In [3]:
def sort_human(strings):
    def atoi(text):
        return int(text) if text.isdigit() else text
    
    def human_keys(text):
        import re
        return [ atoi(c) for c in re.split('(\d+)', text) ]
    
    strings.sort(key=human_keys)
    return strings

def import_image(filename, scale=True, expand=None, verbose=True):
    """
    Read an image file.

    Args:
        filename (str): a file name (path)
        scale (bool): should the returned image be scaled to float32 [0,1]
        expand (int): expands the specified axis of the image (for keras with channels last this should be -1)
        verbose (bool): print information about the image

    Returns:
        (numpy.array) returns an numpy array containing the image
    """
    from skimage.io import imread, imsave

    if(filename.endswith('.tif')):
        image=imread(filename, plugin='tifffile')
    else:
        image=imread(filename)
    if scale:
        if image.dtype==np.bool:
            image=image.astype('float32')
        elif image.dtype==np.int8:
            image=(image.astype('float32')+(2**8/2))/(2**8-1)
        elif image.dtype==np.int16:
            image=(image.astype('float32')+(2**16/2))/(2**16-1)
        elif image.dtype==np.uint8:
            image=image.astype('float32')/(2**8-1)
        elif image.dtype==np.uint16:
            image=image.astype('float32')/(2**16-1)
        else:
            warnings.warn('Scaling for dtype {} is not yet implemented!'.format(image.dtype))
    if expand!=None:
        image=np.expand_dims(image,axis=expand)
    if verbose:
        image_info(image)
    return image

def draw_model(model):
    try: 
        import graphviz
        from IPython.display import display,SVG
        from keras.utils.vis_utils import model_to_dot
        return SVG(model_to_dot(model, show_shapes=False, show_layer_names=False, rankdir='LR').create(prog='dot', format='svg'))
    except:
        raise ImportError('Graphviz not installed.')

# Get data

- Download data set

In [5]:
from pathlib import Path
home = str(Path.home())
data_folder = os.path.join(home, '/Users/aimachine/Documents/DeepLearningDatasets/Warwick QU Dataset (Released 2016_07_08)')

if not os.path.exists(data_folder):
    print('Downloading Warwick dataset')
    from tensorflow.keras.utils import get_file 
    dataset = get_file(
          fname=os.path.basename(URL),
          origin=URL, 
          extract=True)
else:
    print('Warwick dataset found at {}'.format(data_folder) )
    print('Not downloading...')

Downloading Warwick dataset
Downloading data from https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/download/warwick_qu_dataset_released_2016_07_08.zip


- Read image files and masks

In [12]:
fns_images = sort_human(glob.glob(os.path.join(data_folder, 'train_*.bmp')))
fns_masks  = sort_human(glob.glob(os.path.join(data_folder, 'train_*_anno.bmp')))
# remove _anno.bmp from path_images
fns_images = sort_human(list(set(fns_images) - set(fns_masks)))
fns_images[:10]

[]

In [8]:
from skimage import transform

images=[]
for fn in fns_images:
    image = import_image(fn, scale=True, verbose=False)
    image = transform.resize(image, (HEIGHT, WIDTH), mode='constant')
    #print(fn, image.shape)
    images.append(image)
images = np.array(images)
#print(images.shape, images.min(), images.max())

masks=[]
for fn in fns_masks:
    basename = os.path.splitext(os.path.basename(fn))[0].split('_anno')[0]
    image = import_image(fn, verbose=False)
    image = transform.resize(image, (HEIGHT, WIDTH), order=0, mode='constant')
    mask = np.array(image>0, np.uint8)
    mask = mask[..., np.newaxis]
    masks.append(mask)
    
masks = np.array(masks)

print(masks.shape)

print(len(np.unique(masks)))

(0,)
0


In [10]:
images.shape

(0,)

- data sample patches 

In [None]:
def sample_2Dtiles(image, annotation, tile_shape=(32,32), samples=10):
    sample_im=[]
    sample_mask=[]
    for i in range(samples):
        x = np.random.randint(0,image.shape[0]-tile_shape[0]-1)
        y = np.random.randint(0,image.shape[1]-tile_shape[1]-1)
        sample_im.append(   image[x:x+tile_shape[0],y:y+tile_shape[1]])
        sample_mask.append( annotation[x:x+tile_shape[0],y:y+tile_shape[1]])
    return np.array(sample_im), np.array(sample_mask)

In [None]:
X = []
Y = []
for im, mask in zip(images, masks):
    x,y = sample_2Dtiles(im, mask, tile_shape=(PATCH_HEIGHT, PATCH_WIDTH), samples=10)
    X.append(x)
    Y.append(y)
    #print(x.shape)
X = np.array(X)
X = X.reshape((X.shape[0]*X.shape[1], X.shape[2], X.shape[3], X.shape[4]))
Y = np.array(Y)
Y = Y.reshape((Y.shape[0]*Y.shape[1], Y.shape[2], Y.shape[3], Y.shape[4]))
print(X.shape, Y.shape)

(850, 128, 128, 3) (850, 128, 128, 1)


In [None]:
from sklearn.model_selection import train_test_split
x_validation, x_train, y_validation, y_train = train_test_split(X, Y, test_size=.8)
print('X training: {}'.format(x_train.shape))
print('X validation: {}'.format(x_validation.shape))
print('Y training: {}'.format(y_train.shape))
print('Y validation: {}'.format(y_validation.shape))

X training: (680, 128, 128, 3)
X validation: (170, 128, 128, 3)
Y training: (680, 128, 128, 1)
Y validation: (170, 128, 128, 1)


# Models

In [None]:
def UNet(n_input_channels, n_output_channels):
    from keras.layers import Input, Dropout, UpSampling2D, MaxPooling2D, BatchNormalization, Conv2D, Concatenate
    from keras.models import Model

    inputs = Input((None, None, n_input_channels))
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
    conv4 = BatchNormalization()(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    conv5 = BatchNormalization()(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
    merge6 = Concatenate(axis=-1)([conv4,up6])
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
    conv6 = BatchNormalization()(conv6)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
    merge7 = Concatenate(axis=-1)([conv3,up7])
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
    conv7 = BatchNormalization()(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
    merge8 = Concatenate(axis=-1)([conv2,up8])
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
    conv8 = BatchNormalization()(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
    merge9 = Concatenate(axis=-1)([conv1,up9])
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv10 = Conv2D(n_output_channels, 1, activation = 'sigmoid')(conv9)
    
    return Model(inputs = inputs, outputs = conv10)



In [None]:
def MSD(width, depth, n_input_channels=1, n_output_channels=1, drop_out=0.0, batch_norm=False):
    from keras.layers import Input, merge, Dropout, BatchNormalization, Conv2D, Concatenate
    from keras.models import Model
    
    def convolution(n_filters, dilation, inputs, drop_out, name=None):
        if len(inputs) > 1:
            i = Concatenate()(inputs)
        else:
            i = inputs[0]
        c = Conv2D(filters=n_filters, dilation_rate=(dilation, dilation),
                                        kernel_size=(3,3), strides=(1,1), padding='same', 
                                        activation='relu', use_bias=True, name=name)(i)
        if batch_norm:
            c = BatchNormalization()(c)
            
        if drop_out:
            c = Dropout(rate=drop_out)(c)
            
        return c

    input_image = Input(shape=(None, None, n_input_channels), name='input')
    
    if batch_norm:
        bn = BatchNormalization(name='batchnorm')(input_image)
        inputs_layer = [bn]
    else:
        inputs_layer = [input_image]
    
    all_inputs = inputs_layer.copy()
    for i in range(depth):
        outputs_layer = []
        for j in range(width): 
            # dilation rate = 1-10 in alternating cycles
            s_ij = ((i*width + j) % 10) + 1
            name = 'layer_{}_stride_{}'.format(i, s_ij)
            outputs_filter = convolution(n_filters=1, dilation=s_ij, inputs=inputs_layer, drop_out=drop_out, name=name)
            outputs_layer.append(outputs_filter)
            all_inputs.append(outputs_filter)
        inputs_layer.extend( outputs_layer )
        
    c = Concatenate()(all_inputs)
    o = Conv2D(filters=n_output_channels, kernel_size=(1,1), padding='same', activation='sigmoid')(c)
    
    model = Model(inputs=input_image, outputs=o)
        
    return model

In [None]:
model_unet = UNet(n_input_channels=3, n_output_channels=1)
print('UNet: {:,} params'.format(model_unet.count_params()))

model_msd = MSD(n_input_channels=3, n_output_channels=1, 
                width=2, depth=40, drop_out=0.25, batch_norm=True)
print('MSD:  {:,} params'.format(model_msd.count_params()))

print('UNet has {:,} times as many params as MSD!'.format( int(model_unet.count_params()/model_msd.count_params())))

UNet: 31,056,397 params
MSD:  30,736 params
UNet has 1,010 times as many params as MSD!


In [None]:
#draw_model(model_unet)
#draw_model(model_msd)

# Loss function

We use the [Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) as a measure of similarity between the predicted binary mask and the ground truth mask. This is equivalent the [F1 score](https://en.wikipedia.org/wiki/F1_score) for binary classification. 

The loss is simply the negative of the dice coefficient.

In [None]:
import keras.backend as K
smooth=1.0
def dice_coef(y_true, y_pred):
    print(y_true.shape, 'shape')
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

# Train models

- define callback to plot progress

In [None]:
from IPython.display import clear_output
from keras.callbacks import Callback
class PlotLearning(Callback):

    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.loss = []
        self.val_loss = []
        self.acc = []
        self.val_acc = []
        self.fig = plt.figure()
        
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        
        self.logs.append(logs)
        self.x.append(self.i)
        self.loss.append(logs.get('loss'))
        self.acc.append(logs.get('dice_coef'))
        
        self.val_loss.append(logs.get('val_loss'))        
        self.val_acc.append(logs.get('val_dice_coef'))
        
        self.i += 1
        f, ax = plt.subplots(1, 2, figsize=(12,4), sharex=True)
        ax = ax.flatten()
        clear_output(wait=True)
        
        ax[0].plot(self.x, self.loss, label="loss", lw=2)
        ax[0].plot(self.x, self.val_loss, label="val loss")
        #ax[0].set_ylim(bottom=0.)
        ax[0].legend()
        ax[0].grid(True)
        
        ax[1].plot(self.x, self.acc, label="Dice coef", lw=2)
        ax[1].plot(self.x, self.val_acc, label="val Dice coef")
        #ax[1].set_ylim(bottom=0.)
        ax[1].legend()
        ax[1].grid(True)
        
        plt.show();
        
plotLoss = PlotLearning()

In [None]:
batchsize = 8
epochs = 40
reload = False

- Train UNet

In [None]:
%%time
model_unet.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef])

history_unet = model_unet.fit(x_train, y_train, batch_size=batchsize, epochs=epochs, 
                    verbose=1, shuffle=True, 
                    validation_data=(x_validation, y_validation),
                    callbacks=[plotLoss])



(?, ?, ?, ?) shape
(?, ?, ?, ?) shape
Train on 680 samples, validate on 170 samples
Epoch 1/40

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

- Train MSD

In [None]:
%%time
model_msd.compile(optimizer='adam', loss=dice_coef_loss, metrics=[dice_coef])

history_msd = model_msd.fit(x_train, y_train, batch_size=batchsize, epochs=epochs, 
                    verbose=1, shuffle=True, 
                    validation_data=(x_validation, y_validation),
                    callbacks=[plotLoss])


# Predict

- some utility function to predict sample and plot predictions as overlays

In [None]:
def predict(model, num_samples=10, indices=None):
    import time
    tic = time.time()
    if indices is not None:
        samples_ind = indices
    else:
        samples_ind = np.random.randint(0, len(images), num_samples)
    
    sample_images = images[samples_ind]
    sample_masks = masks[samples_ind]
    predictions = model.predict(sample_images, batch_size=4)
    toc = time.time()
    print('Time per image = {:.4f} sec'.format((toc-tic) / num_samples))
    return sample_images, sample_masks, predictions, samples_ind

def images_overlay(im1, im2, threshold=1e-2, title=None, alpha=0.25, cmap='Reds_r', ax=None, figsize=(5,5)):
    
    if ax is None:
        fig, ax = plt.subplots(1,1,figsize=figsize)
        
    im1 = np.squeeze(im1)
    im2 = np.squeeze(im2)
    
    ax.imshow(im1)

    mask = np.ma.masked_less(im2, threshold)
    ax.contourf(mask >  threshold, alpha=alpha, cmap=cmap)
    ax.contour(im2, levels=[threshold], alpha=1.0, cmap=cmap)
    
    ax.axis('off')
    
    if title is not None:
        ax.set_title(title, fontsize=14)
        
def plot_prediction(x, y, y_pred_unet, y_pred_msd, sample):
    
    if sample >= len(y_pred_unet):
        raise ValueError('sample {} does not exist, max = {}'.format(sample, len(y_pred)-1))
    
    fig, ax = plt.subplots(1,3,figsize=(10,5))
    images_overlay(x[sample], y_pred_unet[sample], ax=ax[0], title='prediction UNet')
    images_overlay(x[sample], y_pred_msd[sample],  ax=ax[1], title='prediction MSD')
    images_overlay(x[sample], y[sample], ax=ax[2], title='annotation')
    

In [None]:
%%time
num_samples = 10
x, y, y_pred_unet, indices = predict(model_unet, num_samples=num_samples)
x, y, y_pred_msd, _        = predict(model_msd, indices=indices)

In [None]:
_ = [plot_prediction(x, y, y_pred_unet, y_pred_msd, sample=i) for i in range(num_samples)]