In [74]:
import tensorflow as tf
import matplotlib as plt

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
# Tensorflow addons for instance normalization as described in Improved Unet Paper
import tensorflow_addons as tfa
import os

import pydot



In [98]:
# Constants
INSTANCE_NORMALIZATION_ARGS = dict(
    axis=3,                             # Axis being normalised
    center=True,                        # Signal to add beta as an offset to the normalised tensor
    scale=True,                         # Signal to multiply by gamma
    beta_initializer='random_uniform',  
    gamma_initializer='random_uniform') 

LEAKY_ALPHA = 0.01
# Optimiser Information



In [69]:
## Model Creation
def build_model(input_size = (512,512,3)):
    inputs = Input(input_size)
    
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2,2))(conv1)
    
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2,2))(conv2)
    
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2,2))(conv3)
    
    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2,2))(drop4)
    
    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    upsam6 = UpSampling2D(size=(2,2))(drop5)
    up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsam6)
    merge6 = concatenate([drop4,up6], axis=3)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
    
    upsam7 = UpSampling2D(size=(2,2))(conv6)
    up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsam7)
    merge7 = concatenate([conv3,up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
    
    upsam8 = UpSampling2D(size=(2,2))(conv7)
    up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsam8)
    merge8 = concatenate([conv2,up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
    
    upsam9 = UpSampling2D(size=(2,2))(conv8)
    up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(upsam9)
    merge9 = concatenate([conv1,up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    
    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=conv10)
    model.summary()
    model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics='Accuracy')
    
    return model
    

In [93]:
def context_module(input, out_filter):
    # First Convolution block
    c1 = Conv2D(filters=out_filter, kernel_size=(3,3), padding='same')(input)
    c2 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(c1)
    c3 = LeakyReLU(alpha=LEAKY_ALPHA)(c2)
    
    # DropOut
    c4 = Dropout(0.3)(c3)
    
    # Secound Convolution block
    c5 = Conv2D(filters=out_filter, kernel_size=(3,3), padding='same')(c4)
    c6 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(c5)
    c7 = LeakyReLU(alpha=LEAKY_ALPHA)(c6)
    
    # Preactivation residual add
    c8 = Add()([input,c7])
    
    return c8

# Module that recombines the features following concatenation and reduces the number of feature maps for memory
def localization_module(input, out_filter):
    # First Convolution block
    l1 = Conv2D(filters=out_filter, kernel_size=(3,3), padding='same')(input)
    l2 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(l1)
    l3 = LeakyReLU(alpha=LEAKY_ALPHA)(l2)
    
    # Secound Convolution block, of shape (1x1x1)
    l4 = Conv2D(filters=out_filter, kernel_size=(1,1), padding='same')(l3)
    l5 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(l4)
    l6 = LeakyReLU(alpha=LEAKY_ALPHA)(l5)
    
    return l6

# Upsamples features from a lower 'level' of the UNet to a higher spatial information
def upsampling_module(input, out_filter):
    # Upsample 
    u1 = UpSampling2D(size=(2, 2))(input)
    
    # Convolutional block
    u2 = Conv2D(filters=out_filter, kernel_size=(3,3), padding='same')(u1)
    u3 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(u2)
    u4 = LeakyReLU(alpha=LEAKY_ALPHA)(u3)
    
    return u4

# Connects context_modueles to reduce the resolution of the feature maps and allow for more features while aggregating
def context_connector(input, out_filter):
    cc1 = Conv2D(filters=out_filter, kernel_size=(3,3), strides=2, padding='same')(input)
    cc2 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(cc1)
    cc3 = LeakyReLU(alpha=LEAKY_ALPHA)(cc2)
    return cc3

In [102]:
def improved_unet(input_size = (512,512,3)):
    input = Input(shape=input_size)
    
    # Context Pathway
    # Layer 1
    x1 = Conv2D(filters=16, kernel_size=(3,3), padding='same')(input)
    x2 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(x1)
    x3 = LeakyReLU(alpha=LEAKY_ALPHA)(x2) 
    x4 = context_module(x3, 16)
    
    # Layer 2
    x5 = context_connector(x4, 32)
    x6 = context_module(x5, 32)
    
    # Layer 3
    x7 = context_connector(x5, 64)
    x8 = context_module(x7, 64)
    
    # Layer 4
    x9 = context_connector(x8, 128)
    x10 = context_module(x9, 128)
    
    # Layer 5.1
    x11 = context_connector(x10, 256)
    x12 = context_module(x11, 256)
    
    # Begin Localization Pathway
    # Layer 5.2
    x13 = upsampling_module(x12, 128)
    
    # Layer 4
    x14 = Concatenate()([x10, x13])
    x15 = localization_module(x14, 128)
    x16 = upsampling_module(x15, 64)
    
    # Layer 3
    x17 = Concatenate()([x8, x16])
    x18 = localization_module(x17, 64) # Segmentation 1 from here
    x19 = upsampling_module(x18, 32)
    
    # Layer 3: Segmentation
    seg1 = Activation('sigmoid')(x18)
    seg1 = upsampling_module(seg1, 32)
    
    # Layer 2
    x20 = Concatenate()([x6, x19])
    x21 = localization_module(x20, 32) # Segmentation 2 from here
    x22 = upsampling_module(x21, 16)
    
    # Layer 2: Segmentation
    seg2 = Activation('sigmoid')(x21)
    seg3 = Add()([seg1,seg2])
    seg3 = upsampling_module(seg3, 32)
    
    # Layer 1
    x23 = Concatenate()([x4, x22])
    x24 = Conv2D(filters=32, kernel_size=(3,3), padding='same')(x23)
    x25 = tfa.layers.InstanceNormalization(**INSTANCE_NORMALIZATION_ARGS)(x24)
    x26 = LeakyReLU(alpha=LEAKY_ALPHA)(x25) 
    
    # Layer 1: Segmentation
    seg4 = Activation('sigmoid')(x26)
    segFinal = Add()([seg3,seg4])
    
    # Output
    output = Conv2D(filters=1, kernel_size=(1,1), activation='sigmoid', padding='same')(segFinal)
    uNet = Model(inputs=input, outputs=output)
    
    return uNet
    

In [105]:
model = improved_unet()

tf.keras.utils.plot_model(model, show_shapes=True)


('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
