In [2]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Input, Reshape,ReLU,LeakyReLU,Concatenate
from tensorflow.keras.layers import Conv2D, MaxPooling2D,Conv2DTranspose, BatchNormalization
from tensorflow_addons.layers import InstanceNormalization

## The network architecture is define as follow:


![;;"](figs/networkArch.PNG "ss")

In [3]:
# dk: 3 by 3 Convolution_InstanceNorm ReLu with k filter and stride of 2
def d_block(input_tensor, k):
    x = input_tensor
    x = Conv2D(k, (3,3), strides=(2, 2))(x)
    x = InstanceNormalization()(x)
    x = ReLU()(x)
    return x

#uk: 3x3 Transposed Convolution with k filter and InstanceNorm with Relu
def u_block(input_tensor, k):
    x = input_tensor
    x = Conv2DTranspose(k, (3,3))(x)
    x = InstanceNormalization()(x)
    x = ReLU()(x)
    return x

In [4]:
# Rk denotes a pre-activation residual block with k filters
# The structure of a pre-activation residual block showing below.
def R_block(input_tensor,n_filters,stage):
    '''if stage>1: # first activation is just after conv1        
        x = layers.BatchNormalization()(input_tensor)        
        x = layers.Activation('relu')(x)    
    else:       '''
    x = input_tensor
    x = Conv2D(n_filters,(1,1))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(n_filters,(1,1))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(n_filters,(1,1))(x)
    # define short cut
    x = layers.add([x, input_tensor])
    return x



## Pre-Activative Residual blocks
!["pre_activative_resBlock"](figs/pre_activative_resBlock.png "pre_activative_resBlock")

In [32]:
'''c7s1-64, d128, d256, d512, R512, R512, R512, R512, R512, R512, R512, R512, R512, u256, u128,
u64, c7s1-8.'''
# The GeneratorModel
def Generator(input_shape=(512,512,3), nb_classes=4):
    img_input = tf.keras.layers.Input(shape=input_shape)
    #c7s1-64/
    x = Conv2D(64, (7,7),strides=(1, 1))(img_input)
    #d128, d256, d512
    x = d_block(x,128)
    x = d_block(x, 256)
    x = d_block(x, 512)
    # R512, R512, R512, R512, R512, R512, R512, R512, R512
    for i in range(1,10):
        x = R_block(x, 512,i)
    #u256, u128,u64
  
    x = u_block(x, 256)
    
    x = u_block(x, 128)

    x = u_block(x, 64)
    #c7s1-8
    x = Conv2D(8, (7,7),strides=(1, 1))(x)
    x = Activation('sigmoid')(x)
    return models.Model(img_input, x, name='rnpa')

In [33]:
model = Generator()
model.summary()


Model: "rnpa"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_44 (InputLayer)           [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_494 (Conv2D)             (None, 506, 506, 64) 9472        input_44[0][0]                   
__________________________________________________________________________________________________
conv2d_495 (Conv2D)             (None, 252, 252, 128 73856       conv2d_494[0][0]                 
__________________________________________________________________________________________________
instance_normalization_130 (Ins (None, 252, 252, 128 256         conv2d_495[0][0]                 
_______________________________________________________________________________________________

## The Discriminator
![](figs/Dis.png)

In [7]:
def Discriminator(img_shape):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # source image input
    in_src_image = Input(shape=img_shape)
    # target image input
    in_target_image = Input(shape=img_shape)
    merged = Concatenate()([in_src_image, in_target_image])
    
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    # C128
    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C256
    d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C512
    d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # second last output layer
    d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
    d = InstanceNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # patch output
    d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(d)
    # define model
    model = models.Model([in_src_image, in_target_image], patch_out)
    # compile model
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='mse', optimizer=opt, loss_weights=[0.5])
    return model

In [27]:
imge_shape = (256,256,3)
model = Discriminator(imge_shape)
model.summary()

Model: "model_10"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_34 (InputLayer)           [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
input_35 (InputLayer)           [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
concatenate_10 (Concatenate)    (None, 256, 256, 6)  0           input_34[0][0]                   
                                                                 input_35[0][0]                   
__________________________________________________________________________________________________
conv2d_412 (Conv2D)             (None, 128, 128, 64) 6208        concatenate_10[0][0]      

In [40]:
# define the combined generator and discriminator model, for updating the generator
def define_gan(generator, discriminator,img):
    # make weights in the discriminator not trainable
    for layer in discriminator.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False
    # connect them
    img_shape = img.shape
    in_src = Input(shape=img_shape)
    # add generator
    gen_out = generator(in_src)
    # where the Generator output the fake paramaters
    # We will calculate the MAE between the fake paramaters and the real parameters that will be the loss Lp.
    # apply the fake paramaters and the real parameters to the Origin image I to generate 2 new images P_fake and P_real
    # Also, calculate the rendering loss.(No idea yet)
    # pass the Origin image I and  P_fake to the Discriminator D1. 
    # Calculate the feature loss Lf and adversarial loss La
    # pass the Origin image I and  P_real to the Discriminator D1. 
    # Calculate the feature loss Lf and adversarial loss La.
    # pass the Origin image I and  P_fake in half resolution to the Discriminator D2.
    # Calculate the feature loss Lf2 and adversarial loss La2
    # pass the Origin image I and in half resolution P_real to the Discriminator D2. 
    # Calculate the feature loss Lf2 and adversarial loss La2.
    
    # add the discriminator
    dis_out1 = discriminator([in_src, in_src])
    # compile model with L2 loss
    model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return model

In [42]:
img_shape = (512,512,3)
dis = Discriminator(img_shape)
gen = Generator()
gan = define_gan(gen,dis,img_shape)

In [44]:
gan.summary()

Model: "rnpa"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_44 (InputLayer)           [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_494 (Conv2D)             (None, 506, 506, 64) 9472        input_44[0][0]                   
__________________________________________________________________________________________________
conv2d_495 (Conv2D)             (None, 252, 252, 128 73856       conv2d_494[0][0]                 
__________________________________________________________________________________________________
instance_normalization_130 (Ins (None, 252, 252, 128 256         conv2d_495[0][0]                 
_______________________________________________________________________________________________