In [None]:
#
# This is an example for a GAN handwritten digits generation
# The optimization is done with alternating training
#

import os
import time
import shutil
import numpy as np
import sys
from matplotlib import pyplot as plt
%matplotlib inline

from keras.models import Sequential
from keras.layers import Dense
from keras.utils.layer_utils import count_params  
from tensorflow.keras import optimizers
import tensorflow as tf

import torchvision
from torchvision import datasets, transforms

import torch
from torch import nn

import csv
from scipy.linalg import norm
from scipy.optimize import minimize, LinearConstraint

np.random.seed(1)

In [145]:
# 
# function to set seeds for the imports to repeat results
#

def set_seed(tmp_seed):
    np.random.seed(tmp_seed)
    tf.random.set_seed(tmp_seed)
    torch.manual_seed(tmp_seed)

In [146]:
#
# function to creat folder for saving results
#

def create_folder(tmp_seed):

    # Set the path the results should be saved in
    SAVEPATH = R'saved_networks/alternating_training/_seed_'
    
    # create folder if not exist
    if os.path.exists(os.path.normpath(SAVEPATH +str(tmp_seed))):
        shutil.rmtree(os.path.normpath(SAVEPATH +str(tmp_seed)))
        os.makedirs(os.path.normpath(SAVEPATH +str(tmp_seed)))
    else:
        os.makedirs(os.path.normpath(SAVEPATH +str(tmp_seed)))

    return SAVEPATH

In [None]:
#
# function to prepare the training set
#

def prepare(batch_size):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    train_set = torchvision.datasets.MNIST(
        root="../", train=True, download=False, transform=transform
    )

    indices = train_set.targets==3
    train_set.data , train_set.targets = train_set.data[indices], train_set.targets[indices]

    # 6131 is len of data. with batch size 32 this is not possible, so 6131/32 = 191.59 -> 191 executions
    train_set.data = train_set.data[:6112]
    train_set.targets = train_set.targets[:6112]

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True
    )

    return train_loader

In [148]:
#
# This cell is for drawing the image
# if the parameter save is set, the images are also saved
#

def view_samples(samples,save,path,tmp_seed,epoch):
    for i in range(16):
        ax = plt.subplot(4, 4, i + 1)
        plt.imshow(samples[i].reshape(28, 28), cmap="gray_r")
        plt.xticks([])
        plt.yticks([])
    if(save):
        if not os.path.isdir(os.path.normpath(path +str(tmp_seed)+R'/plots')):
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/pdf'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/png'))
            #os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/png/image_values'))
            os.makedirs(os.path.normpath(path +str(tmp_seed)+R'/plots/networks'))
        plt.savefig(os.path.normpath(path +str(tmp_seed)+R'/plots/pdf/generated_images_e%d.pdf'%epoch), bbox_inches='tight')
        plt.savefig(os.path.normpath(path +str(tmp_seed)+R'/plots/png/generated_images_e%d.png'%epoch), bbox_inches='tight')
    plt.show()
    plt.close()

In [149]:
#
# This cell defines the discriminator network which is a network
# input size: 784 (shape to (None,1,28,28))
# hidden layer: 512
# output size: 1
#

def setup_dis(lr):
    discriminator = Sequential()
    discriminator.add(tf.keras.layers.Reshape((784,), input_shape=(1,28,28)))
    discriminator.add(Dense(512, activation='relu'))
    discriminator.add(Dense(1, activation='sigmoid'))
    org_weights_dis = discriminator.get_weights()
    discriminator.compile(loss= tf.keras.losses.BinaryCrossentropy(), optimizer=optimizers.SGD(learning_rate=lr),metrics=["accuracy"])
    
    return discriminator, org_weights_dis

In [150]:
#
# This cell defines the generator network which is a network
# input size: 10
# hidden layer: 512
# output size: 784 (shape to (None,1,28,28))
#

