##  U-net based with Boundary loss for COVID detection
* This kernel use Attention structure

* Boundary loss reference by this [GITHUB]( https://github.com/LIVIAETS/boundary-loss)

# * Let's see how it performs!!

# Read dicom files

In [None]:
import glob
import pandas  as pd
import numpy   as np
import nibabel as nib
import matplotlib.pyplot as plt
import pickle
import cv2
import tensorflow as tf

In [None]:
# Read and examine metadata
data = pd.read_csv('../input/covid19-ct-scans/metadata.csv')
data.sample(5)
data.head(3)

In [None]:
def read_nii(filepath):
    '''
    Reads .nii file and returns pixel array
    '''
    ct_scan = nib.load(filepath)
    #n1_header = ct_scan.dataobj.slope
    array   = ct_scan.get_fdata()
    array   = np.rot90(np.array(array)) #this data needs to rotate 90 degrees 
    return(array)

In [None]:
# Read sample
sample_ct = read_nii(data.loc[2,'ct_scan'])
sample_mask= read_nii(data.loc[2, 'infection_mask'])

In [None]:
print (sample_ct.shape, np.unique(sample_mask))

# Check HU transform is done or not

In [None]:

imgs_to_process = sample_ct[...,1]

plt.hist(imgs_to_process.flatten(), bins=50, color='c')
plt.xlabel("Hounsfield Units (HU)")
plt.ylabel("Frequency")
plt.show()

# No need to transform

![Hu scale](https://www.researchgate.net/profile/M_Kholief/publication/306033192/figure/fig2/AS:613926819610632@1523382968429/The-Hounsfield-scale-of-CT-numbers.png)

# Since Lungs HU is in interval[-400, 600], we could discard unnecessary pixels

Check if virus is filtered or not

In [None]:
sample_ct_windowed = np.copy(sample_ct)
sample_ct_windowed[sample_ct_windowed <= -600] = -600
sample_ct_windowed[sample_ct_windowed >= 400] = 400

fig = plt.figure(figsize = (18,15))

plt.subplot(1,6,1)
plt.imshow(sample_ct[...,100], cmap = 'bone')
plt.title('original CT')

plt.subplot(1,6,2)
plt.imshow(sample_mask[...,100], cmap = 'nipy_spectral')
#plt.imshow(sample_ct_windowed[...,20],alpha = 0.5, cmap = "bone")
plt.title('original infection mask')



plt.subplot(1,6,3)
plt.imshow(np.rot90(sample_ct[:, 100, :],1), cmap = 'bone')
plt.title('original CT')

plt.subplot(1,6,4)
plt.imshow(np.rot90(sample_mask[:, 100, :],1), cmap = 'nipy_spectral')
plt.imshow(np.rot90(sample_ct_windowed[:, 100, :],1), alpha = 0.5, cmap = "bone")
plt.title('original infection mask')



plt.subplot(1,6,5)
plt.imshow(np.rot90(sample_ct[100],1), cmap = 'bone')
plt.title('original CT')

plt.subplot(1,6,6)
plt.imshow(np.rot90(sample_mask[100],1), cmap = 'nipy_spectral')
plt.imshow(np.rot90(sample_ct_windowed[100],1),alpha = 0.5, cmap = "bone")
plt.title('original infection mask')


# Not filtered, good!

# Now we have two choices, we could take lungs mask from CT as input or just use CT images

# This time I choose using CT directly 

In [None]:
CT = []
Mask = []
img_size = 128
max = 0


for case in range(len(data)): #Concat all cases to list
    ct = read_nii(data['ct_scan'][case])
    mask = read_nii(data['infection_mask'][case])
    if (max < np.max(ct)):
        max = np.max(ct)
    
    
    for imgsize in range(ct.shape[2]): #Convert pixals to 1-d array
        
        ct_img = cv2.resize(ct[..., imgsize], dsize = (img_size, img_size),interpolation = cv2.INTER_AREA).astype('float64')
        
        mask_img = cv2.resize(mask[..., imgsize],dsize=(img_size, img_size),interpolation = cv2.INTER_AREA).astype('uint8')
        CT.append(ct_img[..., np.newaxis])
        Mask.append(mask_img[..., np.newaxis])
        

In [None]:
CT = np.array(CT)

Mask = np.array(Mask)
print (np.unique(Mask))

In [None]:
print (CT.shape)

# Show data image

In [None]:
fig = plt.figure(figsize = (18,15))

plt.subplot(1,2,1)
plt.imshow(CT[100][...,0], cmap = 'bone')
plt.title('original CT')

plt.subplot(1,2,2)
plt.imshow(CT[100][...,0], cmap = 'bone')
plt.imshow(Mask[100][...,0],alpha = 0.5, cmap = "nipy_spectral")
plt.title('original infection mask')


# --------------------------------------------------

# Normalize pixel values in range [0,1] is a good idea before training

In [None]:
mins = 0.5*max
maxs = 99.5*max
norm_data = (CT-mins)/(maxs-mins)

In [None]:
plt.figure(figsize = (9,9))

plt.imshow(norm_data[100][...,0], cmap = 'bone')

In [None]:
print (np.unique(norm_data), np.unique(CT))

# Split into training and validation groups

In [None]:
from sklearn.model_selection import train_test_split
CT_train, CT_test, Mask_train, Mask_test = train_test_split(norm_data, Mask, test_size = 0.1)

# Build Attention Unet
Here we use a slight deviation on the U-Net standard

In [None]:
class attention_unet():
  def __init__(self,img_rows=128,img_cols=128):
    self.img_rows=img_rows
    self.img_cols=img_cols
    self.img_shape=(self.img_rows,self.img_cols,1)
    self.df=64
    self.uf=64
    
  def build_unet(self):
    def conv2d(layer_input,filters,dropout_rate=0,bn=False):
      d=layers.Conv2D(filters,kernel_size=(3,3),strides=(1,1),padding='same')(layer_input)
      if bn:
        d=layers.BatchNormalization()(d)
      d=layers.Activation('relu')(d)
      
      d=layers.Conv2D(filters,kernel_size=(3,3),strides=(1,1),padding='same')(d)
      if bn:
        d=layers.BatchNormalization()(d)
      d=layers.Activation('relu')(d)
      
      if dropout_rate:
        d=layers.Dropout(dropout_rate)(d)
      
      return d
    
    def deconv2d(layer_input,filters,bn=False):
      u=layers.UpSampling2D((2,2))(layer_input)
      u=layers.Conv2D(filters,kernel_size=(3,3),strides=(1,1),padding='same')(u)
      if bn:
        u=layers.BatchNormalization()(u)
      u=layers.Activation('relu')(u)
      
      return u
    
    def attention_block(F_g,F_l,F_int,bn=False):
      g=layers.Conv2D(F_int,kernel_size=(1,1),strides=(1,1),padding='valid')(F_g)
      if bn:
        g=layers.BatchNormalization()(g)
      x=layers.Conv2D(F_int,kernel_size=(1,1),strides=(1,1),padding='valid')(F_l)
      if bn:
        x=layers.BatchNormalization()(x)
#       print(g.shape)
#       print(x.shape)
      psi=layers.Add()([g,x])
      psi=layers.Activation('relu')(psi)
      
      psi=layers.Conv2D(1,kernel_size=(1,1),strides=(1,1),padding='valid')(psi)
      
      if bn:
        psi=layers.BatchNormalization()(psi)
      psi=layers.Activation('sigmoid')(psi)
      
      return layers.Multiply()([F_l,psi])


    #def con_bt(inputs):
        
     #   conv1=conv2d(inputs,self.df)

        
      #  return conv1
    
    
    inputs=layers.Input(shape=self.img_shape)
    
    
    #concat1 = layers.Concatenate()([a,s,c])
    
    pool1=layers.MaxPooling2D((2,2))(inputs)
    
    
    conv1=conv2d(pool1,self.df)
    pool1=layers.MaxPooling2D((2,2))(conv1)
    
    conv2=conv2d(pool1,self.df*2,bn=True)
    pool2=layers.MaxPooling2D((2,2))(conv2)
    
    conv3=conv2d(pool2,self.df*4,bn=True)
    pool3=layers.MaxPooling2D((2,2))(conv3)
    
    conv4=conv2d(pool3,self.df*8,dropout_rate=0.5,bn=True)
    pool4=layers.MaxPooling2D((2,2))(conv4)
    
    conv5=conv2d(pool4,self.df*16,dropout_rate=0.5,bn=True)
    
    up6=deconv2d(conv5,self.uf*8,bn=True)
    conv6=attention_block(up6,conv4,self.uf*8,bn=True)
    up6=layers.Concatenate()([up6,conv6])
    conv6=conv2d(up6,self.uf*8)
    
    up7=deconv2d(conv6,self.uf*4,bn=True)
    conv7=attention_block(up7,conv3,self.uf*4,bn=True)
    up7=layers.Concatenate()([up7,conv7])
    conv7=conv2d(up7,self.uf*4)
    
    up8=deconv2d(conv7,self.uf*2,bn=True)
    conv8=attention_block(up8,conv2,self.uf*2,bn=True)
    up8=layers.Concatenate()([up8,conv8])
    conv8=conv2d(up8,self.uf*2)
    
    up9=deconv2d(conv8,self.uf,bn=True)
    conv9=attention_block(up9,conv1,self.uf,bn=True)
    up9=layers.Concatenate()([up9,conv9])
    conv9=conv2d(up9,self.uf)
    
    outputs=layers.Conv2D(1,kernel_size=(1,1),strides=(1,1),activation='sigmoid')(conv9)
    
    new_up= deconv2d(outputs,self.uf,bn=True)
        
    
    new_conv1=conv2d(new_up,self.df)
    new_pool1=layers.MaxPooling2D((2,2))(new_conv1)
    
    new_conv2=conv2d(new_pool1,self.df*2,bn=True)
    new_pool2=layers.MaxPooling2D((2,2))(new_conv2)
    
    new_conv3=conv2d(new_pool2,self.df*4,bn=True)
    new_pool3=layers.MaxPooling2D((2,2))(new_conv3)
    
    new_conv4=conv2d(new_pool3,self.df*8,dropout_rate=0.5,bn=True)
    new_pool4=layers.MaxPooling2D((2,2))(new_conv4)
    
    new_conv5=conv2d(new_pool4,self.df*16,dropout_rate=0.5,bn=True)
    
    new_up6=deconv2d(new_conv5,self.uf*8,bn=True)
    new_conv6=attention_block(new_up6,new_conv4,self.uf*8,bn=True)
    new_up6=layers.Concatenate()([new_up6,new_conv6])
    new_conv6=conv2d(new_up6,self.uf*8)
    
    new_up7=deconv2d(new_conv6,self.uf*4,bn=True)
    new_conv7=attention_block(new_up7,new_conv3,self.uf*4,bn=True)
    new_up7=layers.Concatenate()([new_up7,new_conv7])
    new_conv7=conv2d(new_up7,self.uf*4)
    
    new_up8=deconv2d(new_conv7,self.uf*2,bn=True)
    new_conv8=attention_block(new_up8,new_conv2,self.uf*2,bn=True)
    new_up8=layers.Concatenate()([new_up8,new_conv8])
    new_conv8=conv2d(new_up8,self.uf*2)
    
    new_up9=deconv2d(new_conv8,self.uf,bn=True)
    new_conv9=attention_block(new_up9,new_conv1,self.uf,bn=True)
    new_up9=layers.Concatenate()([new_up9,new_conv9])
    new_conv9=conv2d(new_up9,self.uf)
    
    
    outputs2=layers.Conv2D(1,kernel_size=(1,1),strides=(1,1),activation='sigmoid')(new_conv9)

    
    
    
    model=Model(inputs= inputs, outputs=outputs2)
    
    return model



# Define BatchNormalization

In [None]:
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
import keras.layers as layers
from keras.models import Model



# batchnormalization
def BatchActivate(x):
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x
# block
def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True):
    x = Conv2D(filters, size, strides=strides, padding=padding)(x)
    if activation == True:
        x = BatchActivate(x)
    return x
# residual_block
def residual_block(blockInput, num_filters=16, batch_activate = False):
    x = BatchActivate(blockInput)
    x = convolution_block(x, num_filters, (3,3) )
    x = convolution_block(x, num_filters, (3,3), activation=False)
    x = Add()([x, blockInput])
    if batch_activate:
        x = BatchActivate(x)
    return x


# Loss functions

In [None]:
from keras.losses import binary_crossentropy
from keras import backend as K

def dice_loss(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = y_true_f * y_pred_f
    score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1. - score

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

def bce_logdice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) - K.log(1. - dice_loss(y_true, y_pred))

def weighted_bce_loss(y_true, y_pred, weight):
    epsilon = 1e-7
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    logit_y_pred = K.log(y_pred / (1. - y_pred))
    loss = weight * (logit_y_pred * (1. - y_true) + 
                     K.log(1. + K.exp(-K.abs(logit_y_pred))) + K.maximum(-logit_y_pred, 0.))
    return K.sum(loss) / K.sum(weight)

def weighted_dice_loss(y_true, y_pred, weight):
    smooth = 1.
    w, m1, m2 = weight, y_true, y_pred
    intersection = (m1 * m2)
    score = (2. * K.sum(w * intersection) + smooth) / (K.sum(w * m1) + K.sum(w * m2) + smooth)
    loss = 1. - K.sum(score)
    return loss

def weighted_bce_dice_loss(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    # if we want to get same size of output, kernel size must be odd
    averaged_mask = K.pool2d(
            y_true, pool_size=(50, 50), strides=(1, 1), padding='same', pool_mode='avg')
    weight = K.ones_like(averaged_mask)
    w0 = K.sum(weight)
    weight = 5. * K.exp(-5. * K.abs(averaged_mask - 0.5))
    w1 = K.sum(weight)
    weight *= (w0 / w1)
    loss = weighted_bce_loss(y_true, y_pred, weight) + dice_loss(y_true, y_pred)
    return loss

In [None]:
from keras.optimizers import Adam

def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

# Implement of "Boundary loss for highly unbalanced segmentation"

In [None]:
from scipy.ndimage import distance_transform_edt as distance


def calc_dist_map(seg):
    res = np.zeros_like(seg)
    posmask = seg.astype(np.bool)

    if posmask.any():
        negmask = ~posmask
        res = distance(negmask) * negmask - (distance(posmask) - 1) * posmask

    return res


def calc_dist_map_batch(y_true):
    y_true_numpy = y_true
    return np.array([calc_dist_map(y)
                     for y in y_true_numpy]).astype(np.float32)


def surface_loss_keras(y_true, y_pred):
    y_true_dist_map = tf.py_func(func=calc_dist_map_batch,
                                     inp=[y_true],
                                     Tout=tf.float32)
    multipled = y_pred * y_true_dist_map
    return K.mean(multipled)



In [None]:
from keras.callbacks import ModelCheckpoint, Callback


class AlphaScheduler(Callback):
  def init(self, alpha, update_fn):
    self.alpha = alpha
    self.update_fn = update_fn
  def on_epoch_end(self, epoch, logs=None):
    updated_alpha = self.update_fn(K.get_value(self.alpha))

alpha = K.variable(1, dtype='float32')

def update_alpha(value):
  return np.clip(value - 0.01, 0.01, 1)


## Define Loss function
 
# We should considering both boundary loss and weighted binary cross entropy dice loss

In [None]:
def gl_sl_wrapper(alpha):
    def gl_sl(y_true, y_pred):
        return alpha* weighted_bce_dice_loss(y_true, y_pred) +  (1-alpha)* surface_loss_keras(y_true, y_pred)
    return gl_sl

## Set Training Check Point

In [None]:
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau
weight_path="{}_weights.best.hdf5".format('model')

checkpoint = ModelCheckpoint(weight_path, monitor='val_dice_coef', verbose=1, 
                             save_best_only=True, mode='max', save_weights_only = True)

reduceLROnPlat = ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, 
                                   patience=3, 
                                   verbose=1, mode='max', epsilon=0.0001, cooldown=2, min_lr=1e-6)
early = EarlyStopping(monitor="val_dice_coef", 
                      mode="max", 
                      patience=15) # probably needs to be more patient, but kaggle time is limited
callbacks_list = [checkpoint, early, reduceLROnPlat]

## Comiple model

In [None]:
Net=attention_unet()
unet=Net.build_unet()

unet.compile(loss=gl_sl_wrapper(alpha),
             optimizer=Adam(1e-4),
             metrics=[dice_coef, 'binary_accuracy'])

unet.summary()

## Start Training

In [None]:
EPOCHS = 100
BS = 16

from keras.preprocessing.image import ImageDataGenerator

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
    horizontal_flip=True, fill_mode="nearest")

# train the network
H = unet.fit_generator(aug.flow(CT_train, Mask_train, batch_size=BS),
    validation_data=(CT_test, Mask_test), steps_per_epoch=len(CT_train) // BS,
    epochs=EPOCHS, verbose=1,shuffle=True, callbacks=[checkpoint])

In [None]:

unet.load_weights(weight_path)
unet.save('model.h5')


## Plot loss history

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Dice loss')
plt.legend(['Train', 'Val'], loc = 'upper left')
plt.show()


# Run the test data

In [None]:
predicted = unet.predict(CT_test)
fig = plt.figure(figsize = (18,15))

plt.subplot(1,3,1)
plt.imshow(CT_test[180][...,0], cmap = 'bone')
plt.title('original CT image')

plt.subplot(1,3,2)
plt.imshow(CT_test[180][...,0], cmap = 'bone')
plt.imshow(Mask_test[180][...,0],alpha = 0.5, cmap = "nipy_spectral")
plt.title('original infection mask')

plt.subplot(1,3,3)
plt.imshow(CT_test[180][...,0], cmap = 'bone')
plt.imshow(predicted[180][...,0],alpha = 0.5,cmap = "nipy_spectral")
plt.title('predicted infection mask')