In [None]:
# imports
import os
import numpy as np
import math
import tensorflow as tf
import copy

sys.path.append('gan_code/')
import DataLoader 
import importlib
importlib.reload(DataLoader)

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import activations
from tensorflow.keras.models import Model
from functools import partial
tf.keras.backend.set_floatx('float32')


# Define GAN architecture

In [None]:
# Define the generator network
initializer = tf.keras.initializers.he_uniform()
bias_node = True
noise = layers.Input(shape=(50), name="Noise")
condition = layers.Input(shape=(2), name="mycond")
con = layers.concatenate([noise,condition])
G = layers.Dense(50, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(con)  
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(100, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(200, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)

generator = Model(inputs=[noise, condition], outputs=G)
generator.build(370)
generator.summary()



In [None]:
# Define the Discriminator network
initializer = tf.keras.initializers.he_uniform()
bias_node = True

image = layers.Input(shape=(368), name="Image")
d_condition = layers.Input(shape=(2), name="mycond")
d_con = layers.concatenate([image,d_condition])
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(d_con)  
D = layers.Activation(activations.relu)(D)
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)
D = layers.Activation(activations.relu)(D)
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)
D = layers.Activation(activations.relu)(D)
D = layers.Dense(1, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)

discriminator = Model(inputs=[image, d_condition], outputs=D)
discriminator.build(370)
discriminator.summary()


# Train, loss anf gradient functions

In [None]:

@tf.function
def gradient_penalty(f, x_real, x_fake, cond_label, batchsize, D):
  alpha = tf.random.uniform([batchsize, 1], minval=0., maxval=1.)

  inter = alpha * x_real + (1-alpha) * x_fake
  with tf.GradientTape() as t:
    t.watch(inter)
    pred = D(inputs=[inter, cond_label])
  grad = t.gradient(pred, [inter])[0]
  
  slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=1))
  gp = 0.00001 * tf.reduce_mean((slopes - 1.)**2) #Lambda
  return gp

@tf.function
def D_loss(x_real, cond_label, batchsize, G, D): 
  z = tf.random.normal([batchsize, 50], mean=0.5, stddev=0.5, dtype=tf.dtypes.float32) #batch and latent dim
  x_fake = G(inputs=[z, cond_label])
  D_fake = D(inputs=[x_fake, cond_label])
  D_real = D(inputs=[x_real, cond_label])
  D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + gradient_penalty(f = partial(D, training=True), x_real = x_real, x_fake = x_fake, cond_label=cond_label, batchsize=batchsize, D=D)
  return D_loss, D_fake

@tf.function
def G_loss(D_fake):
  G_loss = -tf.reduce_mean(D_fake)
  return G_loss

def getTrainData_ultimate( n_iteration, batchsize, dgratio, X ,Labels):
  true_batchsize = tf.cast(tf.math.multiply(batchsize, dgratio), tf.int64)
  n_samples = tf.cast(tf.gather(tf.shape(X), 0), tf.int64)
  n_batch = tf.cast(tf.math.floordiv(n_samples, true_batchsize), tf.int64)
  n_shuffles = tf.cast(tf.math.ceil(tf.divide(n_iteration, n_batch)), tf.int64)
  ds = tf.data.Dataset.from_tensor_slices((X, Labels))
  ds = ds.shuffle(buffer_size = n_samples).repeat(n_shuffles).batch(true_batchsize, drop_remainder=True).prefetch(2)
  return iter(ds)

