In [1]:
import sys
import numpy as np
import json
import os
import pickle
import logging
import matplotlib.pyplot as plt

In [2]:
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.layers import Input,AveragePooling2D, Concatenate, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from skimage.transform import resize
from tensorflow.keras.optimizers import Adam, RMSprop,SGD
from tensorflow.keras.initializers import RandomNormal
from functools import partial

In [3]:
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import img_to_array,array_to_img
import keras
from tensorflow.keras.applications.vgg19 import VGG19

In [4]:
from scipy.linalg import sqrtm
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from skimage.transform import resize

In [5]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"
#strategy = tf.distribute.MirroredStrategy(devices=["/gpu:1", "/gpu:2"])

In [6]:
# gpu = tf.config.experimental.list_physical_devices('GPU')[0]
# tf.config.experimental.set_memory_growth(gpu,True)

In [7]:
# FID

In [8]:
# scale an array of images to a new size
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape, 0)
        # store
        images_list.append(new_image)
    return np.asarray(images_list)

# calculate frechet inception distance
def calculate_fid(model, images1, images2):
    # Convert image to 0-255
    images1 = ((images1*127.5)+127.5).astype('float32')
    #images2 = ((images2*127.5)+127.5).astype('float32')
    # Scale image to 299x299x3
    images1 = scale_images(images1, (299,299,3))
    #images2 = scale_images(images2, (299,299,3))
    # pre-process images
    images1 = preprocess_input(images1)
    images2 = preprocess_input(images2)
    # calculate activations
    act1 = model.predict(images1)
    act2 = model.predict(images2)
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid
# prepare the inception v3 model
model_inceptionv3 = InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

In [9]:
# Normalize layer

In [10]:
class PixelNormalization(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super(PixelNormalization,self).__init__(**kwargs)
    def call(self,x):
        value = x**2
        mean = K.mean(value, axis = -1, keepdims=True)
        x = x/K.sqrt(mean + 1.0e-8)
        return x

In [11]:
class WeightedSum(tf.keras.layers.Add):
    # init with default value
    def __init__(self, alpha=0.0, **kwargs):
        super(WeightedSum, self).__init__(**kwargs)
        self.alpha = K.variable(alpha, name='ws_alpha')

    def get_config(self):
        config = super().get_config().copy()
        config.update({
                'alpha' : 0.0
        })
        return config

    # output a weighted sum of inputs
    def _merge_function(self, inputs):
        # only supports a weighted sum of two inputs
        assert (len(inputs) == 2)
        # ((1-a) * input1) + (a * input2)
        output = ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1])
        return output

In [12]:
class MinibatchStdev(tf.keras.layers.Layer):
    # initialize the layer
    def __init__(self, **kwargs):
        super(MinibatchStdev, self).__init__(**kwargs)

    # perform the operation
    def call(self, inputs):
        # calculate the mean value for each pixel across channels
        mean = K.mean(inputs, axis=0, keepdims=True)
        # calculate the squared differences between pixel values and mean
        squ_diffs = K.square(inputs - mean)
        # calculate the average of the squared differences (variance)
        mean_sq_diff = K.mean(squ_diffs, axis=0, keepdims=True)
        # add a small value to avoid a blow-up when we calculate stdev
        mean_sq_diff += 1e-8
        # square root of the variance (stdev)
        stdev = K.sqrt(mean_sq_diff)
        # calculate the mean standard deviation across each pixel coord
        mean_pix = K.mean(stdev, keepdims=True)
        # scale this up to be the size of one input feature map for each sample
        shape = K.shape(inputs)
        output = K.tile(mean_pix, (shape[0], shape[1], shape[2], 1))
        # concatenate with the output
        combined = K.concatenate([inputs, output], axis=-1)
        return combined

    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        # create a copy of the input shape as a list
        input_shape = list(input_shape)
        # add one to the channel dimension (assume channels-last)
        input_shape[-1] += 1
        # convert list to a tuple
        return tuple(input_shape)

In [13]:
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size, **kwargs):
        super().__init__()
        self.batch_size = batch_size

    def call(self, inputs, **kwargs):
        alpha = tf.random.uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0][0]) + ((1 - alpha) * inputs[1][0])

    def compute_output_shape(self, input_shape):
        return input_shape[0]
    
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'batch_size': self.batch_size,
            
        })
        return config
    
def gradient_penalty(y_true, y_pred, discriminator,gradient_penalty_weight):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
      
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(y_pred)
            # 1. Get the discriminator output for this interpolated image.
            pred = discriminator(y_pred, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [y_pred])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp*gradient_penalty_weight

In [14]:
class GradientPenalty(tf.keras.layers.Layer):
    def call(self, inputs):
        (target, wrt) = inputs
        grad = K.gradients(target, wrt)[0]
        return K.sqrt(K.sum(K.batch_flatten(K.square(grad)),
            axis=1, keepdims=True))-1

    def compute_output_shape(self, input_shapes):
        return (input_shapes[1][0], 1)

