In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Input, Reshape,ReLU
from tensorflow.keras.layers import Conv2D, MaxPooling2D,Conv2DTranspose, BatchNormalization
from tensorflow_addons.layers import InstanceNormalization

## The network architecture is define as follow:


![;;"](figs/networkArch.PNG "ss")

In [2]:
# dk: 3 by 3 Convolution_InstanceNorm ReLu with k filter and stride of 2
def d_block(input_tensor, k):
    x = input_tensor
    x = Conv2D(k, (3,3), strides=(2, 2))(x)
    x = InstanceNormalization()(x)
    x = ReLU()(x)
    return x

#uk: 3x3 Transposed Convolution with k filter and InstanceNorm with Relu
def u_block(input_tensor, k):
    x = input_tensor
    x = Conv2DTranspose(k, (3,3))(x)
    x = InstanceNormalization()(x)
    x = ReLU()(x)
    return x

In [3]:
# Rk denotes a pre-activation residual block with k filters
# The structure of a pre-activation residual block showing below.
def R_block(input_tensor,n_filters,stage):
    '''if stage>1: # first activation is just after conv1        
        x = layers.BatchNormalization()(input_tensor)        
        x = layers.Activation('relu')(x)    
    else:       '''
    x = input_tensor
    x = Conv2D(n_filters,(1,1))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(n_filters,(1,1))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(n_filters,(1,1))(x)
    # define short cut
    x = layers.add([x, input_tensor])
    return x



## Pre-Activative Residual blocks
!["pre_activative_resBlock"](figs/pre_activative_resBlock.png "pre_activative_resBlock")

In [4]:
'''c7s1-64, d128, d256, d512, R512, R512, R512, R512, R512, R512, R512, R512, R512, u256, u128,
u64, c7s1-8.'''
# The GeneratorModel
def Generator(input_shape=(512,512,3), nb_classes=4):
    img_input = tf.keras.layers.Input(shape=input_shape)
    #c7s1-64/
    x = Conv2D(64, (7,7),strides=(1, 1))(img_input)
    #d128, d256, d512
    x = d_block(x,128)
    x = d_block(x, 256)
    x = d_block(x, 512)
    # R512, R512, R512, R512, R512, R512, R512, R512, R512
    for i in range(1,10):
        x = R_block(x, 512,i)
    #u256, u128,u64
  
    x = u_block(x, 256)
    
    x = u_block(x, 128)

    x = u_block(x, 64)
    #c7s1-8
    x = Conv2D(8, (7,7),strides=(1, 1))(x)
    return models.Model(img_input, x, name='rnpa')

In [5]:
model = Generator()
model.summary()

Model: "rnpa"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 506, 506, 64) 9472        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 252, 252, 128 73856       conv2d[0][0]                     
__________________________________________________________________________________________________
instance_normalization (Instanc (None, 252, 252, 128 256         conv2d_1[0][0]                   
_______________________________________________________________________________________________

In [32]:
from tensorflow.keras.datasets import mnist