In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from tensorflow.keras.layers import Dense, Reshape, Conv2D, Input, LeakyReLU, Layer, UpSampling2D,Add,Flatten,AveragePooling2D, Lambda
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from skimage.transform import resize
from tensorflow.keras.optimizers import Adam, RMSprop
from math import sqrt
from numpy import load
from numpy import asarray
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
import math
from tensorflow.keras.applications.vgg19 import VGG19
from functools import partial
from scipy.linalg import sqrtm
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input
from skimage.transform import resize
from tensorflow.keras.applications.inception_v3 import preprocess_input

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Normalize layer

In [None]:
class PixelNormalization(Layer):
    def __init__(self,**kwargs):
        super(PixelNormalization,self).__init__(**kwargs)
    def call(self,x):
        value = x**2
        mean = K.mean(value, axis = -1, keepdims=True)
        x = x/K.sqrt(mean + 1.0e-8)

    return x

In [None]:
class WeightedSum(Add):
    # init with default value
    def __init__(self, alpha=0.0, **kwargs):
        super(WeightedSum, self).__init__(**kwargs)
        self.alpha = K.variable(alpha, name='ws_alpha')

    def get_config(self):
        config = super().get_config().copy()
        config.update({
                'alpha' : 0.0
        })
        return config

    # output a weighted sum of inputs
    def _merge_function(self, inputs):
        # only supports a weighted sum of two inputs
        assert (len(inputs) == 2)
        # ((1-a) * input1) + (a * input2)
        output = ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1])
        return output

In [None]:
class MinibatchStdev(Layer):
    # initialize the layer
    def __init__(self, **kwargs):
        super(MinibatchStdev, self).__init__(**kwargs)

    # perform the operation
    def call(self, inputs):
        # calculate the mean value for each pixel across channels
        mean = K.mean(inputs, axis=0, keepdims=True)
        # calculate the squared differences between pixel values and mean
        squ_diffs = K.square(inputs - mean)
        # calculate the average of the squared differences (variance)
        mean_sq_diff = K.mean(squ_diffs, axis=0, keepdims=True)
        # add a small value to avoid a blow-up when we calculate stdev
        mean_sq_diff += 1e-8
        # square root of the variance (stdev)
        stdev = K.sqrt(mean_sq_diff)
        # calculate the mean standard deviation across each pixel coord
        mean_pix = K.mean(stdev, keepdims=True)
        # scale this up to be the size of one input feature map for each sample
        shape = K.shape(inputs)
        output = K.tile(mean_pix, (shape[0], shape[1], shape[2], 1))
        # concatenate with the output
        combined = K.concatenate([inputs, output], axis=-1)
        return combined

    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        # create a copy of the input shape as a list
        input_shape = list(input_shape)
        # add one to the channel dimension (assume channels-last)
        input_shape[-1] += 1
        # convert list to a tuple
        return tuple(input_shape)

# Wgan Loss

In [6]:
# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [7]:
def dummy_loss_function(y_true, y_pred):
    return y_pred

# Build generator

In [4]:
def add_generator_block(old_model):
    # weight initialization
    init = tf.keras.initializers.RandomNormal(stddev=0.02)
    # weight constraint
    const = tf.keras.constraints.max_norm(1.0)
    # get the end of the last block
    block_end = old_model.layers[-2].output
    # upsample, and define new block