In [15]:
# Build Perceptual loss

In [16]:
from tensorflow.keras.applications.vgg19 import VGG19
class Cut_VGG19:
    """
    Class object that fetches keras' VGG19 model trained on the imagenet dataset
    and declares <layers_to_extract> as output layers. Used as feature extractor
    for the perceptual loss function.
    Args:
        layers_to_extract: list of layers to be declared as output layers.
        patch_size: integer, defines the size of the input (patch_size x patch_size).
    Attributes:
        loss_model: multi-output vgg architecture with <layers_to_extract> as output layers.
    """
    
    def __init__(self, patch_size, layers_to_extract):
        self.patch_size = patch_size
        self.input_shape = (patch_size,) * 2 + (3,)
        self.layers_to_extract = layers_to_extract
        
        if len(self.layers_to_extract) > 0:
            self._cut_vgg()
    
    def _cut_vgg(self):
        """
        Loads pre-trained VGG, declares as output the intermediate
        layers selected by self.layers_to_extract.
        """
        
        vgg = VGG19(weights='imagenet', include_top=False, input_shape=self.input_shape)
        vgg.trainable = False
        outputs = [vgg.layers[i].output for i in self.layers_to_extract]
        self.model = Model([vgg.input], outputs)
feature_extraction = Cut_VGG19(128,[5,9])   
feature_extraction.model.trainable = False

In [17]:
def update_fadein(models, step, n_steps):
    # calculate current alpha (linear from 0 to 1)
    alpha = step / float(n_steps - 1)
    # update the alpha for each model
    for model in models:
        for layer in model.layers:
            if isinstance(layer, WeightedSum):
                K.set_value(layer.alpha, alpha)

