In [1]:
#
# This is an example for a GAN handwritten digits generation
# The optimization is done with the accelerated multiobjective gradient method
# Factor here is set to 1000 instead of 2
#

import os
import time
import shutil
import numpy as np
import sys
from matplotlib import pyplot as plt
%matplotlib inline

from keras.models import Sequential
from keras.layers import Dense
from keras.utils.layer_utils import count_params  
from tensorflow.keras import optimizers
import tensorflow as tf

import torchvision
from torchvision import datasets, transforms

import torch
from torch import nn

import csv
from scipy.linalg import norm
from scipy.optimize import minimize, LinearConstraint

np.random.seed(1)

In [2]:
# 
# function to set seeds for the imports to repeat results
#

def set_seed(tmp_seed):
    np.random.seed(tmp_seed)
    tf.random.set_seed(tmp_seed)
    torch.manual_seed(tmp_seed)

In [3]:
#
# function to creat folder for saving results
#

def create_folder(tmp_seed):

    # Set the path the results should be saved in
    SAVEPATH = R'saved_networks/Kons1_E0/_seed_'
    
    # create folder if not exist
    if os.path.exists(os.path.normpath(SAVEPATH +str(tmp_seed))):
        shutil.rmtree(os.path.normpath(SAVEPATH +str(tmp_seed)))
        os.makedirs(os.path.normpath(SAVEPATH +str(tmp_seed)))
    else:
        os.makedirs(os.path.normpath(SAVEPATH +str(tmp_seed)))

    return SAVEPATH

In [4]:
#
# function to prepare the training set
#

def prepare(batch_size):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    train_set = torchvision.datasets.MNIST(
        root="../", train=True, download=False, transform=transform
    )

    indices = train_set.targets==3
    train_set.data , train_set.targets = train_set.data[indices], train_set.targets[indices]

    # 6131 is len of data. with batch size 32 this is not possible, so 6131/32 = 191.59 -> 191 executions
    train_set.data = train_set.data[:6112]
    train_set.targets = train_set.targets[:6112]

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True
    )

    return train_loader

In [5]:
#
# This cell is for drawing the image
# if the parameter save is set, the images are also saved
#

def view_samples(samples,save,path,tmp_seed,epoch):
    for i in range(16):
        ax = plt.subplot(4, 4, i + 1)
        plt.imshow(samples[i].reshape(28, 28), cmap="gray_r")
        plt.xticks([])
        plt.yticks([])
    if(save):
        if not os.path.isdir(os.path.normpath(path +str(tmp_seed)+R'/plots')):
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/pdf'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/png'))
            #os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/png/image_values'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/networks'))
        plt.savefig(os.path.normpath(path +str(tmp_seed)+R'/plots/pdf/generated_images_e%d.pdf'%epoch), bbox_inches='tight')
        plt.savefig(os.path.normpath(path +str(tmp_seed)+R'/plots/png/generated_images_e%d.png'%epoch), bbox_inches='tight')
    plt.show()
    plt.close()

In [6]:
#
# This cell defines the discriminator network which is a network
# input size: 784 (shape to (None,1,28,28))
# hidden layer: 512
# output size: 1
#

def setup_dis():
    discriminator = Sequential()
    discriminator.add(tf.keras.layers.Reshape((784,), input_shape=(1,28,28)))
    discriminator.add(Dense(512, activation='relu'))
    discriminator.add(Dense(1, activation='sigmoid'))
    org_weights_dis = discriminator.get_weights()
    discriminator.compile(loss= tf.keras.losses.BinaryCrossentropy(), optimizer=optimizers.SGD(learning_rate=0.01),metrics=["accuracy"])
    
    return discriminator, org_weights_dis

In [7]:
#
# This cell defines the generator network which is a network
# input size: 10
# hidden layer: 512
# output size: 784 (shape to (None,1,28,28))
#

def setup_gen():
    generator = Sequential()
    generator.add(Dense(512, input_dim=10, activation='relu'))
    generator.add(Dense(784, activation='tanh'))
    generator.add(tf.keras.layers.Reshape((1, 28, 28)))
    org_weights_gen = generator.get_weights()

    return generator, org_weights_gen

In [8]:
#
# This cell defines the GAN network which combines the Generator and the Discriminator.
# In case we dont want to train the Discriminator when we train the Generator we have to set the discriminator.trainable to False
#

def setup_gan(discriminator, generator):
    
    discriminator.trainable = False
    GAN = Sequential()
    GAN.add(generator)
    GAN.add(discriminator)
    GAN.compile(loss=tf.keras.losses.BinaryCrossentropy() , optimizer=optimizers.SGD(learning_rate=0.01),metrics=["accuracy"])

    return GAN

