In [None]:
import numpy as np
import matplotlib.pyplot as plt
import logging
from keras.callbacks import Callback

class IntervalEvaluation(Callback):

    def __init__(self,
                 training_data=(),                 # Training data
                 validation_data=(),               # Validation data
                 verbose=0,                        # Verbose = 1: print data af each epoch. Verbose = 2: Additionally save plots.
                 filename='checkpoint_best.h5',    # Model save filename
                 period=1,                         # Epoch-frequency for checking performance
                 min_epoch=10,                     # Wait this no of epochs before saving best model
                 avg_length=5,                     # Period for plotting moving average of performance metric
                 eff_rate=0.01,                    # Threshold used to evaluate performance metric
                 patience=300,                     # Epochs to wait since last performance increase
                 plot_period=1,                    # How frequently to plot performance metric
                 batch_size=20000,                 # Batch size for NN prediction
                 max_epochs=700):                  # Max number of epochs
        
        self.x_train, self.y_train = training_data
        self.x_val, self.y_val = validation_data
        self.verbose = verbose
        self.filename = filename
        self.period = period
        self.min_epoch = min_epoch
        self.avg_length = avg_length
        self.eff_rate = eff_rate
        self.patience = patience
        self.plot_period = plot_period
        self.batch_size = batch_size
        self.max_epochs = max_epochs
        
    def on_train_begin(self, logs={}):
        self.effs_val = []
        self.effs_val_avg = []
        self.effs_train = []
        self.effs_train_avg = []
        self.loss = []
        self.val_loss = []
        self.n_wait = 0
        
    def on_epoch_end(self, epoch, logs={}):
        self.loss.append(logs['loss'])
        self.val_loss.append(logs['val_loss'])

        if epoch % self.period == 0:
            
            # Calculate scores for train and val data
            scores_val = self.model.predict(self.x_val, batch_size=self.batch_size).flatten()

            # Get score threshold above which only 0.5% of bg survives in SB
            scores_sorted = np.sort(scores_val[self.y_val < 0.5])[::-1]  # Inverse sort
            cut = self.eff_rate * len(scores_sorted)
            thresh = scores_sorted[int(cut)]
        
            # Find fraction of signal events in SR which survive a cut on the above threshold
            scores_sorted = np.sort(scores_val[self.y_val > 0.5])
            sig_eff = 1.0 - 1.0*np.searchsorted(scores_sorted, thresh) / len(scores_sorted)
            
            if epoch > self.min_epoch:
                #Increase patience timer by one epoch
                self.n_wait = self.n_wait + 1
            
            # If this model is the best so far, save the model and reset the patience timer
            if len(self.effs_val) == 0:
                self.model.save(self.filename)
            elif len(self.effs_val) > self.min_epoch:
                if (sig_eff >= np.array(self.effs_val)[self.min_epoch:].max()):
                    self.model.save(self.filename)
                    self.n_wait = 0
                    
            # Compute average
            self.effs_val.append(sig_eff)
            if(len(self.effs_val) <= self.avg_length):
                self.effs_val_avg.append(np.mean(self.effs_val))
            else:
                self.effs_val_avg.append(np.mean(np.array(self.effs_val)[-self.avg_length:]))
               
            if(self.verbose):
                print("sig eff = ", sig_eff)

            if (self.verbose > 1) & (epoch % self.plot_period == 0):
                plt.figure(figsize=(14,5))
                plt.subplot(1, 2, 1)
                plt.plot(self.effs_val,color='C1')
                if(self.avg_length > 1):
                    plt.plot(self.effs_val_avg,color='C1',linestyle='--')

            if len(self.x_train) > 0:
                
                # Calculate scores for train and val data
                scores_train = self.model.predict(self.x_train, batch_size=self.batch_size).flatten()

                # Get score threshold above which only 0.5% of bg survives in SB
                scores_sorted = np.sort(scores_train[self.y_train < 0.5])[::-1]  # Inverse sort
                cut = self.eff_rate * len(scores_sorted)
                thresh = scores_sorted[int(cut)]

                # Find fraction of signal events in SR which survive a cut on the above threshold
                scores_sorted = np.sort(scores_train[self.y_train > 0.5])
                sig_eff = 1.0 - 1.0*np.searchsorted(scores_sorted, thresh) / len(scores_sorted)
                self.effs_train.append(sig_eff)
            
                # Calculate average
                if(len(self.effs_train) <= self.avg_length):
                    self.effs_train_avg.append(np.mean(self.effs_train))
                else:
                    self.effs_train_avg.append(np.mean(np.array(self.effs_train)[-self.avg_length:]))
                
                if(self.verbose):
                    print("sig eff train = ", sig_eff)
                    
                if (self.verbose > 1) & (epoch % self.plot_period == 0):
                    plt.plot(self.effs_train,color='C0')
                    if(self.avg_length > 1):
                        plt.plot(self.effs_train_avg,color='C0',linestyle='--')
                    plt.grid(b=True)

            # If we have waited too long with no improvement, halt training.
            if ((self.patience > 0) & (self.n_wait > self.patience)) or (epoch == self.max_epochs-1):
                if self.verbose > -1:
                    plt.close('all')
                    plt.figure(figsize=(14,5))
                    plt.subplot(1, 2, 1)
                    plt.plot(self.effs_val,color='C1')
                    plt.plot(self.effs_train,color='C0')
                    if(self.avg_length > 1):
                        plt.plot(self.effs_val_avg,color='C1',linestyle='--')
                        plt.plot(self.effs_train_avg,color='C0',linestyle='--')
                    plt.grid(b=True)
                    plt.subplot(1, 2, 2)
                    plt.plot(self.val_loss,color='C1')
                    plt.plot(self.loss,color='C0')
                    plt.grid(b=True)
                    print("Saving fig:", self.filename[:-3] + "_losseffplots.png")
                    plt.savefig(self.filename[:-3] + "_losseffplots.png")
                    plt.show()
                    self.verbose = 0
                print("Training ends at epoch: %d" % epoch)
                self.model.stop_training = True


                
        if (self.verbose > 1) & (epoch % self.plot_period == 0):
            plt.subplot(1, 2, 2)
            plt.plot(self.val_loss,color='C1')
            plt.plot(self.loss,color='C0')
            plt.grid(b=True)
            plt.savefig(self.filename[:-3] + "_losseffplots.png")