In [18]:
class Sampling(tf.keras.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [19]:
class WGANGP():
    def __init__(self,pre_train = False, **kwargs):
        self.input_dim = (256,256,3)
        self.optimiser = 'rmsprop'
        self.z_dim = 256
        self.n_batch = [128,64,32,16,8,4] # batch_size
        self.pre_train = pre_train
        self.pre_fid = 500
        ################ Encoder Model ########################
        self.encoder_model = tf.keras.models.load_model("../VAE_model/model_VAE/encoder_model_256_2layer.h5", \
                                                       custom_objects={"Sampling":Sampling})
        self.decoder_model = tf.keras.models.load_model("../VAE_model/model_VAE/decoder_model_256_2layer.h5", \
                                                       custom_objects={"Sampling":Sampling})
#         ############## Conditional Layer #######################
        self.conditional_input = (self.z_dim,)
        self.conditional_initial_dense_layer_size = (4,4,32)
        self.conditional_batch_norm_momentum =  None
        self.conditional_activation = 'leaky_relu'
        
        ################ Generator Model #########################
        self.generator_initial_dense_layer_size = (4,4,256)
        self.generator_upsample = [1,1,1]
        self.generator_conv_filters = [128,128,3]
        self.generator_conv_kernel_size = [4,3,1]
        self.generator_conv_strides = [2,2,2]
        self.generator_batch_norm_momentum =  None
        self.generator_activation = 'leaky_relu'
        self.generator_dropout_rate = None
        self.generator_learning_rate = 2e-3
        self.gen_n_blocks = len(self.n_batch)
        
        ################ Discriminator Model ###########################
        self.discriminator_input_shape = (4,4,3)
        self.discriminator_conv_filters = [128,128,128]
        self.discriminator_conv_kernel_size = [1,3,4]
        self.discriminator_conv_strides = [1,1,1]
        self.discriminator_batch_norm_momentum = None
        self.discriminator_activation = 'leaky_relu'
        self.discriminator_dropout_rate = None
        self.discriminator_learning_rate = 2e-3
        self.disc_n_blocks = len(self.n_batch)
        ############################################        
        self.n_layers_discriminator = len(self.discriminator_conv_filters)
        self.n_layers_generator = len(self.generator_conv_filters)
        
        ###############################################                               
        self.d_losses = []
        self.g_losses = []
        self.epoch = 2010
        ###############################################
        if self.pre_train:
            self.weight_init = None
            self.grad_weight = 10
            self.const = tf.keras.constraints.max_norm(1.0)
            self.g_models = []
            self.d_models = []
            for i in range(0, len(self.n_batch)):
                # scale dataset to appropriate size
                output_shape = 4*np.power(2,i)
                print(output_shape)
                g_normal = tf.keras.models.load_model("./model_save/model_condproGAN/g_normal_"+str(output_shape)+".h5", \
                                                      custom_objects={'PixelNormalization':PixelNormalization, \
                                                                     'MinibatchStdev': MinibatchStdev, \
                                                                     'WeightedSum':WeightedSum})
                
                g_fadein = tf.keras.models.load_model("./model_save/model_condproGAN/g_fadein_"+str(output_shape)+".h5", \
                                                      custom_objects={'PixelNormalization':PixelNormalization, \
                                                                     'MinibatchStdev': MinibatchStdev, \
                                                                     'WeightedSum':WeightedSum})
                
                d_normal = tf.keras.models.load_model("./model_save/model_condproGAN/d_normal_"+str(output_shape)+".h5", \
                                                      custom_objects={'AveragePooling2D':AveragePooling2D, \
                                                                     'MinibatchStdev': MinibatchStdev, \
                                                                     'RandomWeightedAverage':RandomWeightedAverage, \
                                                                     'WeightedSum':WeightedSum, \
                                                                     'wasserstein':self.wasserstein})
                
                d_fadein = tf.keras.models.load_model("./model_save/model_condproGAN/d_fadein_"+str(output_shape)+".h5", \
                                                    custom_objects={'AveragePooling2D':AveragePooling2D, \
                                                                   'MinibatchStdev': MinibatchStdev, \
                                                                   'RandomWeightedAverage':RandomWeightedAverage, \
                                                                    'WeightedSum':WeightedSum, \
                                                                   'wasserstein':self.wasserstein})
                
                
                self.g_models.append([g_normal, g_fadein])
                self.d_models.append([d_normal,d_fadein])

        else:
            self.weight_init = RandomNormal(mean=0., stddev=0.02)
            self.grad_weight = 10
            self.const = tf.keras.constraints.max_norm(1.0)
            self.g_models = self._build_generator()
            self.d_models = self._build_discriminator()

        self._build_adversarial()
        LOG_DIR = "./logs/cond_pro_gan.log"
        logging.basicConfig(filename=LOG_DIR,  
                    level=logging.DEBUG,
                    format="[%(asctime)s] [%(name)s] [%(message)s]",
                    filemode="a")
        logging.getLogger("tensorflow").setLevel(logging.ERROR)
        logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR)
    ####################### Loss ###########################
    def wasserstein(self, y_true, y_pred):
        return -K.mean(y_true * y_pred)
    
    def get_perceptual_loss(self,y_true,y_pred):
        content_feature = feature_extraction.model(y_true)
        new_feature = feature_extraction.model(y_pred)
        perceptual_loss = 0
        weight = tf.constant([1/16,1/8], dtype = tf.float32)
        for i in range(len(new_feature)):
            perceptual_loss += weight[i]*K.mean(K.square(new_feature[i] - content_feature[i]))
        l2_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(y_true,y_pred))
        total_loss = perceptual_loss + 100*l2_loss
        return total_loss
    
    def dummy_loss_function(self,y_true, y_pred):
        return y_pred
    
    
    ################# Activation layer #####################                                                                
    def get_activation(self, activation):
        if activation == 'leaky_relu':
            layer = LeakyReLU(alpha = 0.2)
        else:
            layer = Activation(activation)
        return layer
    
    
    ####################################################################
    #################### Build Generator Model #########################
    ####################################################################
    
    def _build_generator(self):
        ############  generator ###############
        #with strategy.scope():
        generator_input_layer = Input(shape=(self.z_dim,), name='generator_input')
        x = generator_input_layer
        x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)

        if self.generator_batch_norm_momentum:
            x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)       
        x = self.get_activation(self.generator_activation)(x)
        x = Reshape(self.generator_initial_dense_layer_size)(x)

        if self.generator_dropout_rate:
            x = Dropout(rate = self.generator_dropout_rate)(x)

        ###########################################################
        for i in range(self.n_layers_generator): 
#             x = UpSampling2D()(x)
            x = Conv2D(
            filters = self.generator_conv_filters[i]
            , kernel_size = (self.generator_conv_kernel_size[i],self.generator_conv_kernel_size[i])
            , padding = 'same'
            , kernel_initializer = self.weight_init
            , kernel_constraint=self.const
            )(x)

            if i < self.n_layers_generator - 1:
                x = PixelNormalization()(x)
                if self.generator_batch_norm_momentum:
                    x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)
                x = self.get_activation(self.generator_activation)(x)                
        
        generator_output = x
        generator_model = Model(generator_input_layer, generator_output,name='Generator')
        #generator_model.summary()
        #################### Store Model #######################
        model_list = list()
        model_list.append([generator_model, generator_model])
        # create submodels
        for i in range(1, self.gen_n_blocks):
            # get prior model without the fade-on
            old_model = model_list[i - 1][0]

            # create new model for next resolution
            models = self.add_generator_block(old_model, i)
