In [1]:
#!/usr/bin/env python
# coding: utf-8


import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import os
import pickle 

from helpers.synthesis import *
from helpers.visualization import *

plt.style.use("science.mplstyle")

dpi = 800

In [2]:
gen_seed = 1

scatterplot_dir = f"/global/ml4hep/spss/rrmastandrea/synth_SM_AD/NF_results_wide_seed_{gen_seed}/"


n_seed = 100
num_bkg = 320000
num_sig = 20000

n_avg = 10 # number of classifiers to average the scores over
n_trials_post_avg = int(n_seed / n_avg)


In [3]:
# Load in the data

num_signals_to_inject = [0, 300, 500, 750, 1000, 1200, 1500, 2000, 2500, 3000]
#num_signals_to_inject = [0, 1500]

synth_ids = ["feta", "cathode", "curtains", "salad"]

summary_dict_bkg, summary_dict_sig = {nn:[] for nn in num_signals_to_inject}, {nn:[] for nn in num_signals_to_inject}

# load in and concatentate data
for nn in num_signals_to_inject:
    
    all_results_bkg, all_results_sig = load_in_data(nn, num_bkg, "StandardScale", synth_ids, n_seed, gen_seed, scatterplot_dir)    

    
    for ensembling_chunk in range(n_trials_post_avg):
    
        # average over the seeds in the chunk
        seeds_to_avg = np.arange(n_avg*ensembling_chunk, n_avg*ensembling_chunk + n_avg)
        concatenated_results_bkg = {iid:0 for iid in synth_ids}
        concatenated_results_sig = {iid:0 for iid in synth_ids}

        for iid in synth_ids:        
            concatenated_results_bkg[iid] = np.mean([all_results_bkg[iid][seed_NN] for seed_NN in seeds_to_avg], axis = 0)
            concatenated_results_sig[iid] = np.mean([all_results_sig[iid][seed_NN] for seed_NN in seeds_to_avg], axis = 0)
        summary_dict_bkg[nn].append(concatenated_results_bkg)
        summary_dict_sig[nn].append(concatenated_results_sig)
        


In [4]:
def get_highest_percentile(scores, num_ids_to_take):
    
    # array to keep track of the id
    ids = np.arange(len(scores))
    
    # sort the ids by the scores
    sorted_ids = [x for _,x in sorted(zip(scores,ids))]
    sorted_scores = [x for x,_ in sorted(zip(scores,ids))]

    # return the highest ids
    return(sorted_ids[-num_ids_to_take:])



In [None]:
random_1 = np.random.uniform(size = (concatenated_results_bkg["feta"].shape[0]))
random_2 = np.random.uniform(size = (concatenated_results_bkg["feta"].shape[0]))
random_3 = np.random.uniform(size = (concatenated_results_bkg["feta"].shape[0]))
random_4 = np.random.uniform(size = (concatenated_results_bkg["feta"].shape[0]))

