# Informations
The data is available in the CERN [OpenData](http://opendata.cern.ch/record/15012)
The code is based on the FastCaloGAN code which is available on [Zenodo](https://zenodo.org/record/5589623) with the latest development available for ATLAS member on the [FCS git repository](https://gitlab.cern.ch/atlas-simulation-fastcalosim/fastcalogan)

The data and code are also the case for the [#calochallenge](https://github.com/CaloChallenge/homepage)

This example runs on only two samples.

In [2]:
# imports
import os
import numpy as np
import math
import tensorflow as tf
import copy
import sys

sys.path.append('gan_code/')
import DataLoader 
import importlib
importlib.reload(DataLoader)

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import activations
from tensorflow.keras.models import Model
from functools import partial
tf.keras.backend.set_floatx('float32')


In [3]:
fns = {
    256: "http://rgw.fisica.unimi.it/TutorialML-AtlasItalia2022/gan_inputs/pid22_E256_eta_20_25_voxalisation.csv?AWSAccessKeyId=M06HBTUGIKXVXYH1RES6&Signature=bSSGDUvNFuCxHHz3YlCqga0Jq0g%3D&Expires=1829145580",
    512: "http://rgw.fisica.unimi.it/TutorialML-AtlasItalia2022/gan_inputs/pid22_E512_eta_20_25_voxalisation.csv?AWSAccessKeyId=M06HBTUGIKXVXYH1RES6&Signature=Dr3A32ycujBSm4bQ14l%2BHvEf1Ig%3D&Expires=1829145645"
}

# Define GAN architecture
The generator takes as input noise and in this case two values to create a conditional generator. It will output 368 values, using a stack of dense layers.

In [None]:
initializer = tf.keras.initializers.he_uniform()
bias_node = True
noise = layers.Input(shape=(50), name="Noise")
condition = layers.Input(shape=(2), name="mycond")
con = layers.concatenate([noise,condition])
G = layers.Dense(50, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(con)  
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(100, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(200, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)

generator = Model(inputs=[noise, condition], outputs=G)
generator.build(370)
generator.summary()

## Define the Discriminator network
The discriminator takes as input the values generated by the generator and check if they discriminate if they are real or generated.

In [None]:
initializer = tf.keras.initializers.he_uniform()
bias_node = True

image = layers.Input(shape=(368), name="Image")
d_condition = layers.Input(shape=(2), name="mycond")
d_con = layers.concatenate([image, d_condition])
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(d_con)  
D = layers.Activation(activations.relu)(D)
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)
D = layers.Activation(activations.relu)(D)
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)
D = layers.Activation(activations.relu)(D)
D = layers.Dense(1, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)

discriminator = Model(inputs=[image, d_condition], outputs=D)
discriminator.build(370)
discriminator.summary()


# Train, loss and gradient functions

In [None]:
@tf.function
def gradient_penalty(f, x_real, x_fake, cond_label, batchsize, D):
  alpha = tf.random.uniform([batchsize, 1], minval=0., maxval=1.)

  inter = alpha * x_real + (1-alpha) * x_fake
  with tf.GradientTape() as t:
    t.watch(inter)
    pred = D(inputs=[inter, cond_label])
  grad = t.gradient(pred, [inter])[0]
  
  slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=1))
  gp = 0.00001 * tf.reduce_mean((slopes - 1.)**2) #Lambda
  return gp

@tf.function
def D_loss(x_real, cond_label, batchsize, G, D): 
  z = tf.random.normal([batchsize, 50], mean=0.5, stddev=0.5, dtype=tf.dtypes.float32) #batch and latent dim
  x_fake = G(inputs=[z, cond_label])
  D_fake = D(inputs=[x_fake, cond_label])
  D_real = D(inputs=[x_real, cond_label])
  D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + gradient_penalty(f=partial(D, training=True), x_real=x_real, x_fake=x_fake, cond_label=cond_label, batchsize=batchsize, D=D)
  return D_loss, D_fake

@tf.function
def G_loss(D_fake):
  G_loss = -tf.reduce_mean(D_fake)
  return G_loss

def getTrainData_ultimate( n_iteration, batchsize, dgratio, X ,Labels):
  true_batchsize = tf.cast(tf.math.multiply(batchsize, dgratio), tf.int64)
  n_samples = tf.cast(tf.gather(tf.shape(X), 0), tf.int64)
  n_batch = tf.cast(tf.math.floordiv(n_samples, true_batchsize), tf.int64)
  n_shuffles = tf.cast(tf.math.ceil(tf.divide(n_iteration, n_batch)), tf.int64)
  ds = tf.data.Dataset.from_tensor_slices((X, Labels))
  ds = ds.shuffle(buffer_size = n_samples).repeat(n_shuffles).batch(true_batchsize, drop_remainder=True).prefetch(2)
  return iter(ds)

@tf.function
def train_loop(X_trains, cond_labels, batchsize, dgratio, G, D, generator_optimizer, discriminator_optimizer): 
  for i in tf.range(dgratio):
    print("d train: " + str(i))
    with tf.GradientTape() as disc_tape:
      (D_loss_curr, D_fake) = D_loss(tf.gather(X_trains, i), tf.gather(cond_labels, i), batchsize, G, D)
      gradients_of_discriminator = disc_tape.gradient(D_loss_curr, D.trainable_variables)
      discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, D.trainable_variables))    
      
  print("g train")
  last_index = tf.subtract(dgratio, 1)

  with tf.GradientTape() as gen_tape:
    # Need to recompute D_fake, otherwise gen_tape doesn't know the history
    (D_loss_curr, D_fake) = D_loss(tf.gather(X_trains, last_index), tf.gather(cond_labels, last_index), batchsize, G, D)
    G_loss_curr = G_loss(D_fake)
    gradients_of_generator = gen_tape.gradient(G_loss_curr, G.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, G.trainable_variables))
    return D_loss_curr, G_loss_curr