#     with strategy.scope():
    upsampling = UpSampling2D()(block_end)
    g = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(upsampling)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    g = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    # add new output layer
    out_image = Conv2D(3, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
    # define model
    model1 = Model(old_model.input, out_image)
    # get the output layer from old model
    out_old = old_model.layers[-1]
    # connect the upsampling to the old output layer
    out_image2 = out_old(upsampling)
    # define new output image as the weighted sum of the old and new models
    merged = WeightedSum()([out_image2, out_image])
    # define model
    model2 = Model(old_model.input, merged)
    return [model1, model2]



def get_loss(args):
    new_feature, content_feature = args[0], args[1]
    loss = 0
    for i in range(len(new_feature)):
        loss += K.mean(K.square(new_feature[i] - content_feature[i]))
    return loss


    # define generator models
def define_generator(latent_dim, n_blocks, in_dim=4):
    # weight initialization
    init = tf.keras.initializers.RandomNormal(stddev=0.02)
    # weight constraint
    const = tf.keras.constraints.max_norm(1.0)
    model_list = list()
    # base model latent input
#     with strategy.scope():
    in_latent = Input(shape=(latent_dim,))
    # linear scale up to activation maps
    g  = Dense(128 * in_dim * in_dim, kernel_initializer=init, kernel_constraint=const)(in_latent)
    g = Reshape((in_dim, in_dim, 128))(g)
    #########################################
    # conv 4x4, input block
    g = Conv2D(128, (4,4), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    # conv 3x3
    g = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    # conv 1x1, output block
    out_image = Conv2D(3, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(g)
    #########################################3
    # define model
    model = Model(in_latent, out_image)
    # store model
    model_list.append([model, model])
    # create submodels
    for i in range(1, n_blocks):
        # get prior model without the fade-on
        old_model = model_list[i - 1][0]
        # create new model for next resolution

        models = add_generator_block(old_model)
        # store model
        model_list.append(models)
    return model_list


# Build discriminator backbone

In [5]:
def add_discriminator_block(old_model, n_input_layers=3):
    # weight initialization
    init = tf.keras.initializers.RandomNormal(stddev=0.02)
    # weight constraint
    const = tf.keras.constraints.max_norm(1.0)
    # get shape of existing model
#     with strategy.scope():
    in_shape = list(old_model.input.shape)
    # define new input shape as double the size
    input_shape = (in_shape[-2]*2, in_shape[-2]*2, in_shape[-1])
    in_image = Input(shape=input_shape)
    # define new input processing layer
    d = Conv2D(128, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    # define new block
    d = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = AveragePooling2D()(d)
    # d = RoiPoolingConv(2)(d)
    block_new = d
    # skip the input, 1x1 and activation for the old model
    for i in range(n_input_layers, len(old_model.layers)):
        d = old_model.layers[i](d)
    # define straight-through model
    model1 = Model(in_image, d)
    # compile model
#     logdir = 'd_log'
#     callback1 = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
#         model1.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.001, clipnorm=1.0))
    model1.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.001))
    # downsample the new larger image
    downsample = AveragePooling2D()(in_image)
    # connect old input processing to downsampled new input
    block_old = old_model.layers[1](downsample)
    block_old = old_model.layers[2](block_old)
    # fade in output of old model input layer with new input
    d = WeightedSum()([block_old, block_new])
    # skip the input, 1x1 and activation for the old model
    for i in range(n_input_layers, len(old_model.layers)):
        d = old_model.layers[i](d)
    # define straight-through model
    model2 = Model(in_image, d)
    # compile model
#     model2.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.001,clipnorm=1.0))
    return [model1, model2]


    # define the discriminator models for each image resolution
def define_discriminator(n_blocks, input_shape=(4,4,3)):
    # weight initialization
    init = tf.keras.initializers.RandomNormal(stddev=0.02)
    # weight constraint
    const = tf.keras.constraints.max_norm(1.0)
    model_list = list()
    # base model input
#     with strategy.scope():
    in_image = Input(shape=input_shape)
    # conv 1x1
    d = Conv2D(128, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    # conv 3x3 (output block)
    d = MinibatchStdev()(d)
    d = Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
    d = LeakyReLU(alpha=0.2)(d)
    # conv 4x4
    d = Conv2D(128, (4,4), padding='same', kernel_initializer=init, kernel_constraint=const)(d)
    d = LeakyReLU(alpha=0.2)(d)
    # dense output layer
    d = Flatten()(d)
    out_class = Dense(1)(d)
    # define model
    model = Model(in_image, out_class)
    # compile model
#     model.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.001))
    # store model
    model_list.append([model, model])
    # create submodels
    for i in range(1, n_blocks):
        # get prior model without the fade-on
        old_model = model_list[i - 1][0]
        # create new model for next resolution

        models = add_discriminator_block(old_model)

        model_list.append(models)
    return model_list

# Gradient Penalty for Wgan loss

In [None]:
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def call(self, inputs, **kwargs):
        alpha = tf.random.uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

    def compute_output_shape(self, input_shape):
        return input_shape[0]
    
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'batch_size': self.batch_size,
            
        })
        return config

In [None]:
def gradient_penalty(y_true, y_pred, discriminator,gradient_penalty_weight):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
      

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(y_pred)
            # 1. Get the discriminator output for this interpolated image.
            pred = discriminator(y_pred, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [y_pred])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp*gradient_penalty_weight

# Build discriminator

In [None]:
def build_discriminator_model(discriminator, batch_size):
#     with strategy.scope():
    real_input = discriminator.input
    shape = real_input.shape[1:]
    fake_input = Input(shape = shape)

    discriminator_output_from_generator = discriminator(fake_input)
    discriminator_output_from_real_samples = discriminator(real_input)

    averaged_samples = RandomWeightedAverage(batch_size = batch_size)([real_input,
                                            fake_input])
#     averaged_samples_out = discriminator(averaged_samples)
    partial_gp_loss = partial(gradient_penalty,
                          discriminator = discriminator,
                          gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
# # Functions need names or Keras will throw an error
    partial_gp_loss.__name__ = 'gradient_penalty'



    discriminator_model = Model(inputs = [real_input,fake_input], 
                  outputs = [discriminator_output_from_real_samples,discriminator_output_from_generator,averaged_samples])

    discriminator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),
                            loss=[wasserstein_loss,
                                  wasserstein_loss,
                                   partial_gp_loss])
    return discriminator_model


