# Informations
The data is available in the CERN [OpenData](http://opendata.cern.ch/record/15012)
The code is based on the FastCaloGAN code which is available on [Zenodo](https://zenodo.org/record/5589623) with the latest development available for ATLAS member on the [FCS git repository](https://gitlab.cern.ch/atlas-simulation-fastcalosim/fastcalogan)

The data and code are also the case for the [#calochallenge](https://github.com/CaloChallenge/homepage)

Download files into the csv_inputs folder to run the example from the first link or 
[128 MeV](http://rgw.fisica.unimi.it/TutorialML-AtlasItalia2022/gan_inputs/pid22_E256_eta_20_25_voxalisation.csv?AWSAccessKeyId=M06HBTUGIKXVXYH1RES6&Signature=bSSGDUvNFuCxHHz3YlCqga0Jq0g%3D&Expires=1829145580)
[256 MeV](http://rgw.fisica.unimi.it/TutorialML-AtlasItalia2022/gan_inputs/pid22_E512_eta_20_25_voxalisation.csv?AWSAccessKeyId=M06HBTUGIKXVXYH1RES6&Signature=Dr3A32ycujBSm4bQ14l%2BHvEf1Ig%3D&Expires=1829145645)

The code is saved to run on only two samples; if you want to run on all samples, you will need to change maxExp in DataLoader to 23 ()


In [None]:
# imports
import os
import numpy as np
import math
import tensorflow as tf
import copy
import sys

sys.path.append('gan_code/')
import DataLoader 
import importlib
importlib.reload(DataLoader)

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import activations
from tensorflow.keras.models import Model
from functools import partial
tf.keras.backend.set_floatx('float32')


# Define GAN architecture

In [None]:
# Define the generator network
initializer = tf.keras.initializers.he_uniform()
bias_node = True
noise = layers.Input(shape=(50), name="Noise")
condition = layers.Input(shape=(2), name="mycond")
con = layers.concatenate([noise,condition])
G = layers.Dense(50, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(con)  
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(100, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(200, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)
G = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(G)
G = layers.BatchNormalization()(G)
G = layers.Activation(activations.swish)(G)

generator = Model(inputs=[noise, condition], outputs=G)
generator.build(370)
generator.summary()



In [None]:
# Define the Discriminator network
initializer = tf.keras.initializers.he_uniform()
bias_node = True

image = layers.Input(shape=(368), name="Image")
d_condition = layers.Input(shape=(2), name="mycond")
d_con = layers.concatenate([image,d_condition])
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(d_con)  
D = layers.Activation(activations.relu)(D)
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)
D = layers.Activation(activations.relu)(D)
D = layers.Dense(368, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)
D = layers.Activation(activations.relu)(D)
D = layers.Dense(1, use_bias=bias_node, kernel_initializer=initializer, bias_initializer='zeros')(D)

discriminator = Model(inputs=[image, d_condition], outputs=D)
discriminator.build(370)
discriminator.summary()


# Train, loss anf gradient functions

In [None]:

@tf.function
def gradient_penalty(f, x_real, x_fake, cond_label, batchsize, D):
  alpha = tf.random.uniform([batchsize, 1], minval=0., maxval=1.)

  inter = alpha * x_real + (1-alpha) * x_fake
  with tf.GradientTape() as t:
    t.watch(inter)
    pred = D(inputs=[inter, cond_label])
  grad = t.gradient(pred, [inter])[0]
  
  slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=1))
  gp = 0.00001 * tf.reduce_mean((slopes - 1.)**2) #Lambda
  return gp

@tf.function
def D_loss(x_real, cond_label, batchsize, G, D): 
  z = tf.random.normal([batchsize, 50], mean=0.5, stddev=0.5, dtype=tf.dtypes.float32) #batch and latent dim
  x_fake = G(inputs=[z, cond_label])
  D_fake = D(inputs=[x_fake, cond_label])
  D_real = D(inputs=[x_real, cond_label])
  D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + gradient_penalty(f = partial(D, training=True), x_real = x_real, x_fake = x_fake, cond_label=cond_label, batchsize=batchsize, D=D)
  return D_loss, D_fake

@tf.function
def G_loss(D_fake):
  G_loss = -tf.reduce_mean(D_fake)
  return G_loss

def getTrainData_ultimate( n_iteration, batchsize, dgratio, X ,Labels):
  true_batchsize = tf.cast(tf.math.multiply(batchsize, dgratio), tf.int64)
  n_samples = tf.cast(tf.gather(tf.shape(X), 0), tf.int64)
  n_batch = tf.cast(tf.math.floordiv(n_samples, true_batchsize), tf.int64)
  n_shuffles = tf.cast(tf.math.ceil(tf.divide(n_iteration, n_batch)), tf.int64)
  ds = tf.data.Dataset.from_tensor_slices((X, Labels))
  ds = ds.shuffle(buffer_size = n_samples).repeat(n_shuffles).batch(true_batchsize, drop_remainder=True).prefetch(2)
  return iter(ds)

@tf.function
def train_loop(X_trains, cond_labels, batchsize, dgratio, G, D, generator_optimizer, discriminator_optimizer): 
  for i in tf.range(dgratio):
    print("d train: " + str(i))
    with tf.GradientTape() as disc_tape:
      (D_loss_curr, D_fake) = D_loss(tf.gather(X_trains, i), tf.gather(cond_labels, i), batchsize, G, D)
      gradients_of_discriminator = disc_tape.gradient(D_loss_curr, D.trainable_variables)
      discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, D.trainable_variables))    
      
  print("g train")
  last_index = tf.subtract(dgratio, 1)

  with tf.GradientTape() as gen_tape:
    # Need to recompute D_fake, otherwise gen_tape doesn't know the history
    (D_loss_curr, D_fake) = D_loss(tf.gather(X_trains, last_index), tf.gather(cond_labels, last_index), batchsize, G, D)
    G_loss_curr = G_loss(D_fake)
    gradients_of_generator = gen_tape.gradient(G_loss_curr, G.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, G.trainable_variables))
    return D_loss_curr, G_loss_curr

In [None]:
dgratio = 5
batchsize = 128
G_lr = D_lr = 0.0001
G_beta1 = D_beta1 = 0.55
generator_optimizer = tf.optimizers.Adam(learning_rate=G_lr, beta_1=G_beta1)
discriminator_optimizer = tf.optimizers.Adam(learning_rate=D_lr, beta_1=D_beta1)

# Prepare for check pointing
saver = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                            discriminator_optimizer=discriminator_optimizer,
                            generator=generator,
                            discriminator=discriminator)

