In [73]:
#
# This is an example for a GAN for a 3x3 image construction
# The optimization is done with the alternating (normal) training
#

import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
import os
import shutil
import time

from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf
from tensorflow.keras import optimizers
import csv

np.random.seed(1)

In [74]:
# 
# function to set seeds for the imports to repeat results
#

def set_seed(tmp_seed):
    np.random.seed(tmp_seed)
    tf.random.set_seed(tmp_seed)

In [75]:
#
# function to creat folder for saving results
#

def create_folder(tmp_seed):

    # Set the path the results should be saved in
    SAVEPATH = R'saved_networks/alternating_train/_seed_'
    
    # create folder if not exist
    if os.path.exists(os.path.normpath(SAVEPATH +str(tmp_seed))):
        shutil.rmtree(os.path.normpath(SAVEPATH +str(tmp_seed)))
        os.makedirs(os.path.normpath(SAVEPATH +str(tmp_seed)))
    else:
        os.makedirs(os.path.normpath(SAVEPATH +str(tmp_seed)))

    return SAVEPATH

In [76]:
#
# This cell is for drawing the 3x3 image
# if the parameter save is set, the images are also saved
#

def view_samples(samples, m, n,save,path,tmp_seed):
    fig, axes = plt.subplots(figsize=(10, 10), nrows=m, ncols=n, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(1-img.reshape((3,3)), cmap='Greys_r')
    if(save):
        if not os.path.isdir(os.path.normpath(path +str(tmp_seed)+R'/plots')):
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/pdf'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/png'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/png/image_values'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/networks'))
        plt.savefig(os.path.normpath(path +str(tmp_seed)+R'/plots/pdf/generated_images.pdf'), bbox_inches='tight')
        plt.savefig(os.path.normpath(path +str(tmp_seed)+R'/plots/png/generated_images.png'), bbox_inches='tight')
    plt.close()
    return fig, axes

In [77]:
#
# This cell sets up the real images
# Noise is added to a standard image
#

def initialize_images():
    num_images = 50
    faces = []
    standard_face = [0.9,0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.9]

    for i in range(num_images):
        noise = np.random.normal(0,0.05,9)
        noise_face = standard_face + noise
        faces.append(noise_face)
    faces = np.array(faces)

    _ = view_samples(faces, 1, 8,False,None,None)
    return faces

In [78]:
#
# This cell defines the discriminator network which is a network
# input size: 9
# output size: 1
#

def setup_dis():
    discriminator = Sequential()
    discriminator.add(Dense(1, input_dim=9, activation='sigmoid'))
    discriminator.compile(loss="binary_crossentropy", optimizer=optimizers.SGD(learning_rate=0.1),metrics=["accuracy"])

    # save the startweights for reset the weigths after training
    org_weights_dis = discriminator.get_weights()
    
    return discriminator, org_weights_dis

In [79]:
#
# This cell defines the generator network which is a network
# input size: 1
# output size: 9
#

def setup_gen():
    generator = Sequential()
    generator.add(Dense(9, input_dim=1, activation='sigmoid'))
    
    # save the startweights for reset the weigths after training
    org_weights_gen = generator.get_weights()

    return generator, org_weights_gen

In [80]:
#
# This cell defines the GAN network which combines the Generator and the Discriminator.
# In case we dont want to train the Discriminator when we train the Generator we have to set the discriminator.trainable to False
#

def setup_gan(discriminator, generator):
    
    discriminator.trainable = False
    GAN = Sequential()
    GAN.add(generator)
    GAN.add(discriminator)
    GAN.compile(loss="binary_crossentropy", optimizer=optimizers.SGD(learning_rate=0.1),metrics=["accuracy"])

    return GAN

In [81]:
# Define the number of epochs
epochs = 1000

# Set the number of real and fake images to train in epoch
batch_size = 8

# set up the label for fake or real images
# labels = 0 for discriminator for fake images
# labels = 1 for discriminator for real images and generator for fake images
Y_fake = np.zeros((batch_size, 1))
Y_real = np.ones((batch_size, 1))

In [82]:
#
#   This function generates the evaluation/color of the generated images
#   use sum of difference between real and fake images for each pixel
#

