<a href="https://colab.research.google.com/github/ramzes2/UNet-Overview/blob/master/UNet_overview.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **U-Net: Convolutional Networks for Biomedical Image Segmentation**
Original article: [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597)

Training data was taken from 
[Konica-Minolta Pathological Image Segmentation Challenge](https://community.topcoder.com/longcontest/?module=ViewProblemStatement&rd=16950&pm=14622)


![alt text](https://cdn-images-1.medium.com/max/1600/1*SNvD04dEFIDwNAqSXLQC_g.jpeg)

## Data downloading and extraction

In [0]:
!wget https://www.dropbox.com/s/w0zajni9ny1w8nd/Pathological.zip

In [0]:
!rm -rf data
!mkdir data
!unzip Pathological.zip -d data/

In [0]:
!ls -l data

In [0]:
!echo "Images:"
!ls data/images
!echo "Masks:"
!ls data/truth



---


## Data overview

In [0]:
import os

TRAIN_IMAGES = 'data/images'
TRAIN_MASKS = 'data/truth'

file_names = []

for file_name in sorted(os.listdir(TRAIN_IMAGES)):
  if file_name.endswith('.tif'):
    base_name = file_name[:-4]
    image_path = f'{TRAIN_IMAGES}/{file_name}'
    mask_path = f'{TRAIN_MASKS}/{base_name}_mask.png'
    file_names.append((image_path, mask_path))

print('Total images:', len(file_names))    

In [0]:
from tqdm import tqdm
import cv2

ORIG_IMG_WIDTH = 500
ORIG_IMG_HEIGHT = 500
BORDER = (512 - ORIG_IMG_WIDTH)//2

Images = []
Masks = []

for img_file, mask_file in tqdm(file_names):
  image = cv2.imread(img_file)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  mask = cv2.imread(mask_file)
  
  image = cv2.copyMakeBorder(image, BORDER, BORDER, BORDER, BORDER, cv2.BORDER_REFLECT)
  mask = cv2.copyMakeBorder(mask, BORDER, BORDER, BORDER, BORDER, cv2.BORDER_REFLECT)[:, :, 1] == 255
  mask = mask.reshape(mask.shape[0], mask.shape[1], -1)
  
  Images.append(image)
  Masks.append(mask)

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams["axes.grid"] = False

f, ax = plt.subplots(4, 4, figsize=(16, 16))

img_idx = 0
for row in range(4):
  for col in range(2):
    ax[row, 2*col].imshow(Images[img_idx])
    ax[row, 2*col + 1].imshow(Masks[img_idx].squeeze(), cmap='gray')
    img_idx += 1


In [0]:
from sklearn.model_selection import train_test_split
import numpy as np

Images = np.array(Images).astype(np.float32)/255
Masks = np.array(Masks)

trainX, valX, trainY, valY = train_test_split(Images, Masks, test_size=0.3, random_state=42)
valX, testX, valY, testY = train_test_split(valX, valY, test_size=0.33, random_state=42)

print('Train images:', len(trainX))
print('Validation images:', len(valX))
print('Test images:', len(testX))



---


## U-net architecture
![alt text](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [0]:
from keras import models, layers, activations

def getUnetModel():
  def unetBlock(x, filters, skip_conn=None):
    if skip_conn is not None:
      x = layers.Concatenate() ([skip_conn, x])
      
    x = layers.Conv2D(filters=filters, kernel_size=3, padding='same') (x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filters=filters, kernel_size=3, padding='same') (x)
    x = layers.Activation('relu')(x)
    return x
    
  inp = layers.Input(shape=(512, 512, 3,))
  
  filters_cnt = 16
  
  encoder1 = unetBlock(inp, filters_cnt)
  pool1 = layers.MaxPooling2D(2) (encoder1)
  
  encoder2 = unetBlock(pool1, filters_cnt*2)
  pool2 = layers.MaxPooling2D(2) (encoder2)
  
  encoder3 = unetBlock(pool2, filters_cnt*4)
  pool3 = layers.MaxPooling2D(2) (encoder3)
  
  encoder4 = unetBlock(pool3, filters_cnt*8)
  pool4 = layers.MaxPooling2D(2) (encoder4)
  
  encoder5 = unetBlock(pool4, filters_cnt*16)
  
  upsampling4 = layers.UpSampling2D(2) (encoder5)
  decoder4 = unetBlock(upsampling4, filters_cnt*8, encoder4)
  
  upsampling3 = layers.UpSampling2D(2) (decoder4)
  decoder3 = unetBlock(upsampling3, filters_cnt*4, encoder3)
  
  upsampling2 = layers.UpSampling2D(2) (decoder3)
  decoder2 = unetBlock(upsampling2, filters_cnt*2, encoder2)
  
  upsampling1 = layers.UpSampling2D(2) (decoder2)
  decoder1 = unetBlock(upsampling1, filters_cnt, encoder1)
  
  output = layers.Conv2D(filters=1, kernel_size=1, padding='same', activation='sigmoid') (decoder1)
  
  model = models.Model(inputs=inp, outputs=output)
  
  return model

In [0]:
unet = getUnetModel()
print(unet.summary() )

In [0]:
from keras import optimizers, losses

unet.compile(optimizer=optimizers.Adam(lr=5e-4),
              loss='binary_crossentropy',
              metrics=['accuracy'])

In [0]:
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

batch_size=16
epochs=40
base_callbacks = [
    EarlyStopping(patience=25, verbose=True),
    ReduceLROnPlateau(patience=15, verbose=True)
]

In [0]:
unet.fit(x=trainX, y=trainY, batch_size=batch_size, validation_data=[valX, valY], epochs=epochs,
        callbacks=[ModelCheckpoint('model_unet.h5', save_best_only=True, save_weights_only=True, verbose=True)] + base_callbacks) 

In [0]:
def visualizePrediction(X, y_true, y_pred, samples, thr=0.5):
  f, ax = plt.subplots(samples, 4, figsize=(4*samples, 16))
  
  for i in range(samples):
    gt = np.squeeze(y_true[i]).astype(np.bool)
    mask = np.squeeze(y_pred[i]) >= thr
    
    vis = np.zeros((512, 512, 3), dtype=np.float32)
    vis[:, :, 1] = np.logical_and(mask, gt)
    vis[:, :, 0] = np.logical_and(gt, np.logical_not(mask))
    vis[:, :, 2] = np.logical_and(np.logical_not(gt), mask)
    
    ax[i, 0].imshow(X[i], vmin=0, vmax=1)
    ax[i, 1].imshow(np.squeeze(y_true[i]), cmap='gray', vmin=0, vmax=1)
    ax[i, 2].imshow(np.squeeze(y_pred[i]), cmap='gray', vmin=0, vmax=1)
    ax[i, 3].imshow(vis)

In [0]:
unet.load_weights('model_unet.h5')
p = unet.predict(valX, verbose=True)

visualizePrediction(valX, valY, p, 4)

In [0]:
from sklearn.metrics import accuracy_score

def calc_accuracy(y_true, y_pred):
  y_true = y_true[:, BORDER:-BORDER, BORDER:-BORDER]
  y_pred = y_pred[:, BORDER:-BORDER, BORDER:-BORDER]
  y_pred = np.round(y_pred)
  
  return accuracy_score(y_true.astype(np.uint32).flatten(), y_pred.astype(np.uint32).flatten())

In [0]:
unet_acc = (calc_accuracy(trainY, unet.predict(trainX)), calc_accuracy(valY, unet.predict(valX)), calc_accuracy(testY, unet.predict(testX)))


print('Train accuracy:', unet_acc[0])
print('Validation accuracy:',  unet_acc[1])
print('Test accuracy:',  unet_acc[2])

---
## How we can improve the accuracy?

1.   Augmentations
> * [albumentations](https://github.com/albu/albumentations)
> * [imgaug](https://github.com/aleju/imgaug)
2.   Batch normalization
3.   Dropout
4.   Custom loss-function
5.   Test Time Augmentation




## Custom loss-function

In [0]:
from keras import backend as K
import numpy as np

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)
  
def bce_dice_loss(y_true, y_pred):
  return losses.binary_crossentropy(y_true, y_pred) + dice_coef_loss(y_true, y_pred)

In [0]:
unet2 = getUnetModel()
unet2.compile(optimizer=optimizers.Adam(lr=5e-4),
              loss=bce_dice_loss,
              metrics=['accuracy', 'binary_crossentropy'])
unet2.load_weights('model_unet.h5')

unet2.fit(x=trainX, y=trainY, batch_size=batch_size, validation_data=[valX, valY], epochs=20,
        callbacks=[ModelCheckpoint('model_unet2.h5', save_best_only=True, save_weights_only=True, verbose=True)] + base_callbacks) 

In [0]:
unet2.load_weights('model_unet2.h5')

unet2_acc = (calc_accuracy(trainY, unet2.predict(trainX)), calc_accuracy(valY, unet2.predict(valX)), calc_accuracy(testY, unet2.predict(testX)))

print('Train accuracy:', unet2_acc[0])
print('Validation accuracy:',  unet2_acc[1])
print('Test accuracy:',  unet2_acc[2])



---


## Test Time Augmentation

In [0]:
def tta_transform(X, hor_flip, ver_flip):
  res = []
  for i in range(len(X)):
    img = X[i]
    if hor_flip:
      img = np.fliplr(img)
    if ver_flip:
      img = np.flipud(img)
    res.append(img)
    
  return np.array(res)

def tta_restore(X, hor_flip, ver_flip):
  res = []
  for i in range(len(X)):
    img = X[i]
    if ver_flip:
      img = np.flipud(img)
    if hor_flip:
      img = np.fliplr(img)
    
    res.append(img)
    
  return np.array(res)

def predict_tta(model, X):
  res = np.zeros((X.shape[0], X.shape[1], X.shape[2], 1), dtype=np.float32)
  
  for hor_flip in range(2):
    for ver_flip in range(2):
        X2 = tta_transform(X, hor_flip, ver_flip)
        p2 = model.predict(X2)
        p = tta_restore(p2, hor_flip, ver_flip)
        
        res += p
  
  return res/4

In [0]:
train_predict_tta = predict_tta(unet, trainX)
val_predict_tta = predict_tta(unet, valX)
test_predict_tta = predict_tta(unet, testX)

print('Unet binary_crossentropy loss:')
unet_tta_acc = (calc_accuracy(trainY, train_predict_tta), calc_accuracy(valY, val_predict_tta), calc_accuracy(testY, test_predict_tta))

print('Train accuracy:', unet_tta_acc[0])
print('Validation accuracy:',  unet_tta_acc[1])
print('Test accuracy:',  unet_tta_acc[2])

train_predict_tta2 = predict_tta(unet2, trainX)
val_predict_tta2 = predict_tta(unet2, valX)
test_predict_tta2 = predict_tta(unet2, testX)

print('')

print('Unet BCE+Dice loss:')
unet2_tta_acc = (calc_accuracy(trainY, train_predict_tta2), calc_accuracy(valY, val_predict_tta2), calc_accuracy(testY, test_predict_tta2))

print('Train accuracy:', unet2_tta_acc[0])
print('Validation accuracy:',  unet2_tta_acc[1])
print('Test accuracy:',  unet2_tta_acc[2])

## Final results

In [0]:
import pandas as pd

acc = np.array([unet_acc, unet2_acc,unet_tta_acc, unet2_tta_acc])

scores = pd.DataFrame({'method:': ['U-Net BCE', 'U-Net BCE+Dice', 'U-Net BCE TTA', 'U-Net BCE+Dice TTA'],
                      'train acc': acc[:, 0], 'val acc': acc[:, 1], 'test acc': acc[:, 2]})

scores


In [0]:
visualizePrediction(valX, valY, val_predict_tta2, 4)