<a href="https://colab.research.google.com/github/wuba2010work/CWP/blob/main/ResUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate, Input, ZeroPadding2D
from tensorflow.keras.models import Model

def batchnorm_relu(inputs):
    x = BatchNormalization()(inputs)
    x = Activation('relu')(x)
    return x

""" Conv layer """
def residual_block(inputs, num_filters, strides=1):
    x = batchnorm_relu(inputs)
    x = Conv2D(num_filters, (3, 3), padding="same", strides=strides)(x)
    x = batchnorm_relu(x)
    x = Conv2D(num_filters, (3, 3), padding="same", strides=1)(x)

    """ shortcut """
    s = Conv2D(num_filters, (1, 1), padding="same", strides=strides)(inputs)

    """ Addition """
    x = x + s
    return x

def decoder_block(inputs, skip_features, num_filters):
    x = UpSampling2D((2, 2))(inputs) # corrected inputs here.
    x = Concatenate()([x, skip_features])
    x = residual_block(x, num_filters)
    return x

def build_resunet(input_shape):
    inputs = Input(shape=input_shape)

    """ Encoder 1 """
    x = Conv2D(64, (3, 3), padding="same", strides=1)(inputs)
    x = batchnorm_relu(x)
    x = Conv2D(64, (3, 3), padding="same", strides=1)(x)
    s = Conv2D(64, (1, 1), padding="same", strides=1)(inputs)
    s1 = x + s
    print(s1.shape)

    """ Encoder 2 and 3 """
    s2 = residual_block(s1, 128, strides=2)
    s3 = residual_block(s2, 256, strides=2)

    """ Bridge """
    b = residual_block(s3, 512, strides=2)

    """ Decoder 1, 2 & 3 """
    x = decoder_block(b, s3, 256)
    x = decoder_block(x, s2, 128)
    x = decoder_block(x, s1, 64)

    """ Classify layer """
    outputs = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(x)

    """ Model """
    model = Model(inputs, outputs)
    return model

if __name__ == "__main__":
    input_shape = (256, 256, 3)
    model = build_resunet(input_shape)
    model.summary()

(None, 256, 256, 64)
