<a href="https://colab.research.google.com/github/sznajder/Notebooks/blob/master/InteractionNetwork_BoostedHiggs_TFKeras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Extracted from PYTHON code. Needs to be organized for Jupyter NB
# From: https://github.com/thongonary/LEIA-Net/blob/master/interaction.py


In [None]:
"""
Tensorflow implementation of the Interaction networks for the identification of boosted Higgs to bb decays https://arxiv.org/abs/1909.12285 
"""

import os
import itertools
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
from lbn import LBNLayer

class LEIA(models.Model):
    def __init__(self, n_constituents, n_targets, params, hidden, fr_activation=0, fo_activation=0, fc_activation=0, De=8, Do=8, sum_O=True, debug=False):
        super(LEIA, self).__init__()

        # initialize the LBN layer for preprocessing
        self.lbn = LBNLayer(n_particles=n_constituents, n_restframes=n_constituents, boost_mode='pairs')

        self.hidden = int(hidden)
        self.P = params
        self.N = self.lbn.lbn.n_out
        self.Nr = self.N * (self.N - 1)
        self.Dr = 0
        self.De = De
        self.Dx = 0
        self.Do = Do
        self.n_targets = n_targets
        self.fr_activation = fr_activation
        self.fo_activation = fo_activation
        self.fc_activation = fc_activation 
        self.assign_matrices()
        self.Ra = tf.ones([self.Dr, self.Nr])
        self.fr1 = layers.Dense(self.hidden) #, input_shape=(2 * self.P + self.Dr,)
        self.fr2 = layers.Dense(int(self.hidden/2)) # , input_shape=(self.hidden,)
        self.fr3 = layers.Dense(self.De) # , input_shape=(int(self.hidden/2),)
        
        self.fo1 = layers.Dense(self.hidden) # , input_shape=(self.P + self.Dx + (2 * self.De),)
        self.fo2 = layers.Dense(int(self.hidden/2)) # , input_shape=(self.hidden,)
        self.fo3 = layers.Dense(self.Do) # , input_shape=(int(self.hidden/2),)
        
        self.fc1 = layers.Dense(hidden)
        self.fc2 = layers.Dense(int(hidden/2))
        self.fc3 = layers.Dense(self.n_targets)
        self.sum_O = sum_O 
        self.debug = debug

    def build(self, input_shape):
        assert len(input_shape) >= 2
        self.built = True

    def assign_matrices(self):
        Rr = np.zeros([self.N, self.Nr], dtype=np.float32)
        Rs = np.zeros([self.N, self.Nr], dtype=np.float32)
        receiver_sender_list = [i for i in itertools.product(range(self.N), range(self.N)) if i[0]!=i[1]]
        for i, (r, s) in enumerate(receiver_sender_list):
            Rr[r, i] = 1
            Rs[s, i] = 1
        self.Rr = tf.convert_to_tensor(Rr)
        self.Rs = tf.convert_to_tensor(Rs)
        del Rs, Rr

    def call(self, x):
        '''
        Expect input to have shape of (batches, N_particles, N_features)
        '''
        ###PF Candidate - PF Candidate###
        if self.debug: 
            print("input_shape = {}".format(x.shape))
            print(f"x before lbn : {x[0,0,:]}")
        x = self.lbn(x) # Already in E, px, py, pz # Bypass for now, just to check if the IN works
        if self.debug: 
            print(f"x after lbn : {x[0,0,:]}\n")
            print("input_shape after lbn = {}".format(x.shape))
            print("n_outs after lbn = {}".format(self.lbn.lbn.n_out))
        x = tf.transpose(x, perm=[0, 2, 1]) # to fit in the IN
        if self.debug: print(f"x after transpose = {x.shape}")
        Orr = self.tmul(x, self.Rr)
        if self.debug: print(f"Orr = {Orr.shape}")
        Ors = self.tmul(x, self.Rs)
        if self.debug: print(f"Ors = {Ors.shape}")
        B = tf.concat([Orr, Ors], 1)
        if self.debug: 
            print(f"B = {B.shape}")
            print(f"params = {self.P}")
        ### First MLP ###
        B = tf.transpose(B, perm=[0, 2, 1])
        if self.fr_activation == 2:
            B = tf.nn.selu(self.fr1(tf.reshape(B, [-1, 2 * self.P + self.Dr])))
            B = tf.nn.selu(self.fr2(B))
            E = tf.nn.selu(tf.reshape(self.fr3(B), [-1, self.Nr, self.De]))
        elif self.fr_activation == 1:
            B = tf.nn.elu(self.fr1(tf.reshape(B, [-1, 2 * self.P + self.Dr])))
            B = tf.nn.elu(self.fr2(B))
            E = tf.nn.elu(tf.reshape(self.fr3(B), [-1, self.Nr, self.De]))
        else:
            B = tf.nn.relu(self.fr1(tf.reshape(B, [-1, 2 * self.P + self.Dr])))
            if self.debug: print(f"B after fr1 = {B.shape}")
            B = tf.nn.relu(self.fr2(B))
            if self.debug: print(f"B after fr2 = {B.shape}")
            E = tf.nn.relu(tf.reshape(self.fr3(B), [-1, self.Nr, self.De]))
            if self.debug: print(f"E after fr3 = {E.shape}")
        del B
        if self.debug: print("E after 1st MLP = {}".format(E.shape))
        E = tf.transpose(E, perm=[0, 2, 1])
        if self.debug:
            print("E after transpose = {}".format(E.shape))
            print("Rr after transpose = {}".format(self.Rr.shape))
        Ebar = self.tmul(E, tf.transpose(self.Rr, perm=[1, 0]))
        if self.debug: print("Ebar after tmul = {}".format(Ebar.shape))
        del E
       
        ####Final output matrix for particles###
        C = tf.concat([x, Ebar], 1)
        del Ebar
        C = tf.transpose(C, perm=[0, 2, 1])
        
        ### Second MLP ###
        if self.fo_activation == 2:
            C = tf.nn.selu(self.fo1(tf.reshape(C, [-1, self.P + self.Dx + self.De])))
            C = tf.nn.selu(self.fo2(C))
            O = tf.nn.selu(tf.reshape(self.fo3(C), [-1, self.N, self.Do]))
        elif self.fo_activation == 1:
            C = tf.nn.elu(self.fo1(tf.reshape(C, [-1, self.P + self.Dx + self.De])))
            C = tf.nn.elu(self.fo2(C))
            O = tf.nn.elu(tf.reshape(self.fo3(C), [-1, self.N, self.Do]))
        else:
            C = tf.nn.relu(self.fo1(tf.reshape(C, [-1, self.P + self.Dx + self.De])))
            C = tf.nn.relu(self.fo2(C))
            O = tf.nn.relu(tf.reshape(self.fo3(C), [-1, self.N, self.Do]))
        del C
       
        if self.sum_O:
            O = tf.reduce_sum(O, 1)

        ### Classification MLP ###
        if self.fc_activation == 2:
            if self.sum_O:
                N = tf.nn.selu(self.fc1(tf.reshape(O, [-1, self.Do * 1])))
            else:
                N = tf.nn.selu(self.fc1(tf.reshape(O, [-1, self.Do * N])))
            N = tf.nn.selu(self.fc2(N))   
        if self.fc_activation == 1:
            if self.sum_O:
                N = tf.nn.elu(self.fc1(tf.reshape(O, [-1, self.Do * 1])))
            else:
                N = tf.nn.elu(self.fc1(tf.reshape(O, [-1, self.Do * N])))
            N = tf.nn.elu(self.fc2(N)) 
        else:
            if self.sum_O:
                N = tf.nn.relu(self.fc1(tf.reshape(O, [-1, self.Do * 1])))
            else:
                N = tf.nn.relu(self.fc1(tf.reshape(O, [-1, self.Do * N])))
            N = tf.nn.relu(self.fc2(N))
        N = self.fc3(N)
        return N

    def tmul(self, x, y):  #Takes (I * J * K)(K * L) -> I * J * L 
        x_shape = tf.shape(x)
        y_shape = tf.shape(y)
        return tf.reshape(tf.matmul(tf.reshape(x, [-1, x_shape[2]]), y), [-1, x_shape[1], y_shape[1]]) 
        

