<a href="https://colab.research.google.com/github/toanpt74/COLAB_RD/blob/main/Unet3plus.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import keras.regularizers
import tensorflow as tf
import keras as k
from keras.api.keras import activations
from keras.layers import Conv2D, Activation, Input, MaxPooling2D, BatchNormalization, Conv2DTranspose, concatenate, UpSampling2D, MaxPool2D


def conv_block(x, filters =16, kernel_size=(3,3), strides=(1,1),padding='same', is_bn=True, is_relu=True, n=2):
    for i in range(1, n+1):
        x = Conv2D(filters= filters, kernel_size= kernel_size, strides=strides, padding=padding,
               kernel_regularizer=keras.regularizers.l2(1e-4), kernel_initializer="he_normal")(x)
        if is_bn:
            x = BatchNormalization()(x)
        if is_relu:
            x = Activation("relu")(x)
    return x

def encoder_block(inputs, filters):
    x = conv_block(inputs, filters)
    p = MaxPooling2D((2,2))(x)
    return x, p
def decoder_block(inputs,skip_features ,num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inputs)
    x = concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def UNet_3Plus(INPUT_SHAPE, OUTPUT_CHANNELS):
    filters = [64, 128, 256, 512, 1024]
    inputs = Input(shape=INPUT_SHAPE, name='input_shape')
    #encoder
    #block 1
    e1, p1 = encoder_block(inputs, filters=filters[0])
    # block 2
    e2, p2 = encoder_block(p1, filters=filters[1])
    # block 3
    e3, p3 = encoder_block(p2, filters=filters[2])
    #block 4

    e4, p4 = encoder_block(p3, filters=filters[3])
    #block 5 Bridge

    e5 = conv_block(p4, filters[4])

    """ Decoder """
    cat_channels = filters[0]
    cat_blocks = len(filters)
    upsample_channels = cat_blocks * cat_channels
    """d4"""
    e1_d4 = MaxPool2D(pool_size=(8, 8))(e1)
    e1_d4 = conv_block(e1_d4, cat_channels, n=1)

    e2_d4 = MaxPool2D(pool_size=(4, 4))(e2)
    e2_d4 = conv_block(e2_d4, cat_channels, n=1)

    e3_d4 = MaxPool2D(pool_size=(2, 2))(e3)
    e3_d4 = conv_block(e3_d4, cat_channels, n=1)

    e4_d4 = conv_block(e4, cat_channels, n=1)

    e5_d4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(e5)
    e5_d4 = conv_block(e5_d4, filters=cat_channels, n=1)

    d4 = concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
    d4 = conv_block(d4, filters=upsample_channels, n=1)

    """ d3 """
    e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1)  # 320*320*64 --> 80*80*64
    e1_d3 = conv_block(e1_d3, cat_channels, n=1)  # 80*80*64 --> 80*80*64

    e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2)  # 160*160*256 --> 80*80*256
    e2_d3 = conv_block(e2_d3, cat_channels, n=1)  # 80*80*256 --> 80*80*64

    e3_d3 = conv_block(e3, cat_channels, n=1)  # 80*80*512 --> 80*80*64

    e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4)  # 40*40*320 --> 80*80*320
    e4_d3 = conv_block(e4_d3, cat_channels, n=1)  # 80*80*320 --> 80*80*64

    e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5)  # 20*20*320 --> 80*80*320
    e5_d3 = conv_block(e5_d3, cat_channels, n=1)  # 80*80*320 --> 80*80*64

    d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
    d3 = conv_block(d3, upsample_channels, n=1)  # 80*80*320 --> 80*80*320

    """ d2 """
    e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1)  # 320*320*64 --> 160*160*64
    e1_d2 = conv_block(e1_d2, cat_channels, n=1)  # 160*160*64 --> 160*160*64

    e2_d2 = conv_block(e2, cat_channels, n=1)  # 160*160*256 --> 160*160*64

    d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3)  # 80*80*320 --> 160*160*320
    d3_d2 = conv_block(d3_d2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4)  # 40*40*320 --> 160*160*320
    d4_d2 = conv_block(d4_d2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5)  # 20*20*320 --> 160*160*320
    e5_d2 = conv_block(e5_d2, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
    d2 = conv_block(d2, upsample_channels, n=1)  # 160*160*320 --> 160*160*320

    """ d1 """
    e1_d1 = conv_block(e1, cat_channels, n=1)  # 320*320*64 --> 320*320*64

    d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2)  # 160*160*320 --> 320*320*320
    d2_d1 = conv_block(d2_d1, cat_channels, n=1)  # 160*160*320 --> 160*160*64

    d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3)  # 80*80*320 --> 320*320*320
    d3_d1 = conv_block(d3_d1, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4)  # 40*40*320 --> 320*320*320
    d4_d1 = conv_block(d4_d1, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)  # 20*20*320 --> 320*320*320
    e5_d1 = conv_block(e5_d1, cat_channels, n=1)  # 320*320*320 --> 320*320*64

    d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
    d1 = conv_block(d1, upsample_channels, n=1)  # 320*320*320 --> 320*320*320

    # last layer does not have batchnorm and relu
    d = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)

    if OUTPUT_CHANNELS == 1:
        output = k.activations.sigmoid(d)
    else:
        output = k.activations.softmax(d)

    return tf.keras.Model(inputs=inputs, outputs=output, name='UNet_3Plus')

INPUT_SHAPE = [256, 256, 1]
OUTPUT_CHANNELS = 1
model = UNet_3Plus(INPUT_SHAPE, OUTPUT_CHANNELS)
model.summary()

