In [3]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Input, Concatenate
from tensorflow.keras.models import Model

### **Conv Block**

In [4]:
def conv_block(inputs, num_filters):
  x = Conv2D(num_filters, 3, padding="same")(inputs)
  x = BatchNormalization()(x)
  x = Activation("relu")(x)

  x = Conv2D(num_filters, 3, padding="same")(inputs)
  x = BatchNormalization()(x)
  x = Activation("relu")(x)

  return x

In [5]:
x = Input((256, 256, 3))
y = conv_block(x, 32)

print(y.shape)

(None, 256, 256, 32)


### **Encoder Block**

In [6]:
def encoder_block(inputs, num_filters):
  x = conv_block(inputs, num_filters)
  p = MaxPool2D((2, 2))(x)
  return x, p

In [7]:
x = Input((256, 256, 3))
s, p = encoder_block(x, 32)

print(s.shape, p.shape)

(None, 256, 256, 32) (None, 128, 128, 32)


### **Decoder** Block

In [8]:
def decoder_block(inputs, skip, num_filters):
  x = Conv2DTranspose(num_filters, (2,2), strides=2, padding="same")(inputs)
  x = Concatenate()([x, skip])
  x = conv_block(x,num_filters)
  return x


In [9]:
x=Input((256, 256, 3))
s = Input((512, 512, 3))
y = decoder_block(x, s, 32)
print(y.shape)

(None, 512, 512, 32)


### **UNET**

In [20]:
def build_unet(input_shape):
  inputs = Input((256, 256, 3))

  # Encoder
  s1, p1 = encoder_block(inputs, 64)
  s2, p2 = encoder_block(p1, 128)
  s3, p3 = encoder_block(p2, 256)
  s4, p4 = encoder_block(p3, 512)

  # bridge
  b1 = conv_block(p4, 1024)
  # print(s1.shape, s2.shape, s3.shape, s4.shape)

  # Decoder
  d1 = decoder_block(b1, s4, 512)
  d2 = decoder_block(d1, s3, 256)
  d3 = decoder_block(d2, s2, 128)
  d4 = decoder_block(d3, s1, 64)

  outputs = Conv2D(1,1, padding="same", activation="relu")(d4)

  model = Model(inputs, outputs, name="UNET")
  return model

### Run the model

In [21]:
input_shape = (256, 256,3)
model = build_unet(input_shape)

In [22]:
model.summary()