## Training

In [None]:
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import imp
import datetime
try:
    imp.find_module('setGPU')
    import setGPU
except ImportError:
    pass    
import glob
import sys
import tqdm
import argparse
import pathlib
import tensorflow as tf
from tensorflow.keras import layers, models
from interaction import LEIA
from data import H5Data
import copy
import h5py

os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
if os.path.isdir('/data/shared/hls-fml/'):
    test_path = '/data/shared/hls-fml/NEWDATA/'
    train_path = '/data/shared/hls-fml/NEWDATA/'
elif os.path.isdir('/eos/project/d/dshep/hls-fml/'):
    test_path = '/eos/project/d/dshep/hls-fml/'
    train_path = '/eos/project/d/dshep/hls-fml/'

N = 100 # number of particles
n_targets = 5 # number of classes
n_features = 4 # number of features per particles
save_path = 'models/8/'
best_path = save_path + '/best/'
batch_size = 256
n_epochs = 100

files = glob.glob(train_path + "/jetImage*_{}p*.h5".format(N))
num_files = len(files)
files_val = files[:int(num_files*0.2)] # take first 20% for validation
files_train = files[int(num_files*0.2):] # take rest for training
files_trial = files[int(num_files*0.2):int(num_files*0.3)] 
data_train = H5Data(batch_size = batch_size,
                    cache = None,
                    preloading=0,
                    features_name='jetConstituentList', 
                    labels_name='jets',
                    spectators_name=None)