#             models[0].summary()
#             models[1].summary()
            # store model
            model_list.append(models)
        return model_list
    #################### Add Generator Block ################
    def add_generator_block(self,old_model, model_indx):
        # weight initialization
        init = tf.keras.initializers.RandomNormal(stddev=0.02)
        # weight constraint
        const = tf.keras.constraints.max_norm(1.0)
        # get the end of the last block
        block_end = old_model.layers[-2].output
        # upsample, and define new block
        #with strategy.scope():
        upsampling = UpSampling2D()(block_end)
        g = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(upsampling)
        g = PixelNormalization()(g)
        g = LeakyReLU(alpha=0.2)(g)
        g = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
        g = PixelNormalization()(g)
        g = LeakyReLU(alpha=0.2)(g)
        # add new output layer
        out_image = Conv2D(3, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
        # define model
        model1 = Model(old_model.input, out_image, name = "generator_normal_"+str(model_indx))
        # get the output layer from old model
        out_old = old_model.layers[-1]
        # connect the upsampling to the old output layer
        out_image2 = out_old(upsampling)
        # define new output image as the weighted sum of the old and new models
        merged = WeightedSum()([out_image2, out_image])
        # define model
        model2 = Model(old_model.input, merged, name = "generator_fadein_"+str(model_indx))
        return [model1, model2]
        
    ####################################################################
    #################### Build Discriminator Model #####################
    ####################################################################
    def _build_discriminator(self):
        #with strategy.scope():
        discriminator_input = Input(shape=self.discriminator_input_shape, name='discriminator_input')
        ############## Conditional Layer ####################
        conditional_input = Input(shape= self.conditional_input, name='conditional_input')
        x_cond = conditional_input
        x_cond = Dense(np.prod(self.conditional_initial_dense_layer_size), kernel_initializer = self.weight_init)(x_cond)
        if self.conditional_batch_norm_momentum:
            x_cond = BatchNormalization(momentum = self.conditional_batch_norm_momentum)(x_cond)       
        x_cond = self.get_activation(self.conditional_activation)(x_cond)
        x_cond = Reshape(self.conditional_initial_dense_layer_size)(x_cond)
       
        x= Concatenate()([discriminator_input,x_cond])
        #x = discriminator_input

        for i in range(self.n_layers_discriminator):
            x = Conv2D(
                filters = self.discriminator_conv_filters[i]
                , kernel_size = (self.discriminator_conv_kernel_size[i],self.discriminator_conv_kernel_size[i])
                , strides = self.discriminator_conv_strides[i]
                , padding = 'same'
                , kernel_initializer = self.weight_init
                , kernel_constraint= self.const
                )(x)

            if self.discriminator_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x)
            x = self.get_activation(self.discriminator_activation)(x)
            if self.discriminator_dropout_rate:
                x = Dropout(rate = self.discriminator_dropout_rate)(x)
            ############### Concatenate ##########################
