In [25]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Activation, MaxPool2D, Conv2DTranspose, BatchNormalization, Concatenate, Input
from tensorflow.keras.models import Model

In [28]:
def conv_block(inputs, num_filters):
  x= Conv2D(num_filters, 3, padding= "same")(inputs)
  x= BatchNormalization()(x)
  x= Activation("relu")(x)

  x= Conv2D(num_filters, 3, padding= "same")(x)
  x= BatchNormalization()(x)
  x= Activation("relu")(x)

  return x


def encoder_block(inputs, num_filters):
  x= conv_block(inputs, num_filters)
  p= MaxPool2D(2,2)(x)

  return x, p

def decoder_block(inputs, skip_features, num_filters):
  x= Conv2DTranspose(num_filters, (2,2), strides=2, padding= "same")(inputs)
  x= Concatenate()([x, skip_features])
  x= conv_block(x, num_filters)
  return x

def unet_build(input_shape):
  inputs= Input(input_shape)

  #encoder side of the unet
  s1, p1= encoder_block(inputs, 64)
  s2, p2= encoder_block(p1, 128)
  s3, p3= encoder_block(p2, 256)
  s4, p4= encoder_block(p3, 512)

  #bridge
  b1= conv_block(p4, 1024)

  #decoder side of the unet
  d1= decoder_block(b1, s4, 512)
  d2= decoder_block(d1, s3, 256)
  d3= decoder_block(d2, s2, 128)
  d4= decoder_block(d3, s1, 64)

  outputs= Conv2D(1, 1, padding= 'same', activation="sigmoid")(d4)
  #here I've used the sigmoid activation function for binary segmenation task
  #use the activation function as per the required purpose

  model= Model(inputs, outputs, name= "UNET")

  return model
 

In [29]:
if __name__ == "__main__":
  input_shape = (512, 512, 3)
  model= unet_build(input_shape)
  model.summary() 


Model: "UNET"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_61 (Conv2D)              (None, 512, 512, 64) 1792        input_7[0][0]                    
__________________________________________________________________________________________________
batch_normalization_59 (BatchNo (None, 512, 512, 64) 256         conv2d_61[0][0]                  
__________________________________________________________________________________________________
activation_59 (Activation)      (None, 512, 512, 64) 0           batch_normalization_59[0][0]     
_______________________________________________________________________________________________