checkpoint_dir = "checkpoints"
if not os.path.exists(checkpoint_dir):
  os.makedirs(checkpoint_dir)

print ('training started')
dl = DataLoader.DataLoader()

start_iteration = 0 
max_iterations = 1000

for iteration in range(start_iteration,max_iterations): 
  change_data = (iteration == start_iteration)
  
  if (change_data == True):
    X, Labels = dl.getAllTrainData(8, 9)
    X = tf.convert_to_tensor(X, dtype=tf.float32)
    Labels = tf.convert_to_tensor(Labels, dtype=tf.float32)
 
    remained_iteration = tf.constant(max_iterations - iteration, dtype=tf.int64)
    ds_iter = getTrainData_ultimate(remained_iteration, batchsize, dgratio, X ,Labels)
    print ("Using "+ str(X.shape[0])+ " events")

  X, Labels = ds_iter.get_next()

  X_feature_size = tf.gather(tf.shape(X), 1)
  Labels_feature_size = tf.gather(tf.shape(Labels), 1)
  X_batch_shape = tf.stack((dgratio, batchsize, X_feature_size), axis=0)
  Labels_batch_shape = tf.stack((dgratio, batchsize, Labels_feature_size), axis=0)

  X_trains    = tf.reshape(X, X_batch_shape)
  cond_labels = tf.reshape(Labels, Labels_batch_shape)  

  #print(X_trains) 
  #print(cond_labels) 
  #print(batchsize) 
  #print(dgratio) 
  #generator.summary() 
  #discriminator.summary() 

  D_loss_curr, G_loss_curr = train_loop(X_trains, cond_labels, batchsize, dgratio, generator, discriminator,  generator_optimizer, discriminator_optimizer)

  if iteration == 0: 
    print("Model and loss values will be saved every 2 iterations." )
  
  if iteration % 2 == 0 and iteration > 0:

    try:
      saver.save(file_prefix = checkpoint_dir+ '/model')
    except:
      print("Something went wrong in saving iteration %s, moving to next one" % (iteration))
      print("exception message ", sys.exc_info()[0])     
    
    print('Iter: {}; D loss: {:.4}; G_loss:  {:.4}'.format(iteration, D_loss_curr, G_loss_curr))
    
        


# All together in a class

In [None]:
import sys
sys.path.append('gan_code/')
import conditional_wgangp 
import importlib
importlib.reload(conditional_wgangp)

gan = conditional_wgangp.WGANGP()
gan.train()

# Evaluation