#             if i == 0:
#                 x = Concatenate()([x,x_cond])
            if i <self.n_layers_discriminator-1:
                x = MinibatchStdev()(x)

        x = Flatten()(x)

        discriminator_output = Dense(1, activation=None
        , kernel_initializer = self.weight_init
        )(x)

        discriminator_model = Model([discriminator_input,conditional_input], discriminator_output,name="Discriminator")

        model_list = list()

        model_list.append([discriminator_model, discriminator_model])
        # create submodels
        for i in range(1, self.disc_n_blocks):
            # get prior model without the fade-on
            old_model = model_list[i - 1][0]

            # create new model for next resolution
            models = self.add_discriminator_block(old_model, i)
            model_list.append(models)
        return model_list
    #################### Add Discriminator Block ################
    def add_discriminator_block(self,old_model,model_indx, n_input_layers=8):
        # weight initialization
        init = tf.keras.initializers.RandomNormal(stddev=0.02)
        # weight constraint
        const = tf.keras.constraints.max_norm(1.0)
        # get shape of existing model
        #with strategy.scope():
        in_shape = list(old_model.input[0].shape)
        
        # define new input shape as double the size
        input_shape = (in_shape[-2]*2, in_shape[-2]*2, in_shape[-1])
        # define new conditional layer
        conditional_input = Input(shape= self.conditional_input)
        x_cond = conditional_input
        cond_initial_dense = (self.conditional_initial_dense_layer_size[0]*np.power(2,model_indx), \
                              self.conditional_initial_dense_layer_size[1]*np.power(2,model_indx), \
                              self.conditional_initial_dense_layer_size[2])
    
        x_cond = Dense(np.prod(cond_initial_dense), kernel_initializer = init)(x_cond)
        if self.conditional_batch_norm_momentum:
            x_cond = BatchNormalization(momentum = self.conditional_batch_norm_momentum)(x_cond)       
        x_cond = self.get_activation(self.conditional_activation)(x_cond)
        x_cond = Reshape(cond_initial_dense)(x_cond)
        
        in_image = Input(shape=input_shape)
        ############ Concatenate
        d = Concatenate()([in_image,x_cond])
        
        # define new input processing layer
        d = Conv2D(128, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
        d = LeakyReLU(alpha=0.2)(d)
        # define new block
        d = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
        d = LeakyReLU(alpha=0.2)(d)
        d = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
        d = LeakyReLU(alpha=0.2)(d)
        d = AveragePooling2D()(d)
       
        #d = RoiPoolingConv(2)(d)
        block_new = d
        # skip the input, 1x1 and activation for the old model
        for i in range(n_input_layers, len(old_model.layers)):
            d = old_model.layers[i](d)
        # define straight-through model
        
        model1 = Model([in_image,conditional_input], d, name = "discriminator_normal_"+str(model_indx))
        # compile model
        model1.compile(loss=self.wasserstein, optimizer=RMSprop(lr=self.discriminator_learning_rate))
        ###################################################################

        d = Concatenate()([in_image,x_cond])
        downsample = AveragePooling2D()(d)
        # connect old input processing to downsampled new input

        block_old = old_model.layers[6](downsample)
        block_old = old_model.layers[7](block_old)

        # fade in output of old model input layer with new input
        d = WeightedSum()([block_old, block_new])

        # skip the input, 1x1 and activation for the old model
        for i in range(n_input_layers, len(old_model.layers)):
            d = old_model.layers[i](d)
        # define straight-through model
        model2 = Model([in_image,conditional_input], d, name= "discriminator_fadein_"+str(model_indx))
        return [model1, model2]
   

    #################### Optimize #########################

    def get_opti(self, lr):
        if self.optimiser == 'adam':
            opti = Adam(lr=lr, beta_1=0.5)
        elif self.optimiser == 'rmsprop':
            opti = RMSprop(lr=lr)
        else:
            opti = Adam(lr=lr)

        return opti


    def set_trainable(self, m, val):
        m.trainable = val
        for l in m.layers:
            l.trainable = val
            
    def update_fadein(self,models, step, n_steps):
        # calculate current alpha (linear from 0 to 1)
        alpha = step / float(n_steps - 1)
        # update the alpha for each model
        for model in models:
            for layer in model.layers:
                if isinstance(layer, WeightedSum):
                    K.set_value(layer.alpha, alpha)
    ####################################################################
    #################### Build Adversarial Model #########################
    ####################################################################
    def build_discriminator_model(self,discriminator, batch_size):
        #with strategy.scope():
        real_input = discriminator.input[0]
        real_conditional_input = discriminator.input[1]
        
        shape = real_input.shape[1:]
        shape_conditional = real_conditional_input.shape[1]
        fake_input = Input(shape = shape)
        fake_conditional_input = Input(shape=(shape_conditional,))
        
        discriminator_output_from_generator = discriminator([fake_input,fake_conditional_input])
        discriminator_output_from_real_samples = discriminator([real_input,real_conditional_input])
        
        averaged_samples = RandomWeightedAverage(batch_size = batch_size)([[real_input,real_conditional_input], \
                                                [fake_input,fake_conditional_input]])
        
        validity_interpolated = discriminator([averaged_samples,real_conditional_input])
        
        gp = GradientPenalty()([validity_interpolated, averaged_samples])
#         partial_gp_loss = partial(gradient_penalty,
#                               discriminator = discriminator,
#                               gradient_penalty_weight=10)

#         partial_gp_loss.__name__ = 'gradient_penalty'



        discriminator_model = Model(inputs = [[real_input,real_conditional_input], \
                                              [fake_input,fake_conditional_input]], \
                      outputs = [discriminator_output_from_real_samples,discriminator_output_from_generator,gp])

        discriminator_model.compile(optimizer=Adam(self.discriminator_learning_rate, beta_1=0.5, beta_2=0.9),
                                loss=[self.wasserstein,
                                      self.wasserstein,
                                        "mse"])
        #discriminator_model.summary()
        return discriminator_model
    
    def define_composite(self,discriminators, generators):
        model_list = []
        #with strategy.scope():
        for i in range(len(discriminators)):
            if i != len(discriminators)-1 :
                discriminators[i][0].trainable = False
                input_dis_cond = generators[i][0].input
                input_dis = generators[i][0].output
                output = discriminators[i][0]([input_dis,input_dis_cond])
                model1 = Model(generators[i][0].input, output)
                model1.compile(loss = self.wasserstein, optimizer = RMSprop(lr=self.generator_learning_rate))
                discriminators[i][0].trainable = True

                discriminators[i][1].trainable = False
                input_dis_cond = generators[i][1].input
                input_dis = generators[i][1].output
                output = discriminators[i][1]([input_dis,input_dis_cond])
                model2 = Model(generators[i][1].input, output)
                model2.compile(loss = self.wasserstein, optimizer = RMSprop(lr=self.generator_learning_rate))
                discriminators[i][1].trainable = True

            else:
                discriminators[i][0].trainable = False
                input_dis_cond = generators[i][0].input
                input_dis = generators[i][0].output
                output = discriminators[i][0]([input_dis,input_dis_cond])
                model1 = Model(generators[i][0].input, [output,input_dis])
                model1.compile(loss = [self.wasserstein, self.get_perceptual_loss], optimizer = RMSprop(lr=self.generator_learning_rate))
                discriminators[i][0].trainable = True

                discriminators[i][1].trainable = False
                input_dis_cond = generators[i][1].input
                input_dis = generators[i][1].output
                output = discriminators[i][1]([input_dis,input_dis_cond])
                model2 = Model(generators[i][1].input, [output,input_dis])
                model2.compile(loss =  [self.wasserstein, self.get_perceptual_loss], optimizer = RMSprop(lr=self.generator_learning_rate))
                discriminators[i][1].trainable = True
                
            model_list.append([model1, model2])

        return model_list
    ####################################################################
    #################### Build Adversarial Model #########################
    ####################################################################
    def _build_adversarial(self):
        self.disc_models = []
        ## build discriminator 
        for i in range(len(self.d_models)):
            self.disc_models.append([self.build_discriminator_model(self.d_models[i][0],int(self.n_batch[i])),self.build_discriminator_model(self.d_models[i][1], int(self.n_batch[i]))])

        ## build gan model
        self.gan_models = self.define_composite(self.d_models, self.g_models)    
    
    def train_discriminator(self,g_model,d_model, x_train,Y_train, batch_size):

        valid = np.ones((batch_size,1))
        fake = -np.ones((batch_size,1))
        dummy = np.zeros((int(batch_size), 1), dtype=np.float32) # Dummy gt for gradient penalty
  
        images_input = x_train
        true_imgs = Y_train
        
        latent_code = self.encoder_model.predict(images_input)[2]
        gen_imgs = g_model.predict(latent_code)
        
        d_loss = d_model.train_on_batch([[true_imgs,latent_code],[gen_imgs,latent_code]],[valid, fake, dummy])
   
        return d_loss

    def train_generator(self,gan_model,x_train,Y_train, batch_size):
        valid = np.ones((batch_size,1), dtype=np.float32)
        true_images = Y_train
        input_images = x_train
        
        latent_code = self.encoder_model.predict(input_images)[2]
        #latencode = np.random.normal(0, 1, (batch_size, self.z_dim))
        
        if type(gan_model.output)==list:
            gans_loss = gan_model.train_on_batch(latent_code, [valid,Y_train])
        else:
            gans_loss = gan_model.train_on_batch(latent_code, valid)
        return gans_loss
    
    ########################################################
    ########################################################
    def scaled_data(self,data,input_shape):
        images_arr = []
        for img in data:
            #print(img.shape)
            img = array_to_img(img)
            resized_img = img.resize(size=input_shape[:-1])
            images_arr.append(img_to_array(resized_img))
        data = np.asarray(images_arr)
        data = (data - 127.5)/127.5
        data = data.astype('float32')
        return data
    
    def shuffle_data_batch(self,array_X,array_Y,batch_size):
        indices = np.arange(array_X.shape[0])
        np.random.shuffle(indices)
        array_X = array_X[indices]
        array_Y = array_Y[indices]
        return array_X,array_Y
    
    def shuffle_data_batch_with_latent(self,array_latent,array_X,array_Y,batch_size):
        indices = np.arange(array_X.shape[0])
        np.random.shuffle(indices)
        array_latent = array_latent[indices]
        array_X = array_X[indices]
        array_Y = array_Y[indices]
        return array_latent,array_X,array_Y
    
    def split_into_chunks(self,l, n):
        for i in range(0, l.shape[0], n):
            yield l[i:i + n]  
    
    def scale_all_shape(self,data):
        init_shape = self.discriminator_input_shape[0]
        num_scale = len(self.n_batch)
        data_all_shape = []
        for i in range(num_scale):
            shape = (init_shape*np.power(2,i),init_shape*np.power(2,i),3)
            data_scaled = self.scaled_data(data,shape)
            data_all_shape.append(data_scaled)
        return data_all_shape
    
    def train_epochs(self,g_model, d_model, gan_model, x_train, Y_train,n_critic=5,fadein=False):
        
        for i in range(self.n_steps):
            if fadein:
                update_fadein([g_model, d_model, gan_model], i, self.n_steps)
            x = next(x_train)
            Y = next(Y_train)
            for _ in range(n_critic):
                d_loss = self.train_discriminator(g_model,d_model,x, Y, self.batch_size)
            g_loss = self.train_generator(gan_model,x ,Y , self.batch_size)
            # Plot the progress
            self.d_losses.append(d_loss)
            self.g_losses.append(g_loss)
            return d_loss,g_loss
    def train(self,x_train,Y_train,x_val,Y_val, batch_size, epochs, run_folder, print_every_n_batches = 10, n_critic=2):
        self.data_fid = ((Y_train*127.5)+127.5).astype('float32')
        self.data_fid = scale_images(self.data_fid, (299,299,3))
        Y_train_reshape = self.scale_all_shape(Y_train)
        
        for epoch in range(self.epoch, self.epoch + epochs):
            for i in range(0, len(self.g_models)):
                self.batch_size = self.n_batch[i]
                x_train_gan,Y_train_gan = self.shuffle_data_batch(x_train,Y_train_reshape[i],self.batch_size)
                self.n_steps = int(x_train.shape[0]/self.batch_size)-1
                # scale dataset to appropriate size
                [g_normal, g_fadein] = self.g_models[i]
                [d_normal, d_fadein] = self.disc_models[i]
                [gan_normal, gan_fadein] = self.gan_models[i]
                gen_shape = g_normal.output_shape
             
                Y_train_gan = self.split_into_chunks(Y_train_gan,self.batch_size)
                x_train_gan = self.split_into_chunks(x_train_gan,self.batch_size)
                
                # Train with fadein
                d_loss_fadein, g_loss_fadein = self.train_epochs(g_fadein, d_fadein, gan_fadein,x_train_gan, Y_train_gan,n_critic, True)
                # Train without fadein
                d_loss, g_loss = self.train_epochs(g_normal, d_normal, gan_normal, x_train_gan, Y_train_gan,n_critic, False)
                
            # If at save interval => save generated image samples
            if epoch % print_every_n_batches == 0:
                # calculator FID 
                fid_fadein = self.sample_images(x_val,Y_val,run_folder,self.g_models[-1][1], fadein = True)
                fid_normal = self.sample_images(x_val,Y_val,run_folder,self.g_models[-1][0], fadein = False)
                logging.info(json.dumps({'epoch':epoch,'d_fadein_loss':d_loss_fadein,'g_fadein_loss':g_loss_fadein,'fid_fadein':fid_fadein, \
                                                        'd_normal_loss':d_loss,'g_normal_loss':g_loss,'fid_normal':fid_normal, \
                                                         'fid_min':self.pre_fid})) 
                print (epoch,"=======Fadein====",d_loss_fadein, g_loss_fadein,fid_fadein,fid_normal)  
                if fid_normal < self.pre_fid:
                    self.pre_fid = fid_normal
                    for i in range(0, len(self.g_models)):
                        # scale dataset to appropriate size
                        [g_normal, g_fadein] = self.g_models[i]
                        [d_normal, d_fadein] = self.d_models[i]
                        #[gan_normal, gan_fadein] = GAN.gan_models[i]
                        output_shape = g_normal.output_shape[1]
                        g_normal.save("./model_save/model_condproGAN/g_normal_"+str(output_shape)+".h5")
                        g_fadein.save("./model_save/model_condproGAN/g_fadein_"+str(output_shape)+".h5")
                        d_normal.save("./model_save/model_condproGAN/d_normal_"+str(output_shape)+".h5")
                        d_fadein.save("./model_save/model_condproGAN/d_fadein_"+str(output_shape)+".h5")
            self.epoch+=1

    def sample_images(self,x_val,Y_val, run_folder, g_model, fadein = True):
        fid = self.pre_fid
        if self.epoch % 40 ==0:
            # Test
            r, c = 8, 4
            output_shape = g_model.output_shape[1:]
            y_true = Y_val
            input_model = x_val

            latent_code = self.encoder_model.predict(input_model)[2]
            #gen_finger = self.decoder_model.predict(latent_code)
            gen_imgs = g_model.predict(latent_code)
            ########## FID ##############

            
            #############################
            indx = np.random.choice(y_true.shape[0], int(0.5*c*r) ,replace=False)

            face_real = 0.5*(y_true[indx]+1)
            face_real = np.clip(face_real, 0, 1)
            finger_real = input_model[indx]
            #face_real = face_real[:,:,:,[2,1,0]]

            gen_imgs = 0.5 * (gen_imgs[indx] + 1)
            gen_imgs = np.clip(gen_imgs, 0, 1)
            #gen_imgs = gen_imgs[:,:,:,[2,1,0]]
            #gen_finger = 0.5 * (gen_finger[indx] + 1)
            #gen_finger = np.clip(gen_finger, 0, 1)

            fig, axs = plt.subplots(r, c, figsize=(15,30))
            #fig.suptitle("Perceptual loss : %.3f" %(perceptloss))
            cnt = 0
            for i in range(r):
                for j in range(int(0.5*c)):
                    axs[i,2*j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]))
                    axs[i,2*j].axis('off')
                    axs[i,2*j+1].imshow(np.squeeze(face_real[cnt, :,:,:]))
                    axs[i,2*j+1].axis('off')
    #                 axs[i,4*j+2].imshow(np.squeeze(gen_finger[cnt, :,:,:]),cmap='gray')
    #                 axs[i,4*j+2].axis('off')
    #                 axs[i,4*j+3].imshow(np.squeeze(finger_real[cnt, :,:,:]),cmap='gray')
    #                 axs[i,4*j+3].axis('off')
                    cnt += 1
        
            #fid = calculate_fid(model_inceptionv3,gen_imgs,self.data_fid)
            if fadein:
                fig.savefig(os.path.join(run_folder, "images_latent/sample_fadein_%d_%d.png" % (output_shape[0],self.epoch)))
            else:
                fig.savefig(os.path.join(run_folder, "images_latent/sample_normal_%d_%d.png" % (output_shape[0],self.epoch)))
        plt.close()
        return fid