data_val = H5Data(batch_size = batch_size,
                  cache = None,
                  preloading=0,
                  features_name='jetConstituentList', 
                  labels_name='jets',
                  spectators_name=None)

            
# Define loss function
def loss(model, x, y):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    y_ = model(x)
    return cce(y_true=y, y_pred=y_)

def normalize(files, feature_name='jetConstituentList'):
    # Return the mean and std for normalization
    trial_file = h5py.File(files[0],"r")
    sample = trial_file[feature_name]
    print(f"Getting mean and std from a sample of {len(sample)} events")
    sample = np.reshape(sample, [-1, sample.shape[-1]])
    mean = np.mean(sample, axis=0).astype(np.float32)
    std = np.std(sample, axis=0).astype(np.float32)
    print(f"Mean = {mean}")
    print(f"Std = {std}")
    return mean, std

def main(args):
    """ Main entry point of the app """
    if args.trial: 
        data_train.set_file_names(files_trial)
    else:    
        data_train.set_file_names(files_train)
    data_val.set_file_names(files_val)
    
    n_val=data_val.count_data()
    n_train=data_train.count_data()

    print("val data:", n_val)
    print("train data:", n_train)

    net_args = (N, n_targets, n_features, args.hidden)
    net_kwargs = {"fr_activation": 0, "fo_activation": 0, "fc_activation": 0}
    
    gnn = LEIA(*net_args, **net_kwargs)
    gnn.build(input_shape=(None, N, n_features))

    # gnn.summary() # Doens't work. Seems like a common TF 2.0 issue: 
    # https://github.com/tensorflow/tensorflow/issues/22963 
    # https://stackoverflow.com/questions/58182032/you-tried-to-call-count-params-on-but-the-layer-isnt-built-tensorflow-2-0

    #### Start training ####
    
    # Keep results for plotting
    train_loss_results = []
    train_accuracy_results = []
    val_loss_results = []
    val_accuracy_results = []
    
    # Log directory for Tensorboard
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
    test_log_dir = 'logs/gradient_tape/' + current_time + '/test'
    pathlib.Path(train_log_dir).mkdir(parents=True, exist_ok=True)  
    pathlib.Path(test_log_dir).mkdir(parents=True, exist_ok=True)  

    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    test_summary_writer = tf.summary.create_file_writer(test_log_dir)
    
    # Load mean and std for normalization
    mean, std = normalize(files_train)
    
    best_loss = 100
    for epoch in range(n_epochs):
        
        # Tool to keep track of the metrics
        epoch_loss_avg = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        epoch_accuracy = tf.keras.metrics.CategoricalAccuracy('train_accuracy')
        val_epoch_loss_avg = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        val_epoch_accuracy = tf.keras.metrics.CategoricalAccuracy('test_accuracy')

        # Training
        for sub_X, sub_Y in tqdm.tqdm(data_train.generate_data(),total = n_train/batch_size):