# Build GAN model (generator + freeze discriminator)

In [8]:
def define_composite(discriminators, generators):
    model_list = []
#     with strategy.scope():
    for i in range(len(discriminators)):
    #     if i != len(discriminators) -1 :
        discriminators[i][0].trainable = False
        input = generators[i][0].output
        output = discriminators[i][0](input)
        model1 = Model(generators[i][0].input, output)
        model1.compile(loss = wasserstein_loss, optimizer = RMSprop(lr=0.001))


        discriminators[i][1].trainable = False
        input = generators[i][1].output
        output = discriminators[i][1](input)
        model2 = Model(generators[i][1].input, output)
        model2.compile(loss = wasserstein_loss, optimizer = RMSprop(lr=0.001))
#         model2.compile(loss = wasserstein_loss, optimizer = RMSprop(lr=0.001, clipnorm=1.0))
        model_list.append([model1, model2])


    return model_list

# Process data

In [14]:
def load_real_samples(dataset):
    X = np.array(dataset)
    X = X.astype('float32')
    X = (X - 127.5)/127.5
    return X

In [10]:
def generate_real_samples(dataset, n_samples):
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # select images
    X = dataset[ix]
    # generate class labels
    y = ones((n_samples, 1))
    return X, y

In [11]:
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = generator.predict(x_input)
    # create class labels
    y = -ones((n_samples, 1))
    return X, y

In [12]:
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

# Update alpha in fadein model

In [None]:
def update_fadein(models, step, n_steps):
    # calculate current alpha (linear from 0 to 1)
    alpha = step / float(n_steps - 1)
    # update the alpha for each model
    for model in models:
        for layer in model.layers:
            if isinstance(layer, WeightedSum):
                K.set_value(layer.alpha, alpha)

# Save model and plot images

In [13]:
def summarize_performance(dir,status, g_model, latent_dim, n_samples = 25, shape = (4,4)):
    # devise name
#     gen_shape = g_model.output_shape
    gen_shape = shape
    print(gen_shape)
    name = '%03dx%03d-%s' % (gen_shape[1], gen_shape[2], status)
    # generate images
    if dir == 'g_model/':
        X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
        # normalize pixel values to the range [0,1]
        X = (X - X.min()) / (X.max() - X.min())
        # plot real images
        square = int(sqrt(n_samples))
        for i in range(n_samples):
            plt.subplot(square, square, 1 + i)
            plt.axis('off')
            plt.imshow(X[i])
        # save plot to file
        filename1 = 'plot_%s.png' % (name)
        plt.savefig(filename1)
        plt.close()
    # save the generator model
        filename2 = dir + 'model_%s.h5' % (name)
        g_model.save(filename2)
        print('>Saved: %s and %s' % (filename1, filename2))
    else:
        filename2 = dir + 'model_%s.h5' % (name)
        g_model.save(filename2)
        print('>Saved: %s' % ( filename2))

# Train for one stage