In [9]:
#
# This is the function we want to minimize which is || h*( sum(xi * nabla fi(y^k))- ((k-1)/(k+1000)*(x^k - x^(k-1)))||^2
#

def f(x,params):
    A, fac, h = params
    return norm((h*(x[0]*A[0,:] + x[1] * A[1,:])) - fac )**2

In [10]:
#
# This is the function which call minimize function with the constraints.
# The constraints are xi>0 and sum(xi) = 1 
#

def minimization(A, fac, h):
    # this is the initial guess, here set to 1
    initial_guess = np.ones(2)

    # this is a Matrix [[1,0],[0,1]]
    B = np.zeros((2,2))
    B[0,0] = 1
    B[1,1] = 1

    # the lower bound verifies that xi >0 
    lb = np.array([0,0])

    # one constraint is for xi>0 and one for sum(xi) = 1
    const = [LinearConstraint(B,lb,np.inf), {'type':'eq','fun': lambda x:  np.sum(x)-1.0}]

    # minimize the function f with A factor and h
    res = minimize(f, initial_guess, method = 'SLSQP' , args= [A, fac, h], constraints = const, options={'ftol': 1e-6,'disp': False}) # here true for display

    # return the result
    return res

In [11]:
#
# This function calculates the gradients for the minimization problem
# For the jacobian Matrix here the off diagonal is set to 0
# The result is A = [[ED_D,0],[0,EG_G]]
#

def calculate_gradients(gan_model,gen,d_model, z, x_real, y_real,y_fake,batch_size,EG_wD,ED_wG):

    z = np.vstack(z)
    x_real = np.vstack(x_real).reshape(batch_size,1,28,28)

    # set the loss function for the GAN model to BinaryCrossentropy()
    bce = tf.keras.losses.BinaryCrossentropy()

    # Set all gradients to 0
    d_model.trainable = True
    
    # Calculate gradients with respect to every trainable variable for ED_D
    # must set the trainable weights to True to get results because otherwise the number of weights is 0
    EG_wG = []
    ED_wD = [] 
    with tf.GradientTape() as tape0, tf.GradientTape() as tape1:
        gan_pred = gan_model(z)
        loss1 = bce(y_real, gan_pred)

        dis_pred = d_model(x_real)
        X,Y = tf.reshape(tf.stack((dis_pred,gan_pred)),shape=(2*batch_size,1)), tf.reshape(tf.stack((y_real,y_fake)),shape=(2*batch_size,1))
        loss2 = bce(Y,X)

    EG_G = tape0.gradient(loss1, gen.trainable_weights)
    ED_D = tape1.gradient(loss2, d_model.trainable_weights)
    for i in range(len(ED_D)):
        ED_wD.extend(ED_D[i].numpy().reshape(-1))
    for i in range(len(EG_G)):
        EG_wG.extend(EG_G[i].numpy().reshape(-1))

    d_model.trainable = False

    # set the first row the Error of the Discriminator [ED_wD,0]
    ED = []
    ED.extend(ED_wD)
    ED.extend(ED_wG)

    # set the second row the Error of the Generator [0,ED_wG]
    EG = []
    EG.extend(EG_wD)
    EG.extend(EG_wG)


    # calculate jacobian matrix
    
    return np.array([ED,EG])

In [12]:
# 
# prepare training by setting parameters
# every digit will be trained 30 times. -> 30 num_epochs * len_trainloader
#

batch_size = 32
num_epochs = 30
len_trainloader = 191
number_calculations = len_trainloader*num_epochs


# set up the label for fake or real images
# labels = 0 for discriminator for fake images
# labels = 1 for discriminator for real images and generator for fake images
real_samples_labels = np.ones((batch_size, 1))
generated_samples_labels = np.zeros((batch_size, 1))

In [13]:
#
#   Train function to train the GAN
#