#            print(f"sub_X: {sub_X.shape}")
#            print(f"sub_Y: {sub_Y.shape}")
#            training = ((sub_X.astype(np.float32) - mean)/std)[:,:,[3,0,1,2]]
            training = sub_X.astype(np.float32)[:,:,[3,0,1,2]]
            target = sub_Y.astype(np.float32)[:,-6:-1]
            def grad(model, input_par, targets):
                with tf.GradientTape() as tape:
                    loss_value = loss(model, input_par, targets)
                return loss_value, tape.gradient(loss_value, model.trainable_variables)
            
            # Define optimizer
            optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

            # Compute loss and gradients
            loss_value, grads = grad(gnn, training, target)
            
            # Update the gradients
            optimizer.apply_gradients(zip(grads, gnn.trainable_variables))
            
            # Track progress
            epoch_loss_avg(loss_value)  # Add current batch loss
            # Compare predicted label to actual label
            epoch_accuracy(target, tf.nn.softmax(gnn(training)))

        # Validation
        for sub_X, sub_Y in tqdm.tqdm(data_val.generate_data(),total = n_val/batch_size):
            #training = ((sub_X.astype(np.float32) - mean)/std)[:,:,[3,0,1,2]]
            training = (sub_X.astype(np.float32))[:,:,[3,0,1,2]]
            target = sub_Y.astype(np.float32)[:,-6:-1]
            
            # Compute the loss
            loss_value = loss(gnn, training, target)
            
            # Track progress
            val_epoch_loss_avg(loss_value)
            val_epoch_accuracy(target, tf.nn.softmax(gnn(training)))

        # End epoch
        train_loss_results.append(epoch_loss_avg.result())
        train_accuracy_results.append(epoch_accuracy.result())
        val_loss_results.append(val_epoch_loss_avg.result())
        val_accuracy_results.append(val_epoch_accuracy.result())
       
        # Save best epoch only
        if best_loss > val_epoch_loss_avg.result():
            best_loss = val_epoch_loss_avg.result()

            # Save the model after training
            #pathlib.Path(best_path).mkdir(parents=True, exist_ok=True)  
            #gnn.save_weights(best_path, save_format='tf')

        # Logs for tensorboard
        with train_summary_writer.as_default():
            tf.summary.scalar('loss', epoch_loss_avg.result(), step=epoch)
            tf.summary.scalar('accuracy', epoch_accuracy.result(), step=epoch)
        with test_summary_writer.as_default():
            tf.summary.scalar('loss', val_epoch_loss_avg.result(), step=epoch)
            tf.summary.scalar('accuracy', val_epoch_accuracy.result(), step=epoch)

        template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.2f}%, Test Loss: {:.4f}, Test Accuracy: {:.2f}%'
        print (template.format(epoch+1,
                         epoch_loss_avg.result(), 
                         epoch_accuracy.result()*100,
                         val_epoch_loss_avg.result(), 
                         val_epoch_accuracy.result()*100))

        # Reset metrics every epoch
        epoch_loss_avg.reset_states()
        val_epoch_loss_avg.reset_states()
        epoch_accuracy.reset_states()
        val_epoch_accuracy.reset_states()

    # Save the model after training
#    pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)  
#    gnn.save_weights(save_path, save_format='tf')

def evaluate(args):
    net_args = (N, n_targets, n_features, args.hidden)
    net_kwargs = {"fr_activation": 0, "fo_activation": 0, "fc_activation": 0, "De": args.De, "Do": args.Do}
    
    gnn = LEIA(*net_args, **net_kwargs)
    gnn.build(input_shape=(None, N, n_features))
    gnn.load_weights(best_path)
    
    data_val.set_file_names(files_val)
    n_val=data_val.count_data()
        
    epoch_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
    epoch_accuracy = tf.keras.metrics.CategoricalAccuracy('accuracy')
    
    # Load mean and std for normalization
    mean, std = normalize(files_train)

    # Validation
    for sub_X, sub_Y in tqdm.tqdm(data_val.generate_data(),total = n_val/batch_size):
        #training = ((sub_X.astype(np.float32) - mean)/std)[:,:,[3,0,1,2]]
        training = sub_X.astype(np.float32)[:,:,[3,0,1,2]]
        target = sub_Y.astype(np.float32)[:,-6:-1]
        
        # Compute the loss
        loss_value = loss(gnn, training, target)
        
        # Track progress
        epoch_loss(loss_value)
        epoch_accuracy(target, tf.nn.softmax(gnn(training)))
        
    template = 'Loss: {}, Accuracy: {}%'
    print (template.format(epoch_loss.result(), 
                     epoch_accuracy.result()*100))

if __name__ == "__main__":
    """ This is executed when run from the command line """
    parser = argparse.ArgumentParser()
    
    # Required positional arguments
    
    # Optional arguments
    parser.add_argument("--hidden", type=int, action='store', dest='hidden', default = 128, help="hidden parameter")
    parser.add_argument("--De", type=int, action='store', dest='De', default = 64, help="De parameter")
    parser.add_argument("--Do", type=int, action='store', dest='Do', default = 64, help="Do parameter")
    parser.add_argument("--evaluate", action='store_true', dest='evaluate', default = False, help="only run evaluation")
    parser.add_argument("--trial", action='store_true', dest='trial', default = False, help="train on a smaller sample")

    args = parser.parse_args()
    if not args.evaluate: main(args)
    else: evaluate(args)
    