def setup_gen():
    generator = Sequential()
    generator.add(Dense(512, input_dim=10, activation='relu'))
    generator.add(Dense(784, activation='tanh'))
    generator.add(tf.keras.layers.Reshape((1, 28, 28)))
    org_weights_gen = generator.get_weights()

    return generator, org_weights_gen

In [151]:
#
# This cell defines the GAN network which combines the Generator and the Discriminator.
# In case we dont want to train the Discriminator when we train the Generator we have to set the discriminator.trainable to False
#

def setup_gan(discriminator, generator,lr):
    
    discriminator.trainable = False
    GAN = Sequential()
    GAN.add(generator)
    GAN.add(discriminator)
    GAN.compile(loss=tf.keras.losses.BinaryCrossentropy() , optimizer=optimizers.SGD(learning_rate=lr),metrics=["accuracy"])

    return GAN

In [152]:
# 
# prepare training by setting parameters
# every digit will be trained 30 times. -> 30 num_epochs * len_trainloader
#

batch_size = 32
num_epochs = 30
len_trainloader = 191
number_calculations = len_trainloader*num_epochs


# set up the label for fake or real images
# labels = 0 for discriminator for fake images
# labels = 1 for discriminator for real images and generator for fake images
real_samples_labels = np.ones((batch_size, 1))
generated_samples_labels = np.zeros((batch_size, 1))

In [153]:
#
#   Train function to train the GAN
#

def train(discriminator, generator, GAN, org_weights_dis,org_weights_gen, SAVEPATH, tmp_seed,train_loader,lr,alpha):
    # setup the lists to plot the error after training
    error_discriminator = []
    error_generator = []
    
    # set the start time
    st = time.time()

    # reset weights for each training
    generator.set_weights(org_weights_gen)
    discriminator.set_weights(org_weights_dis)

    # train the GAN
    for epoch in range(1,number_calculations+1):
        # use next samples of trainloader 
        real_samples, mnist_labels = next(iter(train_loader))

        # Data for training the discriminator
        latent_space_samples = torch.randn((batch_size, 10))

        generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)
        all_samples,all_samples_labels = np.vstack((real_samples, generated_samples)),np.vstack((real_samples_labels, generated_samples_labels))
        
        d_loss, _ = discriminator.train_on_batch(all_samples, all_samples_labels)
        
        g_loss,_ = GAN.train_on_batch(np.vstack(latent_space_samples), real_samples_labels)
     
        # Show loss
        if(epoch%len_trainloader==0):
            k = epoch/len_trainloader
            generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)
            all_samples,all_samples_labels = np.vstack((real_samples, generated_samples)),np.vstack((real_samples_labels, generated_samples_labels))
            d_loss, _ = discriminator.evaluate(all_samples, all_samples_labels, verbose=0)
            g_loss,_ = GAN.evaluate(np.vstack(latent_space_samples), real_samples_labels,verbose=0)

            print(f"Epoch: {k} Loss D.: {d_loss}")
            print(f"Epoch: {k} Loss G.: {g_loss}")
            with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:
                # summarize discriminator performance
                outfile.write(f"Epoch: {k} Loss D.: {d_loss}"+"\n")
                outfile.write(f"Epoch: {k} Loss G.: {g_loss}"+"\n")
                
            generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)

            view_samples(generated_samples,True,SAVEPATH,tmp_seed,k)

            #decrease learning_rate
            lr = lr*alpha

            # reset trainloader
            train_loader = prepare(batch_size)

    # get the end time
    et = time.time()

    # get the execution time
    elapsed_time = et - st

    # evaluate the loss after training
    latent_space_samples = torch.randn((batch_size, 10))

    generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)

    all_samples,all_samples_labels = np.vstack((real_samples, generated_samples)),np.vstack((real_samples_labels, generated_samples_labels))

    d_loss, _ = discriminator.evaluate(all_samples, all_samples_labels, verbose=0)
    g_loss,_ = GAN.evaluate(np.vstack(latent_space_samples), real_samples_labels,verbose=0)

    k = epoch/len_trainloader
    # print performance of last run
    print(f"Epoch: {k} Loss D.: {d_loss}")
    print(f"Epoch: {k} Loss G.: {g_loss}")
    with open(os.path.normpath(SAVEPATH +str(tmp_seed)+'/performance.txt'), "a+") as outfile:
        # summarize discriminator performance
        outfile.write(f"Epoch: {k} Loss D.: {d_loss}"+"\n")
        outfile.write(f"Epoch: {k} Loss G.: {g_loss}"+"\n")    

    # append the loss for plotting the loss for each lambda_weighted
    error_generator.append(g_loss)
    error_discriminator.append(d_loss)

    generated_samples = generator.predict(np.vstack(latent_space_samples),verbose=0)

    view_samples(generated_samples,True,SAVEPATH,tmp_seed,k)

    # save the weights of the generator and discriminator
    generator.save(os.path.normpath(SAVEPATH +str(tmp_seed)+'/plots/networks/generator.h5'))
    discriminator.save(os.path.normpath(SAVEPATH +str(tmp_seed)+'/plots/networks/discriminator.h5'))
        
    return error_generator, error_discriminator