def train(discriminator, generator, GAN, org_weights_dis,org_weights_gen, SAVEPATH, tmp_seed,train_loader,EG_wD,ED_wG):

    # setup learning rate and decrease factor
    learning_rate = 0.01
    alpha = 0.99

    # setup the lists to plot the error after training
    error_discriminator = []
    error_generator = []
    
    # set the start time
    st = time.time()

    # reset weights for each training
    generator.set_weights(org_weights_gen)
    discriminator.set_weights(org_weights_dis)

    # set the at first x^(k-1) to x^k which are the actual weights
    x_gk_m1 = generator.get_weights()
    x_dk_m1 = discriminator.get_weights()

    # set k to 1
    k = 1

    # process the minimization until norm(v) < 1e-4 or we have reached number minimization steps in case this is also done for alternating training
    while (True):
        # use next samples of trainloader
        real_samples, mnist_labels = next(iter(train_loader))

        # generate batch_size many inputs and fake images
        latent_space_samples = torch.randn((batch_size, 10))

        # set x_k for the generator and the discriminator
        x_gk = generator.get_weights()
        x_dk = discriminator.get_weights()


        # y^k = x^k + ((k-1)/(k+1000)*(x^k-x^(k-1))) fpr discriminator and generator
        # for the discriminator
        # need first save as tmp because change directly the weights
        x_dk_tmp = discriminator.get_weights()
        x_gk_tmp = generator.get_weights()

        for i in range(len(x_dk_tmp)):
            x_dk_tmp[i] = x_dk_tmp[i] + ((k-1)/(k+1000)*(x_dk_tmp[i]-x_dk_m1[i])) 
        discriminator.set_weights(x_dk_tmp)
        # for the generator
        for i in range(len(x_gk_tmp)):
            x_gk_tmp[i] = x_gk_tmp[i] + ((k-1)/(k+1000)*(x_gk_tmp[i]-x_gk_m1[i]))
        generator.set_weights(x_gk_tmp) 

        # calculate jacobian matrix
        A = calculate_gradients(GAN,generator,discriminator, latent_space_samples, real_samples, real_samples_labels, generated_samples_labels,batch_size,EG_wD,ED_wG)
        # calculate factor = (((k-1)/(k+1000))*(x_k-x_k_m1))) for generator and discriminator
        # this factor is used in minimization problem
        factor = []
        for i in range(len(x_dk)):
            factor.extend((k-1)/(k+1000)*(x_dk[i].reshape(-1)-x_dk_m1[i].reshape(-1)))
        for i in range(len(x_gk)):
            factor.extend((k-1)/(k+1000)*(x_gk[i].reshape(-1)-x_gk_m1[i].reshape(-1)))
        # solve the minimization problem
        theta = minimization(A, factor, learning_rate)

        # now get the v to update the weights by calculate the sum of theta*A
        v = theta.x[0]*A[0,:] + theta.x[1]*A[1,:]

        # split v into v_discriminator and v_generator to update the weights for each network
        # in case that A is first row is ED and second row is EG first v is for weights of the discriminator
        discriminator.trainable = True
        v_dis = v[:count_params(discriminator.trainable_weights)]
        v_gen = v[count_params(discriminator.trainable_weights):]
        discriminator.trainable = False

        # set x_k_m1 to x_k
        # this is for the next iteration to have x^(k-1) = x^k in the next iteration
        x_gk_m1 = x_gk
        x_dk_m1 = x_dk

        # update the weights for the the discriminator and generator
        wd = discriminator.get_weights()
        for i in range(len(wd)):
            wd[i] = wd[i] - (learning_rate* v_dis[:wd[i].size]).reshape(wd[i].shape)
            v_dis = v_dis[wd[i].size:]
        discriminator.set_weights(wd)

        wg = generator.get_weights()
        for i in range(len(wg)):
            wg[i] = wg[i] - (learning_rate* v_gen[:wg[i].size]).reshape(wg[i].shape)
            v_gen = v_gen[wg[i].size:]
        generator.set_weights(wg)

        # if the norm is smaller than 1e-4 or the number of iterations greater than 1000 stop
        if (k >=number_calculations):
            break

        # Show loss
        if((k)%len_trainloader==0):
            epoch = k/len_trainloader
            generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)
            all_samples,all_samples_labels = np.vstack((real_samples, generated_samples)),np.vstack((real_samples_labels, generated_samples_labels))
            d_loss, _ = discriminator.evaluate(all_samples, all_samples_labels, verbose=0)
            g_loss,_ = GAN.evaluate(np.vstack(latent_space_samples), real_samples_labels,verbose=0)

            print(f"Epoch: {epoch} Loss D.: {d_loss}")
            print(f"Epoch: {epoch} Loss G.: {g_loss}")
            with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:
                # summarize discriminator performance
                outfile.write(f"Epoch: {epoch} Loss D.: {d_loss}"+"\n")
                outfile.write(f"Epoch: {epoch} Loss G.: {g_loss}"+"\n")
                
            generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)

            view_samples(generated_samples,True,SAVEPATH,tmp_seed,epoch)

            #decrease learning_rate
            learning_rate = learning_rate*alpha

            # reset trainloader
            train_loader = prepare(batch_size)
        # increase k
        k +=1

    # get the end time
    et = time.time()

    # get the execution time
    elapsed_time = et - st

    # evaluate the loss after training
    latent_space_samples = torch.randn((batch_size, 10))

    generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)

    all_samples,all_samples_labels = np.vstack((real_samples, generated_samples)),np.vstack((real_samples_labels, generated_samples_labels))

    d_loss, _ = discriminator.evaluate(all_samples, all_samples_labels, verbose=0)
    g_loss,_ = GAN.evaluate(np.vstack(latent_space_samples), real_samples_labels,verbose=0)
    
    epoch = (k-1)/len_trainloader
    # print performance of last run
    print(f"Epoch: {epoch} Loss D.: {d_loss}")
    print(f"Epoch: {epoch} Loss G.: {g_loss}")
    with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:
        # summarize discriminator performance
        outfile.write(f"Epoch: {epoch} Loss D.: {d_loss}"+"\n")
        outfile.write(f"Epoch: {epoch} Loss G.: {g_loss}"+"\n")

    # append the loss for plotting the loss for each lambda_weighted
    error_generator.append(g_loss)
    error_discriminator.append(d_loss)

    generated_samples = generator.predict(np.vstack(latent_space_samples))

    view_samples(generated_samples,True,SAVEPATH,tmp_seed,epoch)

    # save the weights of the generator and discriminator
    generator.save(os.path.normpath(SAVEPATH +str(tmp_seed)+'/plots/networks/generator.h5'))
    discriminator.save(os.path.normpath(SAVEPATH +str(tmp_seed)+'/plots/networks/discriminator.h5'))
        
    return error_generator, error_discriminator

