# TransResUNet: Fully Convolutional Model for Lungs Segmentation

In [1]:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.applications.vgg16 import VGG16

# Residual Path

In [2]:
def res_path(inputs,filter_size,path_number):
    def block(x,fl):
        cnn1 = Conv2D(filter_size,(3,3),padding = 'same',activation="relu")(inputs)
        cnn2 = Conv2D(filter_size,(1,1),padding = 'same',activation="relu")(inputs)

        add = Add()([cnn1,cnn2])

        return add
    
    cnn = block(inputs, filter_size)
    if path_number <= 3:
        cnn = block(cnn,filter_size)
        if path_number <= 2:
            cnn = block(cnn,filter_size)
            if path_number <= 1:
                cnn = block(cnn,filter_size)

    return cnn

# Decoder Block

In [3]:
def decoder_block(inputs, mid_channels, out_channels):
    conv_kwargs = dict(
        activation='relu',
        padding='same',
        kernel_initializer='he_normal',
        data_format='channels_last'  
    )

    x = UpSampling2D(size=(2, 2))(inputs) 
    x = Conv2D(mid_channels, 3, **conv_kwargs)(x)
    x = Conv2D(out_channels, 3, **conv_kwargs)(x)
    return x


In [4]:
def TransResUNet(input_size=(512, 512, 1), output_channels=1):

    inputs = Input(input_size)
    x = inputs
    input_shape = input_size
    if input_size[-1] < 3:
        x = Conv2D(3, 1)(inputs)                         
        input_shape = (input_size[0], input_size[0], 3)  
    else:
        x = inputs
        input_shape = input_size

    encoder = VGG16(include_top=False, weights='imagenet', input_shape=input_shape)
       
    #first encoder block
    enc1 = encoder.get_layer(name='block1_conv1')(x)
    enc1 = encoder.get_layer(name='block1_conv2')(enc1)
    #second encoder block
    enc2 = MaxPooling2D(pool_size=(2, 2))(enc1)
    enc2 = encoder.get_layer(name='block2_conv1')(enc2)
    enc2 = encoder.get_layer(name='block2_conv2')(enc2)
    #third encoder block
    enc3 = MaxPooling2D(pool_size=(2, 2))(enc2)
    enc3 = encoder.get_layer(name='block3_conv1')(enc3)
    enc3 = encoder.get_layer(name='block3_conv2')(enc3)
    enc3 = encoder.get_layer(name='block3_conv3')(enc3)

    #center block
    center = MaxPooling2D(pool_size=(2, 2))(enc3)
    center = decoder_block(center, 512, 256)

    # Decoder block corresponding to third encoder
    res_path3 = res_path(enc3,128,3)
    dec3 = concatenate([res_path3, center], axis=3)
    dec3 = decoder_block(dec3, 256, 64)
    # Decoder block corresponding to second encoder
    res_path2 = res_path(enc2,64,2)
    dec2 = concatenate([res_path2, dec3], axis=3)
    dec2 = decoder_block(dec2, 128, 64)
    # Final Block concatenation with first encoded feature 
    res_path1 = res_path(enc1,32,1)
    dec1 = concatenate([res_path1, dec2], axis=3)
    dec1 = Conv2D(32, 3, padding='same', kernel_initializer='he_normal')(dec1)
    dec1 = ReLU()(dec1)
   

    # Output
    if output_channels > 1:
        out = tf.nn.log_softmax_v2(dec1, axis=3)
    else:
        out = Conv2D(output_channels, 1)(dec1)
        out = Activation('sigmoid')(out)  

    model = Model(inputs=[inputs], outputs=[out])
    return model

# Model Summary

In [5]:
model = TransResUNet()
model.summary()

Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 3)  6           input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 512, 512, 64) 1792        conv2d[0][0]                     
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 512, 512, 