@tf.function
def train_loop(X_trains, cond_labels, batchsize, dgratio, G, D, generator_optimizer, discriminator_optimizer): 
  for i in tf.range(dgratio):
    print("d train: " + str(i))
    with tf.GradientTape() as disc_tape:
      (D_loss_curr, D_fake) = D_loss(tf.gather(X_trains, i), tf.gather(cond_labels, i), batchsize, G, D)
      gradients_of_discriminator = disc_tape.gradient(D_loss_curr, D.trainable_variables)
      discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, D.trainable_variables))    
      
  print("g train")
  last_index = tf.subtract(dgratio, 1)

  with tf.GradientTape() as gen_tape:
    # Need to recompute D_fake, otherwise gen_tape doesn't know the history
    (D_loss_curr, D_fake) = D_loss(tf.gather(X_trains, last_index), tf.gather(cond_labels, last_index), batchsize, G, D)
    G_loss_curr = G_loss(D_fake)
    gradients_of_generator = gen_tape.gradient(G_loss_curr, G.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, G.trainable_variables))
    return D_loss_curr, G_loss_curr

In [None]:
dgratio = 5
batchsize = 128
G_lr = D_lr = 0.0001
G_beta1 = D_beta1 = 0.55
generator_optimizer = tf.optimizers.Adam(learning_rate=G_lr, beta_1=G_beta1)
discriminator_optimizer = tf.optimizers.Adam(learning_rate=D_lr, beta_1=D_beta1)

# Prepare for check pointing
saver = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                            discriminator_optimizer=discriminator_optimizer,
                            generator=generator,
                            discriminator=discriminator)

checkpoint_dir = "checkpoints"
if not os.path.exists(checkpoint_dir):
  os.makedirs(checkpoint_dir)

print ('training started')
dl = DataLoader.DataLoader()

start_iteration = 0 
max_iterations = 1000

for iteration in range(start_iteration,max_iterations): 
  change_data = (iteration == start_iteration)
  
  if (change_data == True):
    X, Labels = dl.getAllTrainData(8, 9)
    X = tf.convert_to_tensor(X, dtype=tf.float32)
    Labels = tf.convert_to_tensor(Labels, dtype=tf.float32)
 
    remained_iteration = tf.constant(max_iterations - iteration, dtype=tf.int64)
    ds_iter = getTrainData_ultimate(remained_iteration, batchsize, dgratio, X ,Labels)
    print ("Using "+ str(X.shape[0])+ " events")

  X, Labels = ds_iter.get_next()

  X_feature_size = tf.gather(tf.shape(X), 1)
  Labels_feature_size = tf.gather(tf.shape(Labels), 1)
  X_batch_shape = tf.stack((dgratio, batchsize, X_feature_size), axis=0)
  Labels_batch_shape = tf.stack((dgratio, batchsize, Labels_feature_size), axis=0)

  X_trains    = tf.reshape(X, X_batch_shape)
  cond_labels = tf.reshape(Labels, Labels_batch_shape)  

  #print(X_trains) 
  #print(cond_labels) 
  #print(batchsize) 
  #print(dgratio) 
  #generator.summary() 
  #discriminator.summary() 

  D_loss_curr, G_loss_curr = train_loop(X_trains, cond_labels, batchsize, dgratio, generator, discriminator,  generator_optimizer, discriminator_optimizer)

  if iteration == 0: 
    print("Model and loss values will be saved every 2 iterations." )
  
  if iteration % 2 == 0 and iteration > 0:

    try:
      saver.save(file_prefix = checkpoint_dir+ '/model')
    except:
      print("Something went wrong in saving iteration %s, moving to next one" % (iteration))
      print("exception message ", sys.exc_info()[0])     
    
    print('Iter: {}; D loss: {:.4}; G_loss:  {:.4}'.format(iteration, D_loss_curr, G_loss_curr))
    
        


In [30]:
sys.path.append('gan_code/')
import conditional_wgangp 
import importlib
importlib.reload(conditional_wgangp)

gan = conditional_wgangp.WGANGP()
gan.train()

Model: "model_30"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Noise (InputLayer)             [(None, 50)]         0           []                               
                                                                                                  
 mycond (InputLayer)            [(None, 2)]          0           []                               
                                                                                                  
 concatenate_30 (Concatenate)   (None, 52)           0           ['Noise[0][0]',                  
                                                                  'mycond[0][0]']                 
                                                                                                  
 dense_120 (Dense)              (None, 50)           2650        ['concatenate_30[0][0]']  

KeyboardInterrupt: 