In [20]:
GAN = WGANGP(pre_train=True)

4
8
16
32
64
128


In [21]:
# Data preprocess

In [22]:
def augment_data_face_fingerprint(data_face,data_finger, num_augment_percent=0.8):
    num_data = data_face.shape[0]
    num_data_augment = int(data_face.shape[0]*num_augment_percent)
    index_data = np.random.choice(data_face.shape[0],num_data_augment,replace=False)
    
    data_aug_face = []
    data_aug_finger = []
    for i in index_data:
        data_aug_face.append(data_face[i,:,:,:])
        data_aug_finger.append(data_finger[i,:,:,:])
    data_aug_face = np.asarray(data_aug_face)    
    data_aug_finger = np.asarray(data_aug_finger) 

    # create image data augmentation generator
    datagen_face = ImageDataGenerator(horizontal_flip=True)
    datagen_face.fit(data_aug_face)
    
    datagen_finger = ImageDataGenerator(horizontal_flip=False,height_shift_range=0.1,width_shift_range=0.1,shear_range=0.5,rotation_range=20)
    datagen_finger.fit(data_aug_finger)
    
    it_face = datagen_face.flow(data_aug_face,batch_size=num_data_augment,shuffle=False)
    it_finger = datagen_finger.flow(data_aug_finger,batch_size=num_data_augment,shuffle=False)
    
    results_face = np.concatenate([data_face,it_face.next()],axis=0)
    results_finger = np.concatenate([data_finger,it_finger.next()],axis=0)
    #np.random.shuffle(results)
    return results_face,results_finger