In [16]:
def train_epochs(g_model, d_model, gan_model, dataset, latent_dim, n_epochs, n_batch, writer,fadein=False):
    # calculate the number of batches per training epoch
    bat_per_epo = int(dataset.shape[0] / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # calculate the size of half a batch of samples
    half_batch = int(n_batch / 2)
    
    fid_score_min = 10000
    patient = 0
    fid_list = []
#     gen_shape = g_model.output_shape
#     name = 'faded' if fadein  else 'tuned'
#     name_model = '%03dx%03d-%s' % (gen_shape[1], gen_shape[2], name)
    # manually enumerate epochs
    for i in range(n_steps):
        # update alpha for all WeightedSum layers when fading in new blocks
        if fadein:
            update_fadein([g_model, d_model, gan_model], i, n_steps)
        # prepare real and fake samples
#         X_real, y_real = generate_real_samples(dataset, half_batch)
#         X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
        dummy = np.zeros((half_batch,1))
        # update discriminator model
#         d_loss1 = d_model.train_on_batch(X_real, y_real)
            
#         d_loss2 = d_model.train_on_batch(X_fake, y_fake)
        
        for _ in range(2):
            X_real, y_real = generate_real_samples(dataset, half_batch)
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            d_loss = d_model.train_on_batch([X_real,X_fake], [y_real,y_fake,dummy])
        
        
        # update the generator via the discriminator's error
        
        z_input = generate_latent_points(latent_dim, n_batch)
       
        y_real2 = ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(z_input, y_real2)
#         if isinstance(g_loss,list):
#             g_loss = g_loss[0]
        
    ### Write to tensorboard
        if i % 500 == 0 and fadein == False:
            with writer.as_default():
                tf.summary.scalar('d_loss',d_loss[0] , step=i)
                writer.flush()
#                 tf.summary.scalar('d_loss2',d_loss2 , step=i)
#                 writer.flush()
                tf.summary.scalar('g_loss',g_loss , step=i)
                writer.flush()

    
        if i %  1000 == 0:
#             print('>%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, d_loss1, d_loss2, g_loss))
            print('>%d, d=%.3f, g=%.3f' % (i+1, d_loss[0], g_loss))
#             print("fid_score:", fid_score)
    


# Train function

In [18]:
def train(g_models, d_models, gan_models, dataset, latent_dim, e_norm, e_fadein, n_batch):
    # fit the baseline model

    g_normal, d_normal, gan_normal = g_models[0][0], d_models[0][0], gan_models[0][0]
    # scale dataset to appropriate size
    gen_shape = g_normal.output_shape
    scaled_data = scale_dataset(dataset, gen_shape[1:])
    print('Scaled Data', scaled_data.shape)
    # train normal or straight-through models
    writer = tf.summary.create_file_writer(logdir = 'log_1/log_0')
    train_epochs(g_normal, d_normal, gan_normal, scaled_data,latent_dim, e_norm[0], n_batch[0],writer, False)
    
    ### Save model
    summarize_performance('g_model/','tuned', g_normal, latent_dim,shape = gen_shape)
    summarize_performance('d_model/','tuned', d_normal, latent_dim, shape = gen_shape)
    summarize_performance('gan_model/','tuned', gan_normal, latent_dim, shape = gen_shape)
    
    # process each level of growth
    for i in range(1, len(g_models)):

        # retrieve models for this level of growth
        writer = tf.summary.create_file_writer(logdir = 'log_1/log_{}'.format(i))
        scaled_data = []
        
        # scale dataset to appropriate size
        [g_normal, g_fadein] = g_models[i]
        [d_normal, d_fadein] = d_models[i]
        [gan_normal, gan_fadein] = gan_models[i]
        gen_shape = g_normal.output_shape
        scaled_data = scale_dataset(dataset, gen_shape[1:])
        print('Scaled Data', scaled_data.shape)
        # train fade-in models for next level of growth
        
        ### Save model
        train_epochs(g_fadein, d_fadein, gan_fadein, scaled_data,latent_dim, e_fadein[i], n_batch[i],writer, True)
        summarize_performance('g_model/','faded', g_fadein, latent_dim, shape = gen_shape)
        summarize_performance('d_model/','faded', d_fadein, latent_dim, shape = gen_shape)
        summarize_performance('gan_model/','faded', gan_fadein, latent_dim, shape = gen_shape)
        # train normal or straight-through models
        train_epochs(g_normal, d_normal, gan_normal, scaled_data,latent_dim, e_norm[i], n_batch[i],writer, False)
        summarize_performance('g_model/','tuned', g_normal, latent_dim, shape = gen_shape)
        summarize_performance('d_model/','tuned', d_normal, latent_dim, shape = gen_shape)
        summarize_performance('gan_model/','tuned', gan_normal, latent_dim, shape = gen_shape)


# Training

In [1]:
latent_dims = 100 # dimension of latent vector
g_models = define_generator(latent_dim = 100, n_blocks = 6,  in_dim = 4)
d_models  = define_discriminator(n_blocks = 6, input_shape=(4,4,3)) # discriminator backbone
dis_model = [] # discriminator model
# n_batch = [1024,1024, 1024, 512, 128, 64]
n_batch = [64,64,32,16,8,8] # batch_size


## build discriminator 
for i in range(len(d_models)):
    dis_model.append([build_discriminator_model(d_models[i][0],int(n_batch[i]/2)),build_discriminator_model(d_models[i][1], int(n_batch[i]/2))])

## build gan model
gan_models = define_composite(d_models, g_models)

# load dataset from Celeb face dataset (replace with custom dataset)
dataset = load_real_samples(all_faces)

n_epochs = [10, 16, 16, 20, 20, 20]


train(g_models, dis_model, gan_models, dataset, latent_dims, n_epochs, n_epochs, n_batch)


NameError: name 'define_generator' is not defined