In [2]:
import os
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
import numpy as np

In [3]:
def crop_image(in_image,out_image):

    input_shape = in_image.shape[2]

    out_shape = out_image.shape[2]

    delta = (input_shape - out_shape)//2
    
    return in_image[:,delta:input_shape-delta,delta:input_shape-delta,:]

In [4]:
def down_conv(a):

    model = tf.keras.models.Sequential([
                                        
    # This is the first convolution
    tf.keras.layers.Conv2D(a, (3,3)),

    # The second convolution
    tf.keras.layers.Conv2D(a, (3,3), activation='relu'),
    
    tf.keras.layers.Dropout(0.1)]
    )
    return model

In [5]:
def up_conv(f,a,b):  #f = filters, a = inputblock b = residual block

    x = tf.keras.layers.Conv2DTranspose(f, (3,3), strides= (2,2), padding = "same")(a)

    y = crop_image(b,x)

    x = tf.keras.layers.Concatenate(axis=3)([x, y])

    x = down_conv(f)(x)
    
    return x

In [6]:
def Unet():

  img_inputs = tf.keras.Input(shape=(572, 572, 3))

  #Down Convolution
  x1 = down_conv(64)(img_inputs)

  x2 = tf.keras.layers.MaxPooling2D(2,2)(x1)

  x3 = down_conv(128)(x2)

  x4 = tf.keras.layers.MaxPooling2D(2,2)(x3)

  x5 = down_conv(256)(x4)

  x6 = tf.keras.layers.MaxPooling2D(2,2)(x5)

  x7 = down_conv(512)(x6)

  x8 = tf.keras.layers.MaxPooling2D(2,2)(x7)

  x9 = down_conv(1024)(x8)

  #Up convolution
  x = up_conv(512,x9,x7) 

  x = up_conv(256,x,x5)

  x = up_conv(128,x,x3)

  x = up_conv(64,x,x1)

  output = tf.keras.layers.Conv2D(3, (1,1), activation='sigmoid')(x)

  model = tf.keras.Model(inputs=img_inputs, outputs=output, name="Unet")

  model.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())

  return model

In [7]:
model = Unet()

In [8]:
model.summary()

Model: "Unet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 572, 572, 3) 0                                            
__________________________________________________________________________________________________
sequential (Sequential)         (None, 568, 568, 64) 38720       input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 284, 284, 64) 0           sequential[0][0]                 
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 280, 280, 128 221440      max_pooling2d[0][0]              
_______________________________________________________________________________________________

In [9]:
tf.keras.utils.plot_model(model, "my_first_model.png")


Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