In [None]:
# set the number of initializations
number_seeds = 3


all_errors_gen = []
all_errors_dis = []
start_alltime = time.time()

# setup learning rate and decrease factor
lr = 0.01
alpha = 0.99

# train GAN for each initialization
for i in range(number_seeds):
    # generate and set seed
    seed = np.random.randint(99999999)
    set_seed(seed)

    # create path
    path = create_folder(seed)

    # setup images for this seed
    train_loader = prepare(batch_size)

    # setup weights for this seed
    dis,org_dis = setup_dis(lr)
    gen,org_gen = setup_gen()
    gan = setup_gan(dis,gen,lr)

    # train the GAN for specific initialization
    error_generator, error_discriminator = train(dis,gen,gan, org_dis, org_gen,path,seed,train_loader,lr,alpha)
    all_errors_gen.extend(error_generator)
    all_errors_dis.extend(error_discriminator)

# save computation time of all seeds
end_alltime = time.time()

# get the execution time
computation_time = end_alltime - start_alltime
time_string = '>>>>>>>Complete computatiom time: %.3f seconds<<<<<<<' %np.round(computation_time,3)
print(time_string)

# save computatuion time
with open(os.path.normpath(path +str(seed) +'/../performance.txt'), "a+") as outfile:   
    outfile.write(time_string+"\n"+"\n")

# plot the gan errors for all seeds    
plt.scatter([all_errors_gen], [all_errors_dis], color= 'red',s = 20,cmap='RdYlGn_r', vmin=0, vmax=3)
plt.colorbar(label="image evaluation")
plt.title("GAN loss over seeds")
plt.xlabel('Generator_loss', fontsize=18)
plt.ylabel('Discriminator_loss', fontsize=16)
plt.savefig(os.path.normpath(path +str(seed)+'/../GAN_error_over_seeds.pdf')) 
plt.close()

#
# plot and save the losses for all seeds
#

plt.scatter([all_errors_gen], [all_errors_dis], color= 'red',s = 20,cmap='RdYlGn_r', vmin=0, vmax=3)
plt.colorbar(label="image evaluation")
plt.xscale('symlog')
plt.yscale('symlog')
plt.title("GAN loss over seeds in log scale")
plt.xlabel('Generator_loss', fontsize=18)
plt.ylabel('Discriminator_loss', fontsize=16)
plt.savefig(os.path.normpath(path +str(seed)+'/../GAN_error_over_seeds_logscale.pdf'))
plt.close()

#
# save the erros into csv file to reuse them
#

with open(os.path.normpath(path +str(seed)+'/../saved_errors.csv'), 'w') as f:
    # using csv.writer method from CSV package
    write = csv.writer(f)
    write.writerows([all_errors_gen])
    write.writerows([all_errors_dis])
