In [91]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
"""
2D Residual U-Net for lung nodule segmentation.
Adapted from: https://github.com/DuFanXin/deep_residual_unet/blob/master/res_unet.py
"""


def res_block(x, nb_filters, strides):
    res_path = BatchNormalization()(x)
    res_path = Activation(activation='relu')(res_path)
    res_path = Conv2D(filters=nb_filters[0], kernel_size=(3, 3), padding='same', strides=strides[0])(res_path)
    res_path = BatchNormalization()(res_path)
    res_path = Activation(activation='relu')(res_path)
    res_path = Conv2D(filters=nb_filters[1], kernel_size=(3, 3), padding='same', strides=strides[1])(res_path)

    shortcut = Conv2D(nb_filters[1], kernel_size=(1, 1), strides=strides[0])(x)
    shortcut = BatchNormalization()(shortcut)

    res_path = add([shortcut, res_path])
    return res_path


def encoder(x):
    resblock_units = [3, 4, 6, 3]
    to_decoder = []

    main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(x)
    main_path = BatchNormalization()(main_path)
    main_path = Activation(activation='relu')(main_path)

    main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path)

    shortcut = Conv2D(filters=64, kernel_size=(1, 1), strides=(1, 1))(x)
    shortcut = BatchNormalization()(shortcut)

    main_path = add([shortcut, main_path])
    # first branching to decoder
    to_decoder.append(main_path)

    for units in range(resblock_units[0]):
        main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)])
    main_path = MaxPooling2D(pool_size=(2, 2))(main_path)    
    to_decoder.append(main_path)

    for units in range(resblock_units[1]):
        main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)])
    main_path = MaxPooling2D(pool_size=(2, 2))(main_path)    
    to_decoder.append(main_path)
    
    for units in range(resblock_units[2]):
        main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)])
    main_path = MaxPooling2D(pool_size=(2, 2))(main_path)    
    to_decoder.append(main_path)

    for units in range(resblock_units[3]):
        main_path = res_block(main_path, [512, 512], [(1, 1), (1, 1)])
    main_path = MaxPooling2D(pool_size=(2, 2))(main_path)    
    to_decoder.append(main_path)

    return to_decoder


def decoder(x, from_encoder):
    main_path = UpSampling2D(size=(2, 2))(x)
    main_path = concatenate([main_path, from_encoder[3]], axis=3)
    main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)])
    
    main_path = UpSampling2D(size=(2, 2))(x)
    main_path = concatenate([main_path, from_encoder[3]], axis=3)
    main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path)
    main_path = concatenate([main_path, from_encoder[2]], axis=3)
    main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path)
    main_path = concatenate([main_path, from_encoder[1]], axis=3)
    main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path)
    main_path = concatenate([main_path, from_encoder[0]], axis=3)
    main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)])

    return main_path


def validate_input(x):
    if x.shape.as_list()[1:-1] != [512, 512]:
        pad = [(512 - int(x.shape[1])) // 2, (512 - int(x.shape[2])) // 2] 
        padding = tf.constant([[0, 0], pad, pad, [0, 0]])
        return tf.pad(x, paddings=padding, mode='REFLECT', constant_values=0)
    else:
        return x

def validate_output(x):
    if x.shape.as_list()[1:-1] != [324, 324]:
        pad = [(512 - int(x.shape[1])) // 2, (512 - int(x.shape[2])) // 2] 
        padding = tf.constant([[0, 0], pad, pad, [0, 0]])
        return tf.pad(x, paddings=padding, mode='REFLECT', constant_values=0)
    else:
        return x
    
def deep_residual_unet(input_shape=(128, 128, 3)):
    inputs = Input(shape=input_shape)
    
    to_decoder = encoder(validate_input(x=inputs))
    
    path = BatchNormalization()(to_decoder[-1])
    
    path = Activation(activation='relu')(path)
    
    path = Conv2D(filters=1024, kernel_size=(3, 3), padding='same', strides=(2, 2))(path)
    

    path = decoder(path, from_encoder=to_decoder)

    output = Conv2D(filters=3, kernel_size=(1, 1), activation='softmax')(path)

    return Model(inputs=[inputs], outputs=[output])



In [92]:
model = deep_residual_unet()
model.summary()

Model: "model_19"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_45 (InputLayer)           [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
tf.compat.v1.pad_21 (TFOpLambda (None, 512, 512, 3)  0           input_45[0][0]                   
__________________________________________________________________________________________________
conv2d_784 (Conv2D)             (None, 512, 512, 64) 1792        tf.compat.v1.pad_21[0][0]        
__________________________________________________________________________________________________
batch_normalization_734 (BatchN (None, 512, 512, 64) 256         conv2d_784[0][0]                 
___________________________________________________________________________________________