def augment_data_only_fingerprint(data_finger, num_augment_percent = 0.5):
    
    num_data = data_finger.shape[0]
    num_data_augment = int(data_finger.shape[0]*num_augment_percent)
    index_batch = np.arange(0,num_data_augment)
    index_data = np.random.choice(num_data,num_data_augment,replace=False)
    ###############################
    data_aug_finger = []
    for i in index_data:
        data_aug_finger.append(data_finger[i,:,:,:])   
    data_aug_finger = np.asarray(data_aug_finger) 
    ################################
    datagen_finger = ImageDataGenerator(horizontal_flip=False,height_shift_range=0.2,width_shift_range=0.2,shear_range=0.1,rotation_range=20)
    datagen_finger.fit(data_aug_finger)
    ###############################
    it_finger = datagen_finger.flow(data_aug_finger,batch_size=num_data_augment,shuffle=False)
    finger_augmented = it_finger.next()
    ################################
    result_finger = data_finger
    for index in zip(index_data,index_batch):
        result_finger[index[0],:,:,:]=finger_augmented[index[1],:,:,:]
    return result_finger


# Load data train with latent code

In [23]:
DATA_DIR = './data_train/data_get_latent/'
with open(os.path.join(DATA_DIR,"data_face_2856_train.pkl"), "rb") as input_file:
    data_train_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_fingerprint_2856_train.pkl"), "rb") as input_file:
    data_train_finger = pickle.load(input_file)

    
