In [23]:
import logging

import awkward as ak
import h5py as h5
import numpy as np
import vector

logging.basicConfig(level=logging.INFO)

from hist.intervals import clopper_pearson_interval

import matplotlib.pyplot as plt
import mplhep as hep

plt.style.use(hep.style.CMS)

In [None]:
# read test target file
test_file = "hhh_test.h5"
test_h5 = h5.File(test_file)

# read baseline prediction
baseline_file = "pred_baseline.h5"
b_h5 = h5.File(baseline_file)

# read spanet prediction
spanet_file = "pred_v53.h5"
s_h5 = h5.File(spanet_file)

In [None]:
def sel_pred_by_dp_ap(dps, aps, bqq_ps, dp_cut, ap_cut):
    # parse predicted bb assignment by DP 
    dp_filter = dps > dp_cut # minimal detection probability to be considered a top
    ap_filter = aps > ap_cut # minimal assignment probability to be considered a top
    filter = ap_filter & dp_filter
    bqq_ps_passed = bqq_ps.mask[filter] 
    bqq_ps_passed = ak.drop_none(bqq_ps_passed)

    return bqq_ps_passed

In [None]:
def sel_target_by_mask(bqq_ts, FBt_pts, FBt_masks): # looks at boosted baseline
    bqq_ts_selected = bqq_ts.mask[FBt_masks] # applying mask so that only events with the minimal number of jets are selected
    bqq_ts_selected = ak.drop_none(bqq_ts_selected)

    FBt_selected_pts = FBt_pts.mask[FBt_masks]
    FBt_selected_pts = ak.drop_none(FBt_selected_pts)

    return bqq_ts_selected, FBt_selected_pts

In [None]:
# A pred look up table is in shape
# [event,
#    pred_H,
#       [correct, pred_H_pt]]
def gen_pred_LUT(bqq_ps_passed, bqq_ts_selected, vfj_pts):
    LUT = []
    # for each event
    for bqq_t_event, bqq_p_event, vfj_pt_event in zip(
        bqq_ts_selected, bqq_ps_passed, vfj_pts
    ):
        # for each predicted bqq assignment, check if any target H have a same bqq assignment
        LUT_event = []
        for i, bqq_p in enumerate(bqq_p_event):
            correct = 0
            predt_pt = vfj_pt_event[bqq_p - 10]
            for bqq_t in bqq_t_event:
                if bqq_p == bqq_t + 10:
                    correct = 1
            LUT_event.append([correct, predt_pt])
        LUT.append(LUT_event)
    return LUT

In [None]:
# A target look up table is in shape
# [event,
#    target_H,
#        target_bb_assign,
#           [retrieved, targetH_pt]]
def gen_target_LUT(bqq_ps_passed, bqq_ts_selected, FBt_pts):
    LUT = []
    # for each event
    for bqq_t_event, bqq_p_event, FBt_pts_event in zip(
        bqq_ts_selected, bqq_ps_passed, FBt_pts
    ):
        # for each target fatjet, check if the predictions have a p fatject same with the t fatjet
        LUT_event = []
        for i, bqq_t in enumerate(bqq_t_event):
            retrieved = 0
            targetH_pt = FBt_pts_event[i]
            for bqq_p in bqq_p_event:
                if bqq_p == bqq_t + 10:
                    retrieved = 1
            LUT_event.append([retrieved, targetH_pt])
        LUT.append(LUT_event)
    return LUT

In [None]:
# generate pred/target LUT
# each entry corresponds to [recot correct or not, reco t pt]
# or
# [targett retrieved or not, target t pt]
def parse_pred_w_target(testfile, predfile, dp_cut=0.5, ap_cut=1 / 13):
    # Collect t pt, mask, target and predicted jet and fjets for 3 ts in each event
    # t pt
    FBt1_pt = np.array(testfile["TARGETS"]["FBt1"]["pt"])
    FBt2_pt = np.array(testfile["TARGETS"]["FBt2"]["pt"])

    # mask
    FBt1_mask = np.array(testfile["TARGETS"]["FBt1"]["mask"])
    FBt2_mask = np.array(testfile["TARGETS"]["FBt2"]["mask"])

    # target jet/fjets
    bqq_FBt1_t = np.array(testfile["TARGETS"]["FBt1"]["bqq"])
    bqq_FBt2_t = np.array(testfile["TARGETS"]["FBt2"]["bqq"])

    # pred jet/fjets
    bqq_FBt1_p = np.array(predfile["TARGETS"]["FBt1"]["bqq"])
    bqq_FBt2_p = np.array(predfile["TARGETS"]["FBt2"]["bqq"])

    # fatjet detection probability
    dp_FBt1 = np.array(predfile["TARGETS"]["FBt1"]["detection_probability"])
    dp_FBt2 = np.array(predfile["TARGETS"]["FBt2"]["detection_probability"])

    # fatjet assignment probability
    ap_FBt1 = np.array(predfile["TARGETS"]["FBt1"]["assignment_probability"])
    ap_FBt2 = np.array(predfile["TARGETS"]["FBt2"]["assignment_probability"])

    # collect fatjet pt
    vfj_pts = np.array(testfile["INPUTS"]["VeryBoostedJets"]["vfj_pt"])

    # convert some arrays to ak array
    dps = np.concatenate(
        (dp_FBt1.reshape(-1, 1), dp_FBt2.reshape(-1, 1)), axis=1
    )
    dps = ak.Array(dps)
    aps = np.concatenate(
        (ap_FBt1.reshape(-1, 1), ap_FBt2.reshape(-1, 1)), axis=1
    )
    aps = ak.Array(aps)
    bqq_ps = np.concatenate(
        (bqq_FBt1_p.reshape(-1, 1), bqq_FBt2_p.reshape(-1, 1)), axis=1
    )
    bqq_ps = ak.Array(bqq_ps)
    bqq_ts = np.concatenate(
        (bqq_FBt1_t.reshape(-1, 1), bqq_FBt2_t.reshape(-1, 1)), axis=1
    )
    bqq_ts = ak.Array(bqq_ts)
    vfj_pts = ak.Array(vfj_pts)
    FBt_masks = np.concatenate(
        (FBt1_mask.reshape(-1, 1), FBt2_mask.reshape(-1, 1)), axis=1
    )
    FBt_masks = ak.Array(FBt_masks)
    FBt_pts = np.concatenate(
        (FBt1_pt.reshape(-1, 1), FBt2_pt.reshape(-1, 1)), axis=1
    )
    FBt_pts = ak.Array(FBt_pts)

    # select predictions and targets
    # Change bb to bqq and TargetH_selected to FBt_selected_pts
    bqq_ts_selected, FBt_selected_pts = sel_target_by_mask(bqq_ts, FBt_pts, FBt_masks)
    bqq_ps_selected = sel_pred_by_dp_ap(dps, aps, bqq_ps, dp_cut, ap_cut)

    # generate correct/retrieved LUT for pred/target respectively
    LUT_pred = gen_pred_LUT(bqq_ps_selected, bqq_ts_selected, vfj_pts)
    LUT_target = gen_target_LUT(bqq_ps_selected, bqq_ts_selected, FBt_selected_pts)

    return LUT_pred, LUT_target

