In [6]:
import numpy as np
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc, roc_auc_score
from scipy.special import softmax

import uproot
import glob

In [None]:

classes = ['QCD', 'Hbb', 'Hcc', 'Hgg', 'H4q', 'Hqql', 'Zqq', 'Wqq', 'Tbqq', 'Tbl']
n_classes = len(classes)
label_list = [f'label_{cls}' for cls in classes]
score_list = [f'score_label_{cls}' for cls in classes]


model = "MorePairAttnParT"
arrays = []
concat_arrays = {}
for file_name in glob.glob(f"../models/{model}_full_pred_*.root"):
    with uproot.open(file_name) as f:
        arrays.append(f["Events"].arrays(label_list + score_list))
for key in label_list + score_list:
    concat_arrays[key] = np.concatenate([arrays[i][key].to_numpy() for i in range(len(arrays))])


y_prob = np.stack([concat_arrays[key] for key in score_list], axis=1)
labels = np.stack([concat_arrays[key] for key in label_list], axis=1).astype(int)


In [None]:
overall_roc_auc = roc_auc_score(labels, y_prob, average='macro', multi_class='ovo')
predicted_labels = np.argmax(y_prob, axis=1) 
true_labels = np.argmax(labels, axis=1)  

accuracy = accuracy_score(true_labels, predicted_labels)

print(f'Overall ROC AUC = {overall_roc_auc:.4f}, Accuracy = {accuracy:.4f}')


scores = y_prob[:, 1:] / (y_prob[:, :1] + y_prob[:, 1:])
scores = np.concatenate((y_prob[:, :1], scores), axis=1)

rejections = []

for i in range(1, n_classes):
    if i == 5:
        percent = 0.99
    elif i == 9:
        percent = 0.995
    else:
        percent = 0.5
    
    mask = (labels[:, 0] == 1) | (labels[:, i] == 1)
    filtered_labels = labels[mask]
    filtered_scores = scores[mask]
    
    binary_labels = (filtered_labels[:, i] == 1).astype(int)
    binary_scores = filtered_scores[:, i]
    
    fpr, tpr, thresholds = roc_curve(binary_labels, binary_scores)

    idx = np.abs(tpr - percent).argmin()
    
    if fpr[idx] != 0:
        rejection = 1 / fpr[idx]
    else:
        rejection = np.inf  
    
    rejections.append(rejection)
    
    print(f'Rejection at {percent*100}% for {label_list[i]}: {rejection}')