with open(os.path.join(DATA_DIR,"data_face_2856_val.pkl"), "rb") as input_file:
    data_val_face = pickle.load(input_file)
with open(os.path.join(DATA_DIR,"data_fingerprint_2856_val.pkl"), "rb") as input_file:
    data_val_finger = pickle.load(input_file)
    

In [24]:
data_train_face, data_train_finger = augment_data_face_fingerprint(data_train_face,data_train_finger)

In [25]:
data_train_face.shape,data_train_finger.shape,data_val_face.shape,data_val_finger.shape

((7711, 256, 256, 3),
 (7711, 256, 256, 1),
 (715, 256, 256, 3),
 (715, 256, 256, 1))

In [26]:
data_train_finger = ((data_train_finger)/np.max(data_train_finger))
data_val_finger = ((data_val_finger)/np.max(data_val_finger))
data_train_finger = np.where(data_train_finger > .5, 1.0, 0.0).astype('float32')
data_val_finger = np.where(data_val_finger > .5, 1.0, 0.0).astype('float32')

In [1]:
print(np.min(data_val_face),np.max(data_val_face), data_val_face.dtype)
plt.imshow(0.5*(data_val_face[150]+1))
plt.show()

In [28]:
# data_train_face = ((data_train_face-127.5)/127.5).astype('float32')
# data_val_face = ((data_val_face-127.5)/127.5).astype('float32')

In [29]:
train_face = data_train_face
val_face = data_val_face
train_finger = data_train_finger
val_finger = data_val_finger

# Training GAN

In [None]:
GAN.train(
    train_finger  
    , train_face
    , val_finger
    , val_face
    , batch_size = 4
    , epochs = 10000
    , run_folder = './model_save'
    , print_every_n_batches = 10
    , n_critic = 2
)



### 