<a href="https://colab.research.google.com/github/szh141/colab/blob/main/Simple_Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://medium.com/@vipul.sarode007/u-net-unleashed-a-step-by-step-guide-on-implementing-and-training-your-own-segmentation-model-in-a38741776968

Encoder

In [1]:
#Let's create a function for one step of the encoder block, so as to increase the reusability when making custom unets

def encoder_block(filters, inputs):
  x = Conv2D(filters, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(inputs)
  s = Conv2D(filters, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(x)
  p = MaxPooling2D(pool_size = (2,2), padding = 'same')(s)
  return s, p #p provides the input to the next encoder block and s for skip connection in the decoder


Baseline or the bottom of the U-net

same as encoder, but no MaxPooling

In [2]:
#Baseline layer is just a bunch on Convolutional Layers to extract high level features from the downsampled Image
def baseline_layer(filters, inputs):
  x = Conv2D(filters, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(inputs)
  x = Conv2D(filters, kernel_size = (3,3), padding = 'same', strides = 1, activation = 'relu')(x)
  return x

Decoder

Concatenate the skip connection

has one more input parameter, the skip connection from the encoder

In [None]:
#Decoder Block
def decoder_block(filters, connections, inputs):
  x = Conv2DTranspose(filters, kernel_size = (2,2), padding = 'same', activation = 'relu', strides = 2)(inputs)
  skip_connections = concatenate([x, connections], axis = -1)
  x = Conv2D(filters, kernel_size = (2,2), padding = 'same', activation = 'relu')(skip_connections)
  x = Conv2D(filters, kernel_size = (2,2), padding = 'same', activation = 'relu')(x)
  return x

U-net model

In [None]:
def unet():
  #Defining the input layer and specifying the shape of the images
  inputs = Input(shape = (224,224,1))

  #defining the encoder
  s1, p1 = encoder_block(64, inputs = inputs)
  s2, p2 = encoder_block(128, inputs = p1)
  s3, p3 = encoder_block(256, inputs = p2)
  s4, p4 = encoder_block(512, inputs = p3)

  #Setting up the baseline
  baseline = baseline_layer(1024, p4)

  #Defining the entire decoder
  d1 = decoder_block(512, s4, baseline)
  d2 = decoder_block(256, s3, d1)
  d3 = decoder_block(128, s2, d2)
  d4 = decoder_block(64, s1, d3)

  #Setting up the output function for binary classification of pixels
  outputs = Conv2D(1, 1, activation = 'sigmoid')(d4)

  #Finalizing the model
  model = Model(inputs = inputs, outputs = outputs, name = 'Unet')

  return model

Dice coefficient definition

In [None]:
# Setting dice coefficient to evaluate our model
def dice_coeff(y_true, y_pred, smooth = 1):
    intersection = K.sum(y_true*y_pred, axis = -1)
    union = K.sum(y_true, axis = -1) + K.sum(y_pred, axis = -1)
    dice_coeff = (2*intersection+smooth) / (union + smooth)
    return dice_coeff

import libraries and training

In [None]:
# this is from the following post for categorizing images, not for U-net
#https://pub.aimind.so/understanding-the-learning-mechanism-of-convolutional-neural-networks-19a0568df252

import tensorflow as tf
#from tensorflow.keras.datasets import mnist
#from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D # Dense, Flatten, not need for U-net
#from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
#from sklearn.metrics import confusion_matrix

In [None]:
unet = unet()
unet.compile(loss = 'binary_crossentropy',
            optimizer = 'adam',
            metrics = ['accuracy', dice_coeff])

#Defining early stopping to regularize the model and prevent overfitting
early_stopping = EarlyStopping(monitor = 'val_loss', patience = 3, restore_best_weights = True)

#Training the model with 50 epochs (it will stop training in between because of early stopping)
unet_history = unet.fit(train_data, validation_data = [val_data],
                        epochs = 50, callbacks = [early_stopping])