In [19]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Input, Reshape,ReLU,LeakyReLU,Concatenate
from tensorflow.keras.layers import Conv2D, MaxPooling2D,Conv2DTranspose, BatchNormalization
from tensorflow_addons.layers import InstanceNormalization

## The network architecture is define as follow:


![;;"](figs/networkArch.PNG "ss")

In [2]:
# dk: 3 by 3 Convolution_InstanceNorm ReLu with k filter and stride of 2
def d_block(input_tensor, k):
    x = input_tensor
    x = Conv2D(k, (3,3), strides=(2, 2))(x)
    x = InstanceNormalization()(x)
    x = ReLU()(x)
    return x

#uk: 3x3 Transposed Convolution with k filter and InstanceNorm with Relu
def u_block(input_tensor, k):
    x = input_tensor
    x = Conv2DTranspose(k, (3,3))(x)
    x = InstanceNormalization()(x)
    x = ReLU()(x)
    return x

In [3]:
# Rk denotes a pre-activation residual block with k filters
# The structure of a pre-activation residual block showing below.
def R_block(input_tensor,n_filters,stage):
    '''if stage>1: # first activation is just after conv1        
        x = layers.BatchNormalization()(input_tensor)        
        x = layers.Activation('relu')(x)    
    else:       '''
    x = input_tensor
    x = Conv2D(n_filters,(1,1))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(n_filters,(1,1))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(n_filters,(1,1))(x)
    # define short cut
    x = layers.add([x, input_tensor])
    return x



## Pre-Activative Residual blocks
!["pre_activative_resBlock"](figs/pre_activative_resBlock.png "pre_activative_resBlock")

In [4]:
'''c7s1-64, d128, d256, d512, R512, R512, R512, R512, R512, R512, R512, R512, R512, u256, u128,
u64, c7s1-8.'''
# The GeneratorModel
def Generator(input_shape=(512,512,3), nb_classes=4):
    img_input = tf.keras.layers.Input(shape=input_shape)
    #c7s1-64/
    x = Conv2D(64, (7,7),strides=(1, 1))(img_input)
    #d128, d256, d512
    x = d_block(x,128)
    x = d_block(x, 256)
    x = d_block(x, 512)
    # R512, R512, R512, R512, R512, R512, R512, R512, R512
    for i in range(1,10):
        x = R_block(x, 512,i)
    #u256, u128,u64
  
    x = u_block(x, 256)
    
    x = u_block(x, 128)

    x = u_block(x, 64)
    #c7s1-8
    x = Conv2D(8, (7,7),strides=(1, 1))(x)
    return models.Model(img_input, x, name='rnpa')

In [2]:
model = Generator()
model.summary()


'model = Generator()\nmodel.summary()'

## The Discriminator
![](figs/Dis.png)

In [21]:
def Discriminator(img_shape):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # source image input
    in_src_image = Input(shape=img_shape)
    # target image input
    in_target_image = Input(shape=img_shape)
    merged = Concatenate()([in_src_image, in_target_image])
    
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    # C128
    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C256
    d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C512
    d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # second last output layer
    d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # patch output
    d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(d)
    # define model
    model = models.Model([in_src_image, in_target_image], patch_out)
    # compile model
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
    return model

In [22]:
imge_shape = (512,512,3)
model = Discriminator(imge_shape)
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 512, 512, 6)  0           input_11[0][0]                   
                                                                 input_12[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 256, 256, 64) 6208        concatenate_4[0][0]        