In [None]:
dgratio = 5
batchsize = 128
G_lr = D_lr = 0.0001
G_beta1 = D_beta1 = 0.55
generator_optimizer = tf.optimizers.Adam(learning_rate=G_lr, beta_1=G_beta1)
discriminator_optimizer = tf.optimizers.Adam(learning_rate=D_lr, beta_1=D_beta1)

# Prepare for check pointing
saver = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                            discriminator_optimizer=discriminator_optimizer,
                            generator=generator,
                            discriminator=discriminator)

checkpoint_dir = "checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

In [None]:
dl = DataLoader.DataLoader(fns)

In [None]:
start_iteration = 0 
max_iterations = 500

for iteration in range(start_iteration,max_iterations): 
  change_data = (iteration == start_iteration)
  
  if (change_data == True):
    X, Labels = dl.getAllTrainData(8, 9)
    X = tf.convert_to_tensor(X, dtype=tf.float32)
    Labels = tf.convert_to_tensor(Labels, dtype=tf.float32)
 
    remained_iteration = tf.constant(max_iterations - iteration, dtype=tf.int64)
    ds_iter = getTrainData_ultimate(remained_iteration, batchsize, dgratio, X ,Labels)
    print ("Using "+ str(X.shape[0])+ " events")

  X, Labels = ds_iter.get_next()

  X_feature_size = tf.gather(tf.shape(X), 1)
  Labels_feature_size = tf.gather(tf.shape(Labels), 1)
  X_batch_shape = tf.stack((dgratio, batchsize, X_feature_size), axis=0)
  Labels_batch_shape = tf.stack((dgratio, batchsize, Labels_feature_size), axis=0)

  X_trains    = tf.reshape(X, X_batch_shape)
  cond_labels = tf.reshape(Labels, Labels_batch_shape)  

  #print(X_trains) 
  #print(cond_labels) 
  #print(batchsize) 
  #print(dgratio) 
  #generator.summary() 
  #discriminator.summary() 

  D_loss_curr, G_loss_curr = train_loop(X_trains, cond_labels, batchsize, dgratio, generator, discriminator,  generator_optimizer, discriminator_optimizer)

  if iteration == 0: 
    print("Model and loss values will be saved every 2 iterations." )
  
  if iteration % 2 == 0 and iteration > 0:

    try:
      saver.save(file_prefix = checkpoint_dir+ '/model')
    except:
      print("Something went wrong in saving iteration %s, moving to next one" % (iteration))
      print("exception message ", sys.exc_info()[0])     
    
    print('Iter: {}; D loss: {:.4}; G_loss:  {:.4}'.format(iteration, D_loss_curr, G_loss_curr))
    
        


In [None]:
output_best_checkpoints="best_iteration"
if not os.path.exists(output_best_checkpoints):
  os.makedirs(output_best_checkpoints)

# Evaluation

Import modules and load classes

In [7]:
import numpy as np
import pandas as pd
import math
import argparse 
import matplotlib.pyplot as plt
import shutil
import ctypes
import glob

sys.path.append('gan_code/')
import conditional_wgangp 
import importlib
importlib.reload(conditional_wgangp)
import DataLoader
importlib.reload(DataLoader)

save_plots = True
lparams = {'xoffset' : 0.1, 'yoffset' : 0.27, 'width'   : 0.8, 'height'  : 0.35}
canvases = []   

histos_vox = []
input_files_vox = []

particleName="#gamma"
particle="photons"

output_best_checkpoints="best_iteration"
if not os.path.exists(output_best_checkpoints):
  os.makedirs(output_best_checkpoints)

maxVoxel = 0
midEnergy = 0
step = 50

wgan = conditional_wgangp.WGANGP()
dl = DataLoader.DataLoader(fns)
ekins = dl.ekins