def calculate_color(generated_images):
    best_image = [0.9,0.1,0.1,0.1,0.9,0.1,0.1,0.1,0.9]
    norms = []
    for image in generated_images:
        image = np.array(image[0])
        diff = best_image-image
        norms.append(sum(abs(diff))) # Manhattan norm

    avg_norm = np.mean(norms)
    return avg_norm

In [83]:
#
#   Train function to train the GAN
#

def train(discriminator, generator, GAN, org_weights_dis,org_weights_gen, SAVEPATH, tmp_seed,faces):

    # setupt the lists to plot the error after training
    error_discriminator = []
    error_generator = []

    # initialize colorlist
    new_color = []
    
    # set the start time
    st = time.time()

    # reset weights for each training
    generator.set_weights(org_weights_gen)
    discriminator.set_weights(org_weights_dis)

    # process the training for number of epochs
    for k in range(epochs):
        # generate batch_size many inputs and fake images
        Z = np.random.uniform(size=batch_size)
        X_fake = generator.predict(Z,verbose=0)

        # choose batch_size many real images from the pool of real images
        X_real = np.random.permutation(faces)[:batch_size]
            
        # combine fake and real images for training the Discriminator
        X, Y = np.vstack((X_real, X_fake)), np.vstack((Y_real, Y_fake))

        # Train the generator for k steps
        for i in range(1):
                d_loss,_ = discriminator.train_on_batch(X, Y)

        # train the generator for 1 step and add the error to the list
        g_loss,_ = GAN.train_on_batch(Z, Y_real)

        # print the loss every 100 iterations
        if((k)%100 == 0):

            # calculate loss for plotting after 100 iterations
            # combine fake and real images
            X_fake = generator.predict(Z,verbose = 0)
            X, Y = np.vstack((X_real, X_fake)), np.vstack((Y_real, Y_fake))
            d_loss, _ = discriminator.evaluate(X, Y, verbose=0)

            g_loss,_ = GAN.evaluate(Z, Y_real,verbose=0)

            loss_string = '>>Loss iteration:%d, d=%.4f, g=%.4f' % (k, d_loss, g_loss)

            print(loss_string)
            
            # save the loss ever 100 iteration
            with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:
                # summarize discriminator performance
                outfile.write(loss_string+"\n")

    # generate new images to evaluate the model after training
    Z = np.random.uniform(size=batch_size)
    X_fake = generator.predict(Z,verbose = 0)
    X_real = np.random.permutation(faces)[:batch_size]
    X, Y = np.vstack((X_real, X_fake)), np.vstack((Y_real, Y_fake))

    d_loss, _ = discriminator.evaluate(X, Y, verbose=0)
    g_loss,_ = GAN.evaluate(Z, Y_real,verbose=0)

    # save last iteration
    loss_string = '>>Loss iteration:%d, d=%.4f, g=%.4f' % (k, d_loss, g_loss)

    print(loss_string)
    # save the loss ever 100 iteration
    with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:
        # summarize discriminator performance
        outfile.write(loss_string+"\n")        
        
    # evaluate the discriminator after training
    # evaluate discriminator on real examples
    _, acc_real = discriminator.evaluate(np.vstack(X_real),Y_real, verbose=0)
    # evaluate discriminator on fake examples
    _, acc_fake = discriminator.evaluate(np.vstack(X_fake),Y_fake, verbose=0)

    # get the end time
    et = time.time()

    # get the execution time
    elapsed_time = et - st
    
    # print accuracy and time
    acc_string = '>Accuracy discriminator: real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100)
    time_string = '>>>>>>>Execution time: %.3f seconds<<<<<<<' %np.round(elapsed_time,3)
    print(acc_string)
    print(time_string)

    # save accuracy and time 
    with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:        
        # Calculate the discriminator accuracy
        outfile.write(acc_string+"\n")
        # save the acutal norm of v
        outfile.write(time_string+"\n"+"\n")
        
    # append the loss for plotting the loss for each lambda_weighted
    error_generator.append(g_loss)
    error_discriminator.append(d_loss)

    # generate images
    generated_images = []
    for i in range(4):
        z = np.random.randn(1,1)
        z = [z]
        generated_image = generator.predict(z,verbose=0)
        generated_images.append(generated_image)
    _ = view_samples(generated_images, 1, 4,True,SAVEPATH,tmp_seed)

    # claculate the color for the generated images
    new_color.append(calculate_color(generated_images))

    # save the values of the generated images
    for i in generated_images:
        with open(os.path.normpath(SAVEPATH +str(tmp_seed)+R'/plots/png/image_values/generated_images.txt'), "a+") as outfile:
            # summarize discriminator performance
            outfile.write(str(i)+"\n")

    # save the weights of the generator and discriminator
    generator.save(os.path.normpath(SAVEPATH +str(tmp_seed)+'/plots/networks/generator.h5'))
    discriminator.save(os.path.normpath(SAVEPATH +str(tmp_seed)+'/plots/networks/discriminator.h5'))
        
    return error_generator, error_discriminator, new_color