In [None]:
# set the number of initializations
number_seeds = 3


all_errors_gen = []
all_errors_dis = []
start_alltime = time.time()

# train GAN for each initialization
for i in range(number_seeds):
    # generate and set seed
    seed = np.random.randint(99999999)
    set_seed(seed)

    # create path
    path = create_folder(seed)

    # setup images for this seed
    train_loader = prepare(batch_size)

    # setup weights for this seed
    dis,org_dis = setup_dis()
    gen,org_gen = setup_gen()
    gan = setup_gan(dis,gen)

    dis.trainable = True
    EG_wD = np.zeros((count_params(dis.trainable_weights), 1)).reshape(-1)
    ED_wG = np.zeros((count_params(gen.trainable_weights), 1)).reshape(-1)
    dis.trainable = False

    # train the GAN for specific initialization
    error_generator, error_discriminator = train(dis,gen,gan, org_dis, org_gen,path,seed,train_loader,EG_wD,ED_wG)
    all_errors_gen.extend(error_generator)
    all_errors_dis.extend(error_discriminator)

# save computation time of all seeds
end_alltime = time.time()

# get the execution time
computation_time = end_alltime - start_alltime
time_string = '>>>>>>>Complete computatiom time: %.3f seconds<<<<<<<' %np.round(computation_time,3)
print(time_string)

# save computatuion time
with open(os.path.normpath(path +str(seed) +'/../performance.txt'), "a+") as outfile:   
    outfile.write(time_string+"\n"+"\n")

# plot the gan errors for all seeds    
plt.scatter([all_errors_gen], [all_errors_dis], color= 'red',s = 20,cmap='RdYlGn_r', vmin=0, vmax=3)
plt.colorbar(label="image evaluation")
plt.title("GAN loss over seeds")
plt.xlabel('Generator_loss', fontsize=18)
plt.ylabel('Discriminator_loss', fontsize=16)
plt.savefig(os.path.normpath(path +str(seed)+'/../GAN_error_over_seeds.pdf')) 
plt.close()

#
# plot and save the losses for all seeds
#

plt.scatter([all_errors_gen], [all_errors_dis], color= 'red',s = 20,cmap='RdYlGn_r', vmin=0, vmax=3)
plt.colorbar(label="image evaluation")
plt.xscale('symlog')
plt.yscale('symlog')
plt.title("GAN loss over seeds in log scale")
plt.xlabel('Generator_loss', fontsize=18)
plt.ylabel('Discriminator_loss', fontsize=16)
plt.savefig(os.path.normpath(path +str(seed)+'/../GAN_error_over_seeds_logscale.pdf'))
plt.close()

#
# save the erros into csv file to reuse them
#

with open(os.path.normpath(path +str(seed)+'/../saved_errors.csv'), 'w') as f:
    # using csv.writer method from CSV package
    write = csv.writer(f)
    write.writerows([all_errors_gen])
    write.writerows([all_errors_dis])