firstPosition = 0


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Noise (InputLayer)             [(None, 50)]         0           []                               
                                                                                                  
 mycond (InputLayer)            [(None, 2)]          0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, 52)           0           ['Noise[0][0]',                  
                                                                  'mycond[0][0]']                 
                                                                                                  
 dense (Dense)                  (None, 50)           2650        ['concatenate[0][0]']        

Now build reference histos 

In [31]:
hist_lim = {
    256: { 40, 450},
    512: { 250, 750}
}
print(hist_lim[256])
min,max=hist_lim[256]
a,b={0, 10}
print(min)
print(a)
print(b)

{40, 450}
40
0
10


In [35]:
%matplotlib inline
import pandas as pd

all_total_E = []

print("Opening vox files")
for index, p in enumerate(ekins):  
  df = pd.read_csv(fns[p], header=None, engine='python', dtype=np.float64)
  #df = pd.read_csv("gan_inputs/pid22_E%d_eta_20_25_voxalisation.csv"%(p), header=None, engine='python', dtype=np.float64)
  df = df.fillna(0)
  data=df.to_numpy()  

  total_E = data.sum(axis=-1)
  all_total_E.append(total_E)

  min, max = hist_lim[p]
  bins = np.linspace(min, max, 30)
  plt.figure(figsize=(6, 6))
  plt.hist(total_E, bins=bins, label='reference', density=True,
           histtype='stepfilled', alpha=0.2, linewidth=2.)
  plt.legend(fontsize=20)
  plt.tight_layout()
  plt.savefig("plot_%d"%(p), dpi=300)
  plt.close()

Opening vox files
256
40
450
512
250
750


In [40]:
import conditional_wgangp
importlib.reload(conditional_wgangp)

nevents=10000
print ("Running from %i to %i in step of %i" %(1, 1000, step))
for iteration in range(1, 1000, step):
  legendPadIndex = 16
  chi2_tot = 0.
  ndf_tot = 0
  input_files_gan = []

  for index, energy in enumerate(ekins):     
    ekin_sample = ekins[index]
    energyArray = np.array([energy] * nevents)
    etaArray = np.zeros(nevents) 
    labels = np.vstack((energyArray, etaArray)).T   

    print(iteration)
    data = wgan.load(iteration, labels, nevents, 'checkpoints/')
    data = data * ekin_sample       #needed for conditional
      
    E_tot = data.numpy().sum(axis=1)
    print(E_tot)

    # Plotting

    plt.figure(figsize=(6, 6))
    plt.hist(all_total_E[index], bins=bins, label='reference', density=True,
            histtype='stepfilled', alpha=0.2, linewidth=2.)
    plt.hist(E_tot, bins=bins, label='reference', density=True,
            histtype='stepfilled', alpha=0.2, linewidth=2.)
    plt.legend(fontsize=20)
    plt.tight_layout()
    plt.savefig("plot_%d_%d"%(p,iteration), dpi=300)
    plt.close()

Running from 1 to 1000 in step of 50
1
checkpoints//model-1
[8291.822 8312.508 8318.231 ... 8306.793 8315.206 8307.19 ]
1
checkpoints//model-1
[16570.602 16573.752 16545.639 ... 16592.484 16622.441 16625.184]
51
checkpoints//model-51
[4731.4326 4742.367  4694.384  ... 4719.254  4775.2524 4786.324 ]
51
checkpoints//model-51
[9545.356 9507.4   9506.665 ... 9540.195 9510.289 9514.639]
101
checkpoints//model-101
[3066.6711 3043.9248 3089.1206 ... 3094.606  3079.217  3060.179 ]
101
checkpoints//model-101
[6168.661  6215.128  6204.4995 ... 6203.5786 6211.6255 6211.8984]
151
checkpoints//model-151
[2561.2444 2520.8828 2522.9497 ... 2542.046  2532.902  2522.7617]
151
checkpoints//model-151
[5144.5635 5156.461  5146.5146 ... 5124.867  5171.4805 5151.82  ]
201
checkpoints//model-201
[2360.1152 2343.952  2388.2036 ... 2360.9512 2384.06   2366.825 ]
201
checkpoints//model-201
[4862.584  4844.9976 4836.1475 ... 4832.8003 4820.1885 4846.0337]
251
checkpoints//model-251
[2289.8772 2317.2617 2286.9712

NotFoundError: Error when restoring from checkpoint or SavedModel at checkpoints//model-501: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for checkpoints//model-501
Please double-check that the path is correct. You may be missing the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1').

# GAN training and evaluation with 1M iterations
The GAN learn but with significant fluctuations [plot](https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PAPERS/SIMU-2018-04/fig_09.png)
Animated gif with all energies [gif](https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PUBNOTES/ATL-SOFT-PUB-2020-006/fig_37.png)

# Latest results
![pions](imgs/latest_pions.png)