In [None]:
LUT_pred_spanet, LUT_target_spanet = parse_pred_w_target(test_h5, s_h5)

In [None]:
LUT_pred_baseline, LUT_target_baseline = parse_pred_w_target(test_h5, b_h5)

In [None]:
# calculate efficiency (percentage of the time that the model's prediction is correct)
# if bins=None, put all data in a single bin
#predT and TargetT
def calc_eff(LUT_pred, bins):
    predTs = [predT for event in LUT_pred for predT in event]
    predTs = np.array(predTs)

    predTs_inds = np.digitize(predTs[:, 1], bins)

    correctTruth_per_bin = []
    for bin_i in range(1, len(bins)):
        correctTruth_per_bin.append(predTs[:, 0][predTs_inds == bin_i])
    correctTruth_per_bin = ak.Array(correctTruth_per_bin)

    means = ak.mean(correctTruth_per_bin, axis=-1)

    errs = np.abs(
        clopper_pearson_interval(
            num=ak.sum(correctTruth_per_bin, axis=-1),
            denom=ak.num(correctTruth_per_bin, axis=-1),
        )
        - means
    )

    return means, errs

In [None]:
# # calculate purity (percentage of the time that the model correctly found a target particle)
# def calc_pur(LUT_target, bins):
#     targetHs = [targetH for event in LUT_target for targetH in event]
#     targetHs = np.array(targetHs)

#     targetHs_inds = np.digitize(targetHs[:, 1], bins)

#     correctTruth_per_bin = []
#     for bin_i in range(1, len(bins)):
#         correctTruth_per_bin.append(targetHs[:, 0][targetHs_inds == bin_i])
#     correctTruth_per_bin = ak.Array(correctTruth_per_bin)

#     means = ak.mean(correctTruth_per_bin, axis=-1)

#     errs = np.abs(
#         clopper_pearson_interval(
#             num=ak.sum(correctTruth_per_bin, axis=-1),
#             denom=ak.num(correctTruth_per_bin, axis=-1),
#         )
#         - means
#     )

#     return means, errs

In [None]:
bins = np.arange(200, 1000, 100)
bin_centers = [(bins[i] + bins[i + 1]) / 2 for i in range(bins.size - 1)]
xerr = (bins[1] - bins[0]) / 2 * np.ones(bins.shape[0] - 1)

In [None]:
dp_cut = 0.85
LUT_pred_spanet, LUT_target_spanet = parse_pred_w_target(test_h5, s_h5, dp_cut=dp_cut)
LUT_pred_baseline, LUT_target_baseline = parse_pred_w_target(
    test_h5, b_h5, dp_cut=dp_cut
)
eff_s, efferr_s = calc_eff(LUT_pred_spanet, bins)
eff_b, efferr_b = calc_eff(LUT_pred_baseline, bins)
pur_s, purerr_s = calc_eff(LUT_target_spanet, bins)
pur_b, purerr_b = calc_eff(LUT_target_baseline, bins)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].errorbar(
    x=bin_centers, y=eff_s, yerr=efferr_s, fmt="o", capsize=5, label="SPANet"
)
ax[0].errorbar(
    x=bin_centers, y=eff_b, yerr=efferr_b, fmt="x", capsize=5, label="Baseline"
)
ax[0].set(xlabel=r"Reconstructed t $p_T$ [GeV]", ylabel=r"Matching purity")
ax[0].set_ylim(0, 1)
ax[0].set_xlim(300, 900)

ax[1].errorbar(
    x=bin_centers, y=pur_s, yerr=purerr_s, fmt="o", capsize=5, label="SPANet"
)
ax[1].errorbar(
    x=bin_centers, y=pur_b, yerr=purerr_b, fmt="x", capsize=5, label="Baseline"
)
ax[1].set(xlabel=r"True t $p_T$ [GeV]", ylabel=r"Matching efficiency")
ax[1].set_ylim(0, 1)
ax[1].set_xlim(300, 900)


ax[0].legend()
# ax[1].legend()
plt.tight_layout()
plt.savefig("spanet_baseline_pur_eff.pdf")