# Segment the ISICs dataset with the Improved UNet

## Import Modules

In [2]:
import os
tf_device='/gpu:0'
import tensorflow as tf
print(tf.__version__)
print(tf.config.list_physical_devices("GPU"))

from tensorflow.keras import layers

2.5.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


# Segmentation
## Load Data

## Data Preprocessing
Normalization is critical when data is combined from different institutes and scanners.
Steps:
 1. Normalize each modality of each patient independently
 2. Clip the resulting images at [-5, 5]
 3. Rescale to [0, 1] (With the non-brain region being set to 0)

## Set Up Network Architecture

In [32]:
def make_model(width, height, channels):
    # Making an Improved UNet model based on their paper
    input_layer = layers.Input(shape=(width, height, channels))
    def down_path(in_layer, filter_size, stride_size):
        l1 = layers.Conv2D(filter_size, (3, 3), stride_size, padding="same")(in_layer)
        l2 = layers.BatchNormalization()(l1)
        l3 = layers.LeakyReLU(10e-2)(l2)
        # Context Module
        l4 = layers.Conv2D(filter_size, (3, 3), padding="same")(l3)
        l5 = layers.BatchNormalization()(l4)
        l6 = layers.LeakyReLU(10e-2)(l5)
        l7 = layers.Dropout(0.3)(l6)
        l8 = layers.Conv2D(filter_size, (3, 3), padding="same")(l7)
        l9 = layers.BatchNormalization()(l8)
        l10 = layers.LeakyReLU(10e-2)(l9)
        l11 = layers.Add()([l1, l10])
        return l11

    down_1 = down_path(input_layer, 16, (1, 1))
    down_2 = down_path(down_1, 32, (2, 2))
    down_3 = down_path(down_2, 64, (2, 2))
    down_4 = down_path(down_3, 128, (2, 2))
    down_5 = down_path(down_4, 256, (2, 2))

    def up_path(in_layer, concat_layer, filter_size):
        # Upsampling Module
        l1 = layers.UpSampling2D(size=(2, 2))(in_layer)
        l2 = layers.Conv2D(filter_size, (3, 3), padding="same")(l1)
        l3 = layers.BatchNormalization()(l2)
        l4 = layers.LeakyReLU(10e-2)(l3)
        l5 = layers.Concatenate()([concat_layer, l4])
        # Localization Module
        l6 = layers.Conv2D(filter_size, (3, 3), padding="same")(l5)
        l7 = layers.BatchNormalization()(l6)
        l8 = layers.LeakyReLU(10e-2)(l7)
        l9 = layers.Conv2D(filter_size, (3, 3), padding="same")(l8)
        l10 = layers.BatchNormalization()(l9)
        l11 = layers.LeakyReLU(10e-2)(l10)
        return l11

    up_1 = up_path(down_5, down_4, 128)
    up_2 = up_path(up_1, down_3, 64)
    up_3 = up_path(up_2, down_2, 32)
    # # Upsampling Module for the last part
    l1 = layers.UpSampling2D(size=(2, 2))(up_3)
    l2 = layers.Conv2D(16, (3, 3), padding="same")(l1)
    l3 = layers.BatchNormalization()(l2)
    l4 = layers.LeakyReLU(10e-2)(l3)
    l5 = layers.Concatenate()([down_1, l4])
    # # Connected to a normal convolution layer
    l6 = layers.Conv2D(32, (3, 3), padding="same")(l5)
    l7 = layers.BatchNormalization()(l6)
    l8 = layers.LeakyReLU(10e-2)(l7)


    # Last Layer connections

    output_layer = layers.Conv2D(1, (1, 1), activation="sigmoid")(l8)
    return tf.keras.Model(inputs=[input_layer], outputs=[output_layer])

model = make_model(512, 512, 3)
model.summary()

Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_27 (InputLayer)           [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_453 (Conv2D)             (None, 512, 512, 16) 448         input_27[0][0]                   
__________________________________________________________________________________________________
batch_normalization_443 (BatchN (None, 512, 512, 16) 64          conv2d_453[0][0]                 
__________________________________________________________________________________________________
leaky_re_lu_397 (LeakyReLU)     (None, 512, 512, 16) 0           batch_normalization_443[0][0]    
____________________________________________________________________________________________