In [None]:
# set the number of initializations
number_seeds = 10

all_errors_gen = []
all_errors_dis = []
all_colors = []
start_alltime = time.time()

# train GAN for each initialization
for i in range(number_seeds):
    # generate and set seed
    seed = np.random.randint(99999999)
    set_seed(seed)

    # create path
    path = create_folder(seed)

    # setup images for this seed
    faces = initialize_images()

    # setup weights for this seed
    dis,org_dis = setup_dis()
    gen,org_gen = setup_gen()
    gan = setup_gan(dis,gen)

    # train the GAN for specific initialization
    error_generator, error_discriminator, new_color = train(dis,gen,gan, org_dis, org_gen,path,seed,faces)
    all_errors_gen.extend(error_generator)
    all_errors_dis.extend(error_discriminator)
    all_colors.extend(new_color)

# save computation time of all seeds
end_alltime = time.time()
# get the execution time
computation_time = end_alltime - start_alltime
time_string = '>>>>>>>Complete computatiom time: %.3f seconds<<<<<<<' %np.round(computation_time,3)
print(time_string)

# save number of good and bad images
n_green = 0
n_yellow = 0
n_red = 0
for val in all_colors:
    if(val<=1):
        n_green +=1
    elif(val<=2):
        n_yellow +=1
    else:
        n_red+=1

# save computatuion time and number images
with open(os.path.normpath(path +str(seed) +'/../performance.txt'), "a+") as outfile:        

    outfile.write("Number of great images:"+str(n_green)+"\n")
    outfile.write("Number of good images:"+str(n_yellow)+"\n")
    outfile.write("Number of bad images:"+str(n_red)+"\n")
    outfile.write(time_string+"\n"+"\n")

# plot the gan errors for all seeds    
plt.scatter([all_errors_gen], [all_errors_dis], c= all_colors,s = 20,cmap='RdYlGn_r', vmin=0, vmax=3)
plt.colorbar(label="image evaluation")
plt.title("GAN loss over seeds")
plt.xlabel('Generator_loss', fontsize=18)
plt.ylabel('Discriminator_loss', fontsize=16)
plt.savefig(os.path.normpath(path +str(seed)+'/../GAN_error_over_seeds.pdf')) 
plt.close()

#
# plot and save the losses for all seeds
#

plt.scatter([all_errors_gen], [all_errors_dis], c= all_colors,s = 20,cmap='RdYlGn_r', vmin=0, vmax=3)
plt.colorbar(label="image evaluation")
plt.xscale('symlog')
plt.yscale('symlog')
plt.title("GAN loss over seeds in log scale")
plt.xlabel('Generator_loss', fontsize=18)
plt.ylabel('Discriminator_loss', fontsize=16)
plt.savefig(os.path.normpath(path +str(seed)+'/../GAN_error_over_seeds_logscale.pdf'))
plt.close()

#
# save the erros into csv file to reuse them
#

with open(os.path.normpath(path +str(seed)+'/../saved_errors.csv'), 'w') as f:
    # using csv.writer method from CSV package
    write = csv.writer(f)
    write.writerows([all_errors_gen])
    write.writerows([all_errors_dis])
    write.writerows([all_colors])