In [None]:
#!/usr/bin/env python3
import numpy as np
import math
import argparse 
import os,sys,ctypes
import ROOT 
import shutil
import ctypes
import glob

ROOT.gROOT.SetBatch(True)

sys.path.append('gan_code/')
import conditional_wgangp 
import importlib
importlib.reload(conditional_wgangp)
import DataLoader
importlib.reload(DataLoader)

save_plots = True
lparams = {'xoffset' : 0.1, 'yoffset' : 0.27, 'width'   : 0.8, 'height'  : 0.35}
canvases = []   

histos_vox = []
input_files_vox = []

minFactor = 3
maxFactor = 3  
particleName="#gamma"
particle="photons"

output_best_checkpoints="best_iteration"
if not os.path.exists(output_best_checkpoints):
  os.makedirs(output_best_checkpoints)

maxVoxel = 0
midEnergy = 0
step = 50

wgan = conditional_wgangp.WGANGP()
dl = DataLoader.DataLoader()
ekins = dl.ekins

firstPosition = 0

print("Opening vox files")
for index, energy in enumerate(ekins):    
  #print(" Energy ", energy)
  input_file_vox = ('rootFiles/pid22_E%s_eta_20_25.root' % (energy))
  print(" Opening file: " + input_file_vox)
  infile_vox = ROOT.TFile.Open(input_file_vox, 'read') 
  input_files_vox.append(infile_vox)
  tree = infile_vox.Get('rootTree') 
  
  h = ROOT.TH1F("h","",100,0,energy*2) 
  tree.Draw("etot>>h","","off")
  xmax=h.GetBinCenter(h.GetMaximumBin());
  minX = max(0, xmax-minFactor*h.GetRMS()) #max(0, xmax-minFactors[item]*h.GetRMS())
  maxX = xmax+maxFactor*h.GetRMS()
  print("min "+ str(minX) + " max " + str(maxX))
      
  h_vox = ROOT.TH1F("h_vox","",30,minX/1000,maxX/1000) 
  tree.Draw("etot/1000>>h_vox","","off")
  h_vox.Scale(1/h_vox.GetEntries())
  histos_vox.append(h_vox)

print ("Running from %i to %i in step of %i" %(0, 1000, step))
for iteration in range(0, 1000, step):
  try:
    histos_gan =[]
    canvas = ROOT.TCanvas('canvas_h', 'Total Energy comparison plots', 900, 900)
    canvas.Divide(4,4)
    legendPadIndex = 16

    canvases.append(canvas)

    chi2_tot = 0.
    ndf_tot = 0
    input_files_gan = []

    for index, energy in enumerate(ekins):     
      ekin_sample = ekins[index]
      energyArray = np.array([energy] * nevents)
      etaArray = np.zeros(nevents) 
      labels = np.vstack((energyArray, etaArray)).T   

      data = wgan.load(iteration, labels, n_events, 'checkpoints')
      data = data * ekin_sample       #needed for conditional
        
      h_vox = histos_vox[index]
      h_gan = ROOT.TH1F("h_gan","",30,h_vox.GetXaxis().GetXmin(),h_vox.GetXaxis().GetXmax())

      E_tot = data.numpy().sum(axis=1)
      for e in E_tot:
        h_gan.Fill(e/1000)
      
      h_gan.Scale(1/h_gan.GetEntries())
      h_gan.SetLineColor(ROOT.kRed)
      h_gan.SetLineStyle(7)
      m = [h_vox.GetBinContent(h_vox.GetMaximumBin()),h_gan.GetBinContent(h_gan.GetMaximumBin())]
      h_vox.GetYaxis().SetRangeUser(0,max(m) *1.25)
      histos_gan.append(h_gan)
      h_vox.GetYaxis().SetTitle("Entries")

      xAxisTitle = "Energy [GeV]"
      h_vox.GetXaxis().SetTitle(xAxisTitle)  
      h_vox.GetXaxis().SetNdivisions(506)
      chi2 = ctypes.c_double(0.)
      ndf = ctypes.c_int(0)
      igood = ctypes.c_int(0)
      histos_vox[index].Chi2TestX(h_gan, chi2, ndf, igood, "WW")
      ndf = ndf.value
      chi2=chi2.value
      chi2_tot += chi2
      ndf_tot += ndf

      if (ndf != 0):
        print("Iteration %s Energy %s : chi2/ndf = %.1f / %i = %.1f\n" % (iteration, energy, chi2, ndf, chi2/ndf))

      # Plotting

      canvas.cd(index+1)
      histos_vox[index].Draw("HIST")
      histos_gan[index].Draw("HIST same")

      # Legend box                                                                                                                                                                            
      if (energy > 1024):
          energy_legend =  str(round(energy/1000,1)) + " GeV"
      else:
          energy_legend =  str(energy) + " MeV"
      t = ROOT.TLatex()
      t.SetNDC()
      t.SetTextFont(42)
      t.SetTextSize(0.1)
      t.DrawLatex(0.2, 0.83, energy_legend)
   
    # Total Energy chi2
    chi2_o_ndf = chi2_tot / ndf_tot
    print("Iteration %s Total Energy : chi2/ndf = %.1f / %i = %.3f\n" % (iteration, chi2_tot, ndf_tot, chi2_o_ndf))
    chi2File = "%s/chi2.txt" % (output_best_checkpoints, pid, eta_min, eta_max)
    if chi2_o_ndf > 0:
      f = open(chi2File, 'a')
      f.write("%s %.3f\n" % (iteration, chi2_o_ndf))
      f.close()
    else:
      print("Something went wrong, chi2 will not be written. Chi2/ndf is %f " % (chi2_o_ndf))
      print(E_tot)
      continue

    # Legend box particle
    leg = MakeLegend( lparams )
    leg.SetTextFont( 42 )
    leg.SetTextSize(0.1)
    canvas.cd(legendPadIndex)
    leg.AddEntry(h_vox,"Geant4","l") #Geant4
    leg.Draw()
    leg.AddEntry(h_gan,"GAN","l")  #WGAN-GP
    leg.Draw('same')
    legend = (particleName + ", " + str('{:.2f}'.format(int(20)/100,2)) + "<|#eta|<" + str('{:.2f}'.format((int(20)+5)/100,2)))
    ROOT.ATLAS_LABEL_BIG( 0.1, 0.9, ROOT.kBlack, legend )

    # Legend box Epoc&chi2 

    t = ROOT.TLatex()
    t.SetNDC()
    t.SetTextFont(42)
    t.SetTextSize(0.1)
    t.DrawLatex(0.1, 0.18, "Iter: %s" % (iteration))
    t.DrawLatex(0.1, 0.07, "#scale[0.8]{#chi^{2}/NDF = %.0f/%i = %.1f}" % (chi2_tot, ndf_tot, chi2_o_ndf))


    #Copy best epoch files, including plots
    epochs, chi2_o_ndf_list = np.loadtxt(chi2File, delimiter=' ', unpack=True)
    
    checkpointName =  "Plot_comparison_tot_energy"
      
    if round(chi2_o_ndf,3) <= np.amin(chi2_o_ndf_list) and chi2_o_ndf > 0:
      print ("Better chi2, creating plots")
      inputFile_Plot_png="%s/%s.png" % (output_best_checkpoints, checkpointName)
      inputFile_Plot_eps="%s/%s.eps" % (output_best_checkpoints, checkpointName)
      inputFile_Plot_pdf="%s/%s.pdf" % (output_best_checkpoints, checkpointName)
      canvas.SaveAs(inputFile_Plot_png) 
      canvas.SaveAs(inputFile_Plot_eps) 
      canvas.SaveAs(inputFile_Plot_pdf) 
     
      print("Epoch with lowest chi2/ndf is %s with a value of %.3f" % (epoch, chi2_o_ndf))
      #Now save best epoch number to file
      chi2File = "%s/chi2/epoch_best_chi2_%s_%s_%s.txt" % (output_best_checkpoints, pid, eta_min, eta_max)
      f = open(chi2File, 'w')
      f.write("%s %.3f\n" % (iteration, chi2_o_ndf))
      f.close() 

    if (save_plots) :
      checkpointName =  "Plot_comparison_tot_energy_%i" % (iteration)
      inputFile_Plot_png="%s/%s.png" % (output_best_checkpoints, checkpointName)
      inputFile_Plot_eps="%s/%s.eps" % (output_best_checkpoints, checkpointName)
      inputFile_Plot_pdf="%s/%s.pdf" % (output_best_checkpoints, checkpointName)
      canvas.SaveAs(inputFile_Plot_png) 
      canvas.SaveAs(inputFile_Plot_eps) 
      canvas.SaveAs(inputFile_Plot_pdf) 

  except:
    print("Something went wrong in iteration %s, moving to next one" % (iteration))
    print("exception message ", sys.exc_info()[0])     



# GAN training and evaluation with 1M iterations
The GAN learn but with significant fluctuations [plot](https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PAPERS/SIMU-2018-04/fig_09.png)
Animated gif with all energies [gif](https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PUBNOTES/ATL-SOFT-PUB-2020-006/fig_37.png)

# Latest results
![pions](imgs/latest_pions.png)