In [7]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Activation, ReLU
from tensorflow.keras.layers import BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.models import Model, Sequential

For the implementation of U-Net, We will utilize three iterative blocks as shown in the architecture representation, namely the convolution operation block, the encoder block, and the decoder block.

In [16]:
def convolution_operations(input_data,filters=64):
  # First convolution block
  conv1 = Conv2D(filters, kernel_size = (3,3),padding = "same")(input_data)
  # padding = same means Output size remains the same as input size (padding is automatically added)
  batch_norm1 = BatchNormalization()(conv1)
  act1 = ReLU()(batch_norm1)

  # Second convolution block
  conv2 = Conv2D(filters, kernel_size = (3,3),padding = "same")(act1)
  # padding = same means Output size remains the same as input size (padding is automatically added)
  batch_norm2 = BatchNormalization()(conv2)
  act2 = ReLU()(batch_norm2)

  return act2

In [9]:
def encoder(entered_input, filters=64):
    # Collect the start and end of each sub-block for normal pass and skip connections
    enc1 = convolution_operations(entered_input, filters)
    MaxPool1 = MaxPooling2D(strides = (2,2))(enc1)
    return enc1, MaxPool1

In [10]:
def decoder(entered_input, skip, filters=64):
    # Upsampling and concatenating the essential features
    Upsample = Conv2DTranspose(filters, (2, 2), strides=2, padding="same")(entered_input)
    Connect_Skip = Concatenate()([Upsample, skip])
    out = convolution_operations(Connect_Skip, filters)
    return out

In [14]:
def UNET(Image_size):
  input1 = Input(Image_size)

  # Encoder block
  skip1, encoder_1 = encoder(input1,filters=64)
  skip2, encoder_2 = encoder(encoder_1, 64*2)
  skip3, encoder_3 = encoder(encoder_2, 64*4)
  skip4, encoder_4 = encoder(encoder_3, 64*8)

  # Convolution block
  conv_block = convolution_operations(encoder_4,64*16)

  # Decoder block
  decoder_1 = decoder(conv_block, skip4, 64*8)
  decoder_2 = decoder(decoder_1, skip3, 64*4)
  decoder_3 = decoder(decoder_2, skip2, 64*2)
  decoder_4 = decoder(decoder_3, skip1, 64)

  out = Conv2D(1,1,padding="same", activation="sigmoid")(decoder_4)

  model = Model(input1, out)
  return model

In [18]:
input_shape = (512, 512, 1)
model = UNET(input_shape)
model.summary()