percentiles = [0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

all_overlap_bkg = {nn:{p:[] for p in percentiles} for nn in num_signals_to_inject}
all_overlap_sig = {nn:{p:[] for p in percentiles} for nn in num_signals_to_inject}

pairs_overlap_bkg = {}
pairs_overlap_sig = {}

for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2
        pairs_overlap_bkg[pair_id] = {nn:{p:[] for p in percentiles} for nn in num_signals_to_inject}
        pairs_overlap_sig[pair_id] = {nn:{p:[] for p in percentiles} for nn in num_signals_to_inject}     

random_overlap = []
random_overlap_4 = []

for p in percentiles:
    
    print(f"Evaluating percentile {p}...")
    
    for ensembling_chunk in range(n_trials_post_avg):
        
        print(f"   Ensembling chunk {ensembling_chunk}...")
        
        if ensembling_chunk == 0: # only need to do the random overlap once
    
            # random
            num_ids_to_take_rand = int(p*summary_dict_bkg[nn][0]["feta"].shape[0])
            #print(f"Random: Taking the top {num_ids_to_take_rand} scores ({p} percentile)")

            random_1_ids = get_highest_percentile(random_1.flatten(), num_ids_to_take_rand)
            random_2_ids = get_highest_percentile(random_2.flatten(), num_ids_to_take_rand)
            random_3_ids = get_highest_percentile(random_3.flatten(), num_ids_to_take_rand)
            random_4_ids = get_highest_percentile(random_4.flatten(), num_ids_to_take_rand)

            random_ov = list(set(random_1_ids) & set(random_2_ids))
            random_overlap.append(float(len(random_ov))/num_ids_to_take_rand)

            random_ov_4 = list(set(random_1_ids) & set(random_2_ids) & set(random_3_ids) & set(random_4_ids))
            random_overlap_4.append(float(len(random_ov_4))/num_ids_to_take_rand)

        for nn in num_signals_to_inject:

            print(f"      Evaluating num. injected signal events {nn}...")

            # background
            num_ids_to_take_bkg = int(p*summary_dict_bkg[nn][ensembling_chunk]["feta"].shape[0])
            #print(f"Background: Taking the top {num_ids_to_take_bkg} scores ({p} percentile)")

            local_ids_bkg = {}

            for iid in synth_ids:
                local_ids_bkg[iid] = get_highest_percentile(summary_dict_bkg[nn][ensembling_chunk][iid].flatten(), num_ids_to_take_bkg)

            for i, id_1 in enumerate(synth_ids):
                for j, id_2 in enumerate(synth_ids[i+1:]):
                    pair_id = id_1 + "_" + id_2

                    loc_overlap_bkg = list(set(local_ids_bkg[id_1]) & set(local_ids_bkg[id_2]))
                    pairs_overlap_bkg[pair_id][nn][p].append(float(len(loc_overlap_bkg))/num_ids_to_take_bkg)

            grand_uni_bkg = list(set(local_ids_bkg["feta"]) & set(local_ids_bkg["cathode"]) & set(local_ids_bkg["curtains"]) & set(local_ids_bkg["salad"]))

            all_overlap_bkg[nn][p].append(float(len(grand_uni_bkg))/num_ids_to_take_bkg)


            # all
            num_ids_to_take_sig = int(p*summary_dict_sig[nn][ensembling_chunk]["feta"].shape[0])

            local_ids_sig = {}

            for iid in synth_ids:
                local_ids_sig[iid] = get_highest_percentile(summary_dict_sig[nn][ensembling_chunk][iid].flatten(), num_ids_to_take_sig)

            for i, id_1 in enumerate(synth_ids):
                for j, id_2 in enumerate(synth_ids[i+1:]):
                    pair_id = id_1 + "_" + id_2

                    loc_overlap_sig = list(set(local_ids_sig[id_1]) & set(local_ids_sig[id_2]))
                    pairs_overlap_sig[pair_id][nn][p].append(float(len(loc_overlap_sig))/num_ids_to_take_sig)

            grand_uni_sig = list(set(local_ids_sig["feta"]) & set(local_ids_sig["cathode"]) & set(local_ids_sig["curtains"]) & set(local_ids_sig["salad"]))

            all_overlap_sig[nn][p].append(float(len(grand_uni_sig))/num_ids_to_take_sig)

    print()

    



Evaluating percentile 0.005...
Evaluating num. injected signal events 0, ensembling chunk 0
Evaluating num. injected signal events 300, ensembling chunk 0
Evaluating num. injected signal events 500, ensembling chunk 0
Evaluating num. injected signal events 750, ensembling chunk 0
Evaluating num. injected signal events 1000, ensembling chunk 0
Evaluating num. injected signal events 1200, ensembling chunk 0
Evaluating num. injected signal events 1500, ensembling chunk 0
Evaluating num. injected signal events 2000, ensembling chunk 0
Evaluating num. injected signal events 2500, ensembling chunk 0
Evaluating num. injected signal events 3000, ensembling chunk 0
Evaluating num. injected signal events 0, ensembling chunk 1
Evaluating num. injected signal events 300, ensembling chunk 1
Evaluating num. injected signal events 500, ensembling chunk 1
Evaluating num. injected signal events 750, ensembling chunk 1
Evaluating num. injected signal events 1000, ensembling chunk 1
Evaluating num. injec

Evaluating num. injected signal events 3000, ensembling chunk 2
Evaluating num. injected signal events 0, ensembling chunk 3
Evaluating num. injected signal events 300, ensembling chunk 3
Evaluating num. injected signal events 500, ensembling chunk 3
Evaluating num. injected signal events 750, ensembling chunk 3
Evaluating num. injected signal events 1000, ensembling chunk 3
Evaluating num. injected signal events 1200, ensembling chunk 3
Evaluating num. injected signal events 1500, ensembling chunk 3
Evaluating num. injected signal events 2000, ensembling chunk 3
Evaluating num. injected signal events 2500, ensembling chunk 3
Evaluating num. injected signal events 3000, ensembling chunk 3
Evaluating num. injected signal events 0, ensembling chunk 4
Evaluating num. injected signal events 300, ensembling chunk 4
Evaluating num. injected signal events 500, ensembling chunk 4
Evaluating num. injected signal events 750, ensembling chunk 4
Evaluating num. injected signal events 1000, ensembl

Evaluating num. injected signal events 2500, ensembling chunk 5
Evaluating num. injected signal events 3000, ensembling chunk 5
Evaluating num. injected signal events 0, ensembling chunk 6
Evaluating num. injected signal events 300, ensembling chunk 6
Evaluating num. injected signal events 500, ensembling chunk 6
Evaluating num. injected signal events 750, ensembling chunk 6
Evaluating num. injected signal events 1000, ensembling chunk 6
Evaluating num. injected signal events 1200, ensembling chunk 6
Evaluating num. injected signal events 1500, ensembling chunk 6
Evaluating num. injected signal events 2000, ensembling chunk 6
Evaluating num. injected signal events 2500, ensembling chunk 6
Evaluating num. injected signal events 3000, ensembling chunk 6
Evaluating num. injected signal events 0, ensembling chunk 7
Evaluating num. injected signal events 300, ensembling chunk 7
Evaluating num. injected signal events 500, ensembling chunk 7
Evaluating num. injected signal events 750, ensembl

Evaluating num. injected signal events 2000, ensembling chunk 8
Evaluating num. injected signal events 2500, ensembling chunk 8
Evaluating num. injected signal events 3000, ensembling chunk 8
Evaluating num. injected signal events 0, ensembling chunk 9
Evaluating num. injected signal events 300, ensembling chunk 9
Evaluating num. injected signal events 500, ensembling chunk 9
Evaluating num. injected signal events 750, ensembling chunk 9
Evaluating num. injected signal events 1000, ensembling chunk 9
Evaluating num. injected signal events 1200, ensembling chunk 9
Evaluating num. injected signal events 1500, ensembling chunk 9
Evaluating num. injected signal events 2000, ensembling chunk 9
Evaluating num. injected signal events 2500, ensembling chunk 9
Evaluating num. injected signal events 3000, ensembling chunk 9

Evaluating percentile 0.2...
Evaluating num. injected signal events 0, ensembling chunk 0
Evaluating num. injected signal events 300, ensembling chunk 0
Evaluating num. inje

Evaluating num. injected signal events 1000, ensembling chunk 3
Evaluating num. injected signal events 1200, ensembling chunk 3
Evaluating num. injected signal events 1500, ensembling chunk 3
Evaluating num. injected signal events 2000, ensembling chunk 3
Evaluating num. injected signal events 2500, ensembling chunk 3
Evaluating num. injected signal events 3000, ensembling chunk 3
Evaluating num. injected signal events 0, ensembling chunk 4
Evaluating num. injected signal events 300, ensembling chunk 4
Evaluating num. injected signal events 500, ensembling chunk 4
Evaluating num. injected signal events 750, ensembling chunk 4
Evaluating num. injected signal events 1000, ensembling chunk 4
Evaluating num. injected signal events 1200, ensembling chunk 4
Evaluating num. injected signal events 1500, ensembling chunk 4
Evaluating num. injected signal events 2000, ensembling chunk 4
Evaluating num. injected signal events 2500, ensembling chunk 4
Evaluating num. injected signal events 3000, e

Evaluating num. injected signal events 750, ensembling chunk 6
Evaluating num. injected signal events 1000, ensembling chunk 6
Evaluating num. injected signal events 1200, ensembling chunk 6
Evaluating num. injected signal events 1500, ensembling chunk 6
Evaluating num. injected signal events 2000, ensembling chunk 6
Evaluating num. injected signal events 2500, ensembling chunk 6
Evaluating num. injected signal events 3000, ensembling chunk 6
Evaluating num. injected signal events 0, ensembling chunk 7
Evaluating num. injected signal events 300, ensembling chunk 7
Evaluating num. injected signal events 500, ensembling chunk 7
Evaluating num. injected signal events 750, ensembling chunk 7
Evaluating num. injected signal events 1000, ensembling chunk 7
Evaluating num. injected signal events 1200, ensembling chunk 7
Evaluating num. injected signal events 1500, ensembling chunk 7
Evaluating num. injected signal events 2000, ensembling chunk 7
Evaluating num. injected signal events 2500, en

Evaluating num. injected signal events 500, ensembling chunk 9
Evaluating num. injected signal events 750, ensembling chunk 9
Evaluating num. injected signal events 1000, ensembling chunk 9
Evaluating num. injected signal events 1200, ensembling chunk 9
Evaluating num. injected signal events 1500, ensembling chunk 9
Evaluating num. injected signal events 2000, ensembling chunk 9
Evaluating num. injected signal events 2500, ensembling chunk 9
Evaluating num. injected signal events 3000, ensembling chunk 9

Evaluating percentile 0.6...
Evaluating num. injected signal events 0, ensembling chunk 0
Evaluating num. injected signal events 300, ensembling chunk 0
Evaluating num. injected signal events 500, ensembling chunk 0
Evaluating num. injected signal events 750, ensembling chunk 0
Evaluating num. injected signal events 1000, ensembling chunk 0
Evaluating num. injected signal events 1200, ensembling chunk 0
Evaluating num. injected signal events 1500, ensembling chunk 0
Evaluating num. in

Evaluating num. injected signal events 300, ensembling chunk 2
Evaluating num. injected signal events 500, ensembling chunk 2
Evaluating num. injected signal events 750, ensembling chunk 2
Evaluating num. injected signal events 1000, ensembling chunk 2
Evaluating num. injected signal events 1200, ensembling chunk 2
Evaluating num. injected signal events 1500, ensembling chunk 2
Evaluating num. injected signal events 2000, ensembling chunk 2
Evaluating num. injected signal events 2500, ensembling chunk 2
Evaluating num. injected signal events 3000, ensembling chunk 2
Evaluating num. injected signal events 0, ensembling chunk 3
Evaluating num. injected signal events 300, ensembling chunk 3
Evaluating num. injected signal events 500, ensembling chunk 3
Evaluating num. injected signal events 750, ensembling chunk 3
Evaluating num. injected signal events 1000, ensembling chunk 3
Evaluating num. injected signal events 1200, ensembling chunk 3
Evaluating num. injected signal events 1500, ense

In [None]:
# save out the percentiles

all_overlap_bkg_median = {nn:[] for nn in num_signals_to_inject}
all_overlap_bkg_lower = {nn:[] for nn in num_signals_to_inject}
all_overlap_bkg_upper = {nn:[] for nn in num_signals_to_inject}

all_overlap_sig_median = {nn:[] for nn in num_signals_to_inject}
all_overlap_sig_lower = {nn:[] for nn in num_signals_to_inject}
all_overlap_sig_upper = {nn:[] for nn in num_signals_to_inject}


pairs_overlap_bkg_median, pairs_overlap_bkg_lower, pairs_overlap_bkg_upper = {}, {}, {}
pairs_overlap_sig_median, pairs_overlap_sig_lower, pairs_overlap_sig_upper = {}, {}, {}
for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2
        pairs_overlap_bkg_median[pair_id] = {nn:[] for nn in num_signals_to_inject}
        pairs_overlap_bkg_lower[pair_id] = {nn:[] for nn in num_signals_to_inject}
        pairs_overlap_bkg_upper[pair_id] = {nn:[] for nn in num_signals_to_inject}
        
        pairs_overlap_sig_median[pair_id] = {nn:[] for nn in num_signals_to_inject}     
        pairs_overlap_sig_lower[pair_id] = {nn:[] for nn in num_signals_to_inject}     
        pairs_overlap_sig_upper[pair_id] = {nn:[] for nn in num_signals_to_inject}     



for nn in num_signals_to_inject:
    
    for p in percentiles:
        
        all_overlap_bkg_median[nn].append(np.median(all_overlap_bkg[nn][p]))
        all_overlap_bkg_lower[nn].append(np.percentile(all_overlap_bkg[nn][p], 16))
        all_overlap_bkg_upper[nn].append(np.percentile(all_overlap_bkg[nn][p], 84))
        
        all_overlap_sig_median[nn].append(np.median(all_overlap_sig[nn][p]))
        all_overlap_sig_lower[nn].append(np.percentile(all_overlap_sig[nn][p], 16))
        all_overlap_sig_upper[nn].append(np.percentile(all_overlap_sig[nn][p], 84))
        
        for i, id_1 in enumerate(synth_ids):
            for j, id_2 in enumerate(synth_ids[i+1:]):
                pair_id = id_1 + "_" + id_2
                
                pairs_overlap_bkg_median[pair_id][nn].append(np.median(pairs_overlap_bkg[pair_id][nn][p]))
                pairs_overlap_bkg_lower[pair_id][nn].append(np.percentile(pairs_overlap_bkg[pair_id][nn][p], 16))
                pairs_overlap_bkg_upper[pair_id][nn].append(np.percentile(pairs_overlap_bkg[pair_id][nn][p], 84))
                
                pairs_overlap_sig_median[pair_id][nn].append(np.median(pairs_overlap_sig[pair_id][nn][p]))
                pairs_overlap_sig_lower[pair_id][nn].append(np.percentile(pairs_overlap_sig[pair_id][nn][p], 16))
                pairs_overlap_sig_upper[pair_id][nn].append(np.percentile(pairs_overlap_sig[pair_id][nn][p], 84))
        
        


In [None]:
pickle.dump( all_overlap_bkg_median, open( "pickles/all_overlap_bkg_median.p", "wb" ) )
pickle.dump( all_overlap_bkg_lower, open( "pickles/all_overlap_bkg_lower.p", "wb" ) )
pickle.dump( all_overlap_bkg_upper, open( "pickles/all_overlap_bkg_upper.p", "wb" ) )
pickle.dump( all_overlap_sig_median, open( "pickles/all_overlap_sig_median.p", "wb" ) )
pickle.dump( all_overlap_sig_lower, open( "pickles/all_overlap_sig_lower.p", "wb" ) )
pickle.dump( all_overlap_sig_upper, open( "pickles/all_overlap_sig_upper.p", "wb" ) )


pickle.dump( random_overlap, open( "pickles/random_overlap.p", "wb" ) )
pickle.dump( random_overlap_4, open( "pickles/random_overlap_4.p", "wb" ) )


pickle.dump( pairs_overlap_bkg_median, open( "pickles/pairs_overlap_bkg_median.p", "wb" ) )
pickle.dump( pairs_overlap_bkg_lower, open( "pickles/pairs_overlap_bkg_lower.p", "wb" ) )
pickle.dump( pairs_overlap_bkg_upper, open( "pickles/pairs_overlap_bkg_upper.p", "wb" ) )
pickle.dump( pairs_overlap_sig_median, open( "pickles/pairs_overlap_sig_median.p", "wb" ) )
pickle.dump( pairs_overlap_sig_lower, open( "pickles/pairs_overlap_sig_lower.p", "wb" ) )
pickle.dump( pairs_overlap_sig_upper, open( "pickles/pairs_overlap_sig_upper.p", "wb" ) )





In [None]:
all_overlap_bkg_median = pickle.load( open( "pickles/all_overlap_bkg_median.p", "rb" ) )
all_overlap_bkg_lower = pickle.load( open( "pickles/all_overlap_bkg_lower.p", "rb" ) )
all_overlap_bkg_upper = pickle.load(  open( "pickles/all_overlap_bkg_upper.p", "rb" ) )
all_overlap_sig_median = pickle.load(  open( "pickles/all_overlap_sig_median.p", "rb" ) )
all_overlap_sig_lower = pickle.load( open( "pickles/all_overlap_sig_lower.p", "rb" ) )
all_overlap_sig_upper = pickle.load( open( "pickles/all_overlap_sig_upper.p", "rb" ) )


random_overlap = np.array(pickle.load(  open( "pickles/random_overlap.p", "rb" ) ))
random_overlap_4 = np.array(pickle.load(  open( "pickles/random_overlap_4.p", "rb" ) ))


pairs_overlap_bkg_median = pickle.load( open( "pickles/pairs_overlap_bkg_median.p", "rb" ) )
pairs_overlap_bkg_lower= pickle.load( open( "pickles/pairs_overlap_bkg_lower.p", "rb" ) )
pairs_overlap_bkg_upper = pickle.load( open( "pickles/pairs_overlap_bkg_upper.p", "rb" ) )
pairs_overlap_sig_median = pickle.load( open( "pickles/pairs_overlap_sig_median.p", "rb" ) )
pairs_overlap_sig_lower = pickle.load( open( "pickles/pairs_overlap_sig_lower.p", "rb" ) )
pairs_overlap_sig_upper = pickle.load( open( "pickles/pairs_overlap_sig_upper.p", "rb" ) )





## For each signal injection, plot the overlap as a function of the score percentile

In [None]:
percentiles = [0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

fontsize = 30
small_font = 16

colors = ["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7"]

plot_residuals = True



l = 3

num_to_plot = 0


random_overlap = np.array(random_overlap)
random_overlap_4 = np.array(random_overlap_4)


"""
BACKGROUND
"""



fig = plt.figure(figsize = (12, 6))
ax = plt.gca()
if plot_residuals: 
    ax.plot(percentiles, random_overlap-random_overlap, label = "Random", linewidth = l+2, color = "black")
else: 
    ax.plot(percentiles, random_overlap, label = "Random", linewidth = l+2, color = "black")
k = 0
for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2 
        plot_label = id_1.upper() + "-" + id_2.upper()
        if plot_residuals: 
            ax.plot(percentiles, pairs_overlap_bkg_median[pair_id][num_to_plot]-random_overlap, label = f"{plot_label}", linewidth = l, color = colors[k])
            ax.fill_between(percentiles, pairs_overlap_bkg_lower[pair_id][num_to_plot]-random_overlap, pairs_overlap_bkg_upper[pair_id][num_to_plot]-random_overlap,color = colors[k], alpha = 0.3)

        else:
            ax.plot(percentiles, pairs_overlap_bkg_median[pair_id][num_to_plot], label = f"{plot_label}", linewidth = l, color = colors[k])
        k += 1       
if plot_residuals:
    ax.plot(percentiles, all_overlap_bkg_median[num_to_plot]-random_overlap_4, label = "All methods", linewidth = l*3, color = "grey", alpha = 0.7)
    ax.fill_between(percentiles, all_overlap_bkg_lower[num_to_plot]-random_overlap_4, all_overlap_bkg_upper[num_to_plot]-random_overlap_4,color = "grey", alpha = 0.3)

else:
    ax.plot(percentiles, all_overlap_bkg_median[num_to_plot], label = "All methods", linewidth = l*3, color = "grey", alpha = 0.7
           )
ax.legend(fontsize = small_font, bbox_to_anchor=(1.4, 1))
ax.set_xlabel("Score percentile (top)", fontsize = fontsize)
ax.set_xticks(ticks=[0, 0.25, 0.5, .75, 1])
ax.tick_params(axis='both', which='major', labelsize=small_font)
ax.text(.7, 0.8, f"$n_\\mathrm{{sig}}$ = {num_to_plot}\nBackground", fontsize = fontsize, transform = ax.transAxes)
if plot_residuals:
    ax.set_ylabel("$\Delta$ fraction of shared events", fontsize = fontsize)
else:
    ax.set_ylabel("Fraction of shared events", fontsize = fontsize)
plt.show()
#fig.savefig(f"plots/unification_bkg_{num_to_plot}.pdf", dpi = dpi)


"""
SIGNAL
"""


fig = plt.figure(figsize = (12, 6))
ax = plt.gca()
if plot_residuals: 
    ax.plot(percentiles, random_overlap-random_overlap, label = 'Random', linewidth = l+2, color = "black")
else: 
    ax.plot(percentiles, random_overlap, label = 'Random', linewidth = l+2, color = "black")
k = 0
for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2      
        plot_label = id_1.upper() + "-" + id_2.upper()   
        if plot_residuals: 
            ax.plot(percentiles, pairs_overlap_sig_median[pair_id][num_to_plot]-random_overlap, label = f"{plot_label}", linewidth = l, color = colors[k])
            ax.fill_between(percentiles, pairs_overlap_sig_lower[pair_id][num_to_plot]-random_overlap, pairs_overlap_sig_upper[pair_id][num_to_plot]-random_overlap,color = colors[k], alpha = 0.3)

        else:
            ax.plot(percentiles, pairs_overlap_sig_median[pair_id][num_to_plot], label = f"{plot_label}", linewidth = l, color = colors[k])       
        k += 1       
if plot_residuals:
    ax.plot(percentiles, all_overlap_sig_median[num_to_plot]-random_overlap_4, label = 'All methods', linewidth = l*3, color = "grey", alpha = 0.7)
    ax.fill_between(percentiles, all_overlap_sig_lower[num_to_plot]-random_overlap_4, all_overlap_sig_upper[num_to_plot]-random_overlap_4,color = "grey", alpha = 0.3)

else:
    ax.plot(percentiles, all_overlap_sig_median[num_to_plot], label = 'All methods', linewidth = l*3, color = "grey", alpha = 0.7)
ax.legend(fontsize = small_font, bbox_to_anchor=(1.4, 1))
ax.set_xlabel("Score percentile (top)", fontsize = fontsize)
ax.set_xticks(ticks=[0, 0.25, 0.5, .75, 1])
ax.tick_params(axis='both', which='major', labelsize=small_font)
ax.text(.7, 0.8, f"$n_\\mathrm{{sig}}$ = {num_to_plot}\nSignal", fontsize = fontsize, transform = ax.transAxes)
if plot_residuals:
    ax.set_ylabel("$\Delta$ fraction of shared events", fontsize = fontsize)
else:
    ax.set_ylabel("Fraction of shared events", fontsize = fontsize)


plt.show()

#fig.savefig(f"plots/unification_sig_{num_to_plot}.pdf", dpi = dpi)


### For each score pecentile, plot the overlap as a function of the signal injection

In [None]:
percentiles = [0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

fontsize = 30
small_font = 16

l = 3


target_percentile = 0.05
target_percentile_index = percentiles.index(target_percentile)






loc_overlaps_bkg_median, loc_overlaps_bkg_lower, loc_overlaps_bkg_upper = {}, {}, {}
loc_overlaps_sig_median, loc_overlaps_sig_lower, loc_overlaps_sig_upper = {}, {}, {}

for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2
        loc_overlaps_bkg_median[pair_id] = []
        loc_overlaps_bkg_lower[pair_id] = []
        loc_overlaps_bkg_upper[pair_id] = []
        
        loc_overlaps_sig_median[pair_id] = []
        loc_overlaps_sig_lower[pair_id] = []
        loc_overlaps_sig_upper[pair_id] = []
        


loc_all_overlap_bkg_median, loc_all_overlap_bkg_lower, loc_all_overlap_bkg_upper = [], [], []
loc_all_overlap_sig_median, loc_all_overlap_sig_lower, loc_all_overlap_sig_upper = [], [], []
    

for nn in num_signals_to_inject:
    
    for i, id_1 in enumerate(synth_ids):
        for j, id_2 in enumerate(synth_ids[i+1:]):
            pair_id = id_1 + "_" + id_2
            
            loc_overlaps_bkg_median[pair_id].append(pairs_overlap_bkg_median[pair_id][nn][target_percentile_index] - random_overlap[target_percentile_index])
            loc_overlaps_bkg_lower[pair_id].append(pairs_overlap_bkg_lower[pair_id][nn][target_percentile_index] - random_overlap[target_percentile_index])
            loc_overlaps_bkg_upper[pair_id].append(pairs_overlap_bkg_upper[pair_id][nn][target_percentile_index] - random_overlap[target_percentile_index])

            loc_overlaps_sig_median[pair_id].append(pairs_overlap_sig_median[pair_id][nn][target_percentile_index] - random_overlap[target_percentile_index])
            loc_overlaps_sig_lower[pair_id].append(pairs_overlap_sig_lower[pair_id][nn][target_percentile_index] - random_overlap[target_percentile_index])
            loc_overlaps_sig_upper[pair_id].append(pairs_overlap_sig_upper[pair_id][nn][target_percentile_index] - random_overlap[target_percentile_index])

    
    loc_all_overlap_bkg_median.append(all_overlap_bkg_median[nn][target_percentile_index]- random_overlap_4[target_percentile_index])
    loc_all_overlap_bkg_lower.append(all_overlap_bkg_lower[nn][target_percentile_index]- random_overlap_4[target_percentile_index])
    loc_all_overlap_bkg_upper.append(all_overlap_bkg_upper[nn][target_percentile_index]- random_overlap_4[target_percentile_index])

    loc_all_overlap_sig_median.append(all_overlap_sig_median[nn][target_percentile_index]- random_overlap_4[target_percentile_index])
    loc_all_overlap_sig_lower.append(all_overlap_sig_lower[nn][target_percentile_index]- random_overlap_4[target_percentile_index])
    loc_all_overlap_sig_upper.append(all_overlap_sig_upper[nn][target_percentile_index]- random_overlap_4[target_percentile_index])

    
    
"""
BACKGROUND
"""
fig = plt.figure(figsize = (12, 6))
ax = plt.gca()

k = 0
for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2
        plot_label = id_1.upper() + "-" + id_2.upper()
        ax.plot(num_signals_to_inject, loc_overlaps_bkg_median[pair_id], label = f"{plot_label}", linewidth = l, color = colors[k])    
        ax.fill_between(num_signals_to_inject, loc_overlaps_bkg_lower[pair_id], loc_overlaps_bkg_upper[pair_id], alpha = 0.3, color = colors[k])      

        k += 1
ax.plot(num_signals_to_inject, loc_all_overlap_bkg_median, label = "All methods", linewidth = l*3, color = "grey", alpha = 0.7)
ax.fill_between(num_signals_to_inject, loc_all_overlap_bkg_lower, loc_all_overlap_bkg_upper, linewidth = l*3, color = "grey", alpha = 0.3)

ax.axhline(target_percentile, label = "Random", linewidth = l+2, color = "black")
ax.legend(fontsize = small_font, bbox_to_anchor=(1.4, 1))
ax.set_ylim(-.1, 0.75)
ax.set_xlabel("Num. injected signal events", fontsize = fontsize)
ax.set_ylabel("$\Delta$ fraction of shared events", fontsize = fontsize)
ax.set_xticks(ticks=num_signals_to_inject)
ax.tick_params(axis='both', which='major', labelsize=small_font)
ax.text(1.1, 0.55, f"Percentile = {target_percentile}\nBackground", fontsize = fontsize)

plt.show()

fig.savefig(f"plots/unification_percentile_{100*target_percentile}_bkg.pdf", dpi = dpi)


"""
SIGNAL
"""

fig = plt.figure(figsize = (12, 6))
ax = plt.gca()

k = 0
for i, id_1 in enumerate(synth_ids):
    for j, id_2 in enumerate(synth_ids[i+1:]):
        pair_id = id_1 + "_" + id_2
        plot_label = id_1.upper() + "-" + id_2.upper()
        ax.plot(num_signals_to_inject, loc_overlaps_sig[pair_id], label = f"_nolegend_", linewidth = l, color = colors[k])
        ax.fill_between(num_signals_to_inject, loc_overlaps_sig_lower[pair_id], loc_overlaps_sig_upper[pair_id], alpha = 0.3, color = colors[k])      

        k += 1
ax.plot(num_signals_to_inject, loc_all_overlap_sig,  label = "_nolegend_", linewidth = l*3, color = "grey", alpha = 0.7)
ax.fill_between(num_signals_to_inject, loc_all_overlap_sig_lower, loc_all_overlap_sig_upper, linewidth = l*3, color = "grey", alpha = 0.3)


ax.axhline(target_percentile, label = "_nolegend_", linewidth = l+2, color = "black")

ax.legend(fontsize = small_font, bbox_to_anchor=(1.4, 1))

ax.set_ylim(-.1, 1.)
ax.set_xlabel("Num. injected signal events", fontsize = fontsize)
ax.set_ylabel("$\Delta$ fraction of shared events", fontsize = fontsize)
ax.set_xticks(ticks=num_signals_to_inject)
ax.tick_params(axis='both', which='major', labelsize=small_font)
ax.text(1.1, 0.75, f"Percentile = {target_percentile}\nSignal", fontsize = fontsize)

plt.show()

fig.savefig(f"plots/unification_percentile_{100*target_percentile}_sig.pdf", dpi = dpi)

