# **Attention Inception U-Net for 3D Brain Tumor Segmentation**

In [1]:
# importing all libraries

import tensorflow
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv3D, Conv3DTranspose, MaxPooling3D, BatchNormalization, Activation, Concatenate
from tensorflow.keras.models import Model

In [2]:
# Convolutional blocks for the U_Net model

def conv_block(input_mat):
  num_filters = 32
  kernel_size = 3
  batch_norm = True

  X = Conv3D(num_filters,kernel_size=(kernel_size,kernel_size,kernel_size),strides=(1,1,1),padding='same')(input_mat)
  X = BatchNormalization()(X)
  X = Activation('leaky_relu')(X)

  X = Conv3D(num_filters,kernel_size=(kernel_size,kernel_size,kernel_size),strides=(1,1,1),padding='same')(X)
  X = BatchNormalization()(X)
  X = Activation('leaky_relu')(X)

  return X

In [3]:
# Inception blocks for the U_Net model
def inception_block(inputs):
    n_filters = 32
    conv1x1 = Conv3D(n_filters, kernel_size=1, activation='relu', padding='same')(inputs)

    conv3x3 = Conv3D(n_filters, kernel_size=1, activation='relu', padding='same')(inputs)
    conv3x3 = Conv3D(n_filters, kernel_size=3, activation='relu', padding='same')(conv3x3)

    conv5x5 = Conv3D(n_filters, kernel_size=1, activation='relu', padding='same')(inputs)
    conv5x5 = Conv3D(n_filters, kernel_size=5, activation='relu', padding='same')(conv5x5)

    maxpool3x3 = MaxPooling3D(pool_size=(3, 3, 3), strides=1, padding='same')(inputs)
    maxpool3x3 = Conv3D(n_filters, kernel_size=1, activation='relu', padding='same')(maxpool3x3)

    output = concatenate([conv1x1, conv3x3, conv5x5, maxpool3x3], axis=-1)
    return output

In [4]:
# Attention Module for the U_Net model
def attention_block(inputs):
    f = Conv3D(1, kernel_size=1, activation='relu', padding='same')(inputs)
    g = Conv3D(1, kernel_size=1, activation='relu', padding='same')(inputs)
    h = Conv3D(inputs.shape[-1], kernel_size=1, activation='relu', padding='same')(inputs)
    s = tf.keras.layers.multiply([f, g])
    s = tf.keras.layers.Reshape((-1,))(s)
    s = tf.keras.layers.Activation('softmax')(s)
    s = tf.keras.layers.Reshape(inputs.shape[1:-1] + (1,))(s)
    output = tf.keras.layers.multiply([h, s])
    return output

In [5]:
# Define 3D Attention Inception U-Net model with attention blocks only at the encoder part
# Developer can extend the inception blocks to the decoder part as per the experimental needs.

def attention_inception_unet(input_shape):
    # Input Layer
    inputs = Input(shape=input_shape)

    # Level 1 (Encoder)
    enc_attention1 = attention_block(inputs)
    enc_inception1 = inception_block(enc_attention1)
    enc_pool1 = MaxPooling3D(pool_size=(2, 2, 2))(enc_inception1)

    # Level 2
    enc_attention2 = attention_block(enc_pool1)
    enc_inception2 = inception_block(enc_attention2)
    enc_pool2 = MaxPooling3D(pool_size=(2, 2, 2))(enc_inception2)

    # Level 3
    enc_attention3 = attention_block(enc_pool2)
    enc_inception3 = inception_block(enc_attention3)
    enc_pool3 = MaxPooling3D(pool_size=(2, 2, 2))(enc_inception3)

    # Level 4
    enc_attention4 = attention_block(enc_pool3)
    enc_inception4 = inception_block(enc_attention4)
    enc_pool4 = MaxPooling3D(pool_size=(2, 2, 2))(enc_inception4)

    # Level 5 (Bridge)
    bridge_conv = Conv3D(filters=1024, kernel_size=(3, 3, 3), padding='same', activation='relu')(enc_pool4)
    bridge_inception = inception_block(bridge_conv)

    # Level 4 (Decoder)

    dec_upconv4 = Conv3DTranspose(filters=512, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(bridge_inception)
    dec_concat4 = Concatenate(axis=-1)([dec_upconv4, enc_inception4])
    dec_inception4 = conv_block(dec_concat4)
    #print("level 4 DEc",dec_inception4.shape)

    # Level 3
    dec_upconv3 = Conv3DTranspose(filters=256, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(dec_inception4)
    dec_concat3 = Concatenate(axis=-1)([dec_upconv3, enc_inception3])
    dec_inception3 = conv_block(dec_concat3)

    # Level 2
    dec_upconv2 = Conv3DTranspose(filters=128, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(dec_inception3)
    dec_concat2 = Concatenate(axis=-1)([dec_upconv2, enc_inception2])
    dec_inception2 = conv_block(dec_concat2)

    # Level 1
    dec_upconv1 = Conv3DTranspose(filters=64, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding='same')(dec_inception2)
    dec_concat1 = Concatenate(axis=-1)([dec_upconv1, enc_inception1])
    dec_inception1 = conv_block(dec_concat1)

    # Output
    output = Conv3D(filters=4, kernel_size=(1, 1, 1), activation='softmax')(dec_inception1)

    model = Model(inputs=inputs, outputs=output)
    return model

In [6]:
# buiuding the model with the shape of (128, 128, 128, 4)
input_shape = (128, 128, 128, 4) # Adjust the input shape according to data
model = attention_inception_unet(input_shape)

print(model.input)

print(model.output)

KerasTensor(type_spec=TensorSpec(shape=(None, 128, 128, 128, 4), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'")
KerasTensor(type_spec=TensorSpec(shape=(None, 128, 128, 128, 4), dtype=tf.float32, name=None), name='conv3d_51/Softmax:0', description="created by layer 'conv3d_51'")


In [7]:
# printing the summary of the model
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128, 128, 128, 4)]   0         []                            
                                                                                                  
 conv3d (Conv3D)             (None, 128, 128, 128, 1)     5         ['input_1[0][0]']             
                                                                                                  
 conv3d_1 (Conv3D)           (None, 128, 128, 128, 1)     5         ['input_1[0][0]']             
                                                                                                  
 multiply (Multiply)         (None, 128, 128, 128, 1)     0         ['conv3d[0][0]',              
                                                                     'conv3d_1[0][0]']        