In [None]:
import logging

import awkward as ak
import h5py as h5
import numpy as np
import vector

from coffea.hist.plot import clopper_pearson_interval
import matplotlib.pyplot as plt
import mplhep as hep

vector.register_awkward()

logging.basicConfig(level=logging.INFO)
plt.style.use(hep.style.CMS)

In [None]:
# read test target file
test_file = "hhh_test.h5"
test_h5 = h5.File(test_file)

# read baseline prediction
baseline_file = "pred_baseline.h5"
b_h5 = h5.File(baseline_file)

# read spanet prediction
spanet_file = "pred_v53.h5"
s_h5 = h5.File(spanet_file)

In [None]:
def sel_pred_by_dp_ap(dps, aps, bb_ps, dp_cut, ap_cut):
    # parse predicted bb assignment by DP
    dp_filter = dps > dp_cut
    ap_filter = aps > ap_cut
    filter = ap_filter & dp_filter
    bb_ps_passed = bb_ps.mask[filter]
    bb_ps_passed = ak.drop_none(bb_ps_passed)

    return bb_ps_passed

In [None]:
def sel_target_by_mask(bb_ts, bh_pts, bh_masks):
    bb_ts_selected = bb_ts.mask[bh_masks]
    bb_ts_selected = ak.drop_none(bb_ts_selected)

    bh_selected_pts = bh_pts.mask[bh_masks]
    bh_selected_pts = ak.drop_none(bh_selected_pts)

    return bb_ts_selected, bh_selected_pts

In [None]:
# A pred look up table is in shape
# [event,
#    pred_H,
#       [correct, pred_H_pt]]
def gen_pred_LUT(bb_ps_passed, bb_ts_selected, fj_pts):
    LUT = []
    # for each event
    for bb_t_event, bb_p_event, fj_pt_event in zip(
        bb_ts_selected, bb_ps_passed, fj_pts
    ):
        # for each predicted bb assignment, check if any target H have a same bb assignment
        LUT_event = []
        for i, bb_p in enumerate(bb_p_event):
            correct = 0
            predH_pt = fj_pt_event[bb_p - 10]
            for bb_t in bb_t_event:
                if bb_p == bb_t + 10:
                    correct = 1
            LUT_event.append([correct, predH_pt])
        LUT.append(LUT_event)
    return LUT

In [None]:
# A target look up table is in shape
# [event,
#    target_H,
#        target_bb_assign,
#           [retrieved, targetH_pt]]
def gen_target_LUT(bb_ps_passed, bb_ts_selected, targetH_pts):
    LUT = []
    # for each event
    for bb_t_event, bb_p_event, targetH_pts_event in zip(
        bb_ts_selected, bb_ps_passed, targetH_pts
    ):
        # for each target fatjet, check if the predictions have a p fatject same with the t fatjet
        LUT_event = []
        for i, bb_t in enumerate(bb_t_event):
            retrieved = 0
            targetH_pt = targetH_pts_event[i]
            for bb_p in bb_p_event:
                if bb_p == bb_t + 10:
                    retrieved = 1
            LUT_event.append([retrieved, targetH_pt])
        LUT.append(LUT_event)
    return LUT

In [None]:
# generate pred/target LUT
# each entry corresponds to [recoH correct or not, reco H pt]
# or
# [targetH retrieved or not, target H pt]
def parse_pred_w_target(testfile, predfile, dp_cut=0.5, ap_cut=1 / 13):
    # Collect H pt, mask, target and predicted jet and fjets for 3 Hs in each event
    # h pt
    h1_pt = np.array(testfile["TARGETS"]["h1"]["pt"])
    h2_pt = np.array(testfile["TARGETS"]["h2"]["pt"])
    h3_pt = np.array(testfile["TARGETS"]["h3"]["pt"])

    bh1_pt = np.array(testfile["TARGETS"]["bh1"]["pt"])
    bh2_pt = np.array(testfile["TARGETS"]["bh2"]["pt"])
    bh3_pt = np.array(testfile["TARGETS"]["bh3"]["pt"])

    # mask
    h1_mask = np.array(testfile["TARGETS"]["h1"]["mask"])
    h2_mask = np.array(testfile["TARGETS"]["h2"]["mask"])
    h3_mask = np.array(testfile["TARGETS"]["h3"]["mask"])

    bh1_mask = np.array(testfile["TARGETS"]["bh1"]["mask"])
    bh2_mask = np.array(testfile["TARGETS"]["bh2"]["mask"])
    bh3_mask = np.array(testfile["TARGETS"]["bh3"]["mask"])

    # target jet/fjets
    b1_h1_t = np.array(testfile["TARGETS"]["h1"]["b1"])
    b1_h2_t = np.array(testfile["TARGETS"]["h2"]["b1"])
    b1_h3_t = np.array(testfile["TARGETS"]["h3"]["b1"])

    b2_h1_t = np.array(testfile["TARGETS"]["h1"]["b2"])
    b2_h2_t = np.array(testfile["TARGETS"]["h2"]["b2"])
    b2_h3_t = np.array(testfile["TARGETS"]["h3"]["b2"])

    bb_bh1_t = np.array(testfile["TARGETS"]["bh1"]["bb"])
    bb_bh2_t = np.array(testfile["TARGETS"]["bh2"]["bb"])
    bb_bh3_t = np.array(testfile["TARGETS"]["bh3"]["bb"])

    # pred jet/fjets
    bb_bh1_p = np.array(predfile["TARGETS"]["bh1"]["bb"])
    bb_bh2_p = np.array(predfile["TARGETS"]["bh2"]["bb"])
    bb_bh3_p = np.array(predfile["TARGETS"]["bh3"]["bb"])

    # fatjet detection probability
    dp_bh1 = np.array(predfile["TARGETS"]["bh1"]["detection_probability"])
    dp_bh2 = np.array(predfile["TARGETS"]["bh2"]["detection_probability"])
    dp_bh3 = np.array(predfile["TARGETS"]["bh3"]["detection_probability"])

    # fatjet assignment probability
    ap_bh1 = np.array(predfile["TARGETS"]["bh1"]["assignment_probability"])
    ap_bh2 = np.array(predfile["TARGETS"]["bh2"]["assignment_probability"])
    ap_bh3 = np.array(predfile["TARGETS"]["bh3"]["assignment_probability"])

    # collect fatjet pt
    fj_pts = np.array(testfile["INPUTS"]["BoostedJets"]["fj_pt"])

    # convert some arrays to ak array
    dps = np.concatenate(
        (dp_bh1.reshape(-1, 1), dp_bh2.reshape(-1, 1), dp_bh3.reshape(-1, 1)), axis=1
    )
    dps = ak.Array(dps)
    aps = np.concatenate(
        (ap_bh1.reshape(-1, 1), ap_bh2.reshape(-1, 1), ap_bh3.reshape(-1, 1)), axis=1
    )
    aps = ak.Array(aps)
    bb_ps = np.concatenate(
        (bb_bh1_p.reshape(-1, 1), bb_bh2_p.reshape(-1, 1), bb_bh3_p.reshape(-1, 1)),
        axis=1,
    )
    bb_ps = ak.Array(bb_ps)
    bb_ts = np.concatenate(
        (bb_bh1_t.reshape(-1, 1), bb_bh2_t.reshape(-1, 1), bb_bh3_t.reshape(-1, 1)),
        axis=1,
    )
    bb_ts = ak.Array(bb_ts)
    fj_pts = ak.Array(fj_pts)
    bh_masks = np.concatenate(
        (bh1_mask.reshape(-1, 1), bh2_mask.reshape(-1, 1), bh3_mask.reshape(-1, 1)),
        axis=1,
    )
    bh_masks = ak.Array(bh_masks)
    bh_pts = np.concatenate(
        (bh1_pt.reshape(-1, 1), bh2_pt.reshape(-1, 1), bh3_pt.reshape(-1, 1)), axis=1
    )
    bh_pts = ak.Array(bh_pts)

    # select predictions and targets
    bb_ts_selected, targetH_selected_pts = sel_target_by_mask(bb_ts, bh_pts, bh_masks)
    bb_ps_selected = sel_pred_by_dp_ap(dps, aps, bb_ps, dp_cut, ap_cut)

    # generate correct/retrieved LUT for pred/target respectively
    LUT_pred = gen_pred_LUT(bb_ps_selected, bb_ts_selected, fj_pts)
    LUT_target = gen_target_LUT(bb_ps_selected, bb_ts_selected, targetH_selected_pts)

    return LUT_pred, LUT_target

In [None]:
LUT_pred_spanet, LUT_target_spanet = parse_pred_w_target(test_h5, s_h5)

In [None]:
LUT_pred_baseline, LUT_target_baseline = parse_pred_w_target(test_h5, b_h5)

In [None]:
# calculate efficiency
# if bins=None, put all data in a single bin
def calc_eff(LUT_pred, bins):
    predHs = [predH for event in LUT_pred for predH in event]
    predHs = np.array(predHs)

    predHs_inds = np.digitize(predHs[:, 1], bins)

    correctTruth_per_bin = []
    for bin_i in range(1, len(bins)):
        correctTruth_per_bin.append(predHs[:, 0][predHs_inds == bin_i])
    correctTruth_per_bin = ak.Array(correctTruth_per_bin)

    means = ak.mean(correctTruth_per_bin, axis=-1)

    errs = np.abs(
        clopper_pearson_interval(
            num=ak.sum(correctTruth_per_bin, axis=-1),
            denom=ak.num(correctTruth_per_bin, axis=-1),
        )
        - means
    )

    return means, errs

In [None]:
# calculate purity
def calc_pur(LUT_target, bins):
    targetHs = [targetH for event in LUT_target for targetH in event]
    targetHs = np.array(targetHs)

    targetHs_inds = np.digitize(targetHs[:, 1], bins)

    correctTruth_per_bin = []
    for bin_i in range(1, len(bins)):
        correctTruth_per_bin.append(targetHs[:, 0][targetHs_inds == bin_i])
    correctTruth_per_bin = ak.Array(correctTruth_per_bin)

    means = ak.mean(correctTruth_per_bin, axis=-1)

    errs = np.abs(
        clopper_pearson_interval(
            num=ak.sum(correctTruth_per_bin, axis=-1),
            denom=ak.num(correctTruth_per_bin, axis=-1),
        )
        - means
    )

    return means, errs

In [None]:
bins = np.arange(200, 1000, 100)
bin_centers = [(bins[i] + bins[i + 1]) / 2 for i in range(bins.size - 1)]
xerr = (bins[1] - bins[0]) / 2 * np.ones(bins.shape[0] - 1)

In [None]:
dp_cut = 0.85
LUT_pred_spanet, LUT_target_spanet = parse_pred_w_target(test_h5, s_h5, dp_cut=dp_cut)
LUT_pred_baseline, LUT_target_baseline = parse_pred_w_target(
    test_h5, b_h5, dp_cut=dp_cut
)
eff_s, efferr_s = calc_eff(LUT_pred_spanet, bins)
eff_b, efferr_b = calc_eff(LUT_pred_baseline, bins)
pur_s, purerr_s = calc_pur(LUT_target_spanet, bins)
pur_b, purerr_b = calc_pur(LUT_target_baseline, bins)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].errorbar(
    x=bin_centers, y=eff_s, yerr=efferr_s, fmt="o", capsize=5, label="SPANet"
)
ax[0].errorbar(
    x=bin_centers, y=eff_b, yerr=efferr_b, fmt="x", capsize=5, label="Baseline"
)
ax[0].set(xlabel=r"Reconstructed H $p_T$ [GeV]", ylabel=r"Matching purity")
ax[0].set_ylim(0, 1)
ax[0].set_xlim(300, 900)

ax[1].errorbar(
    x=bin_centers, y=pur_s, yerr=purerr_s, fmt="o", capsize=5, label="SPANet"
)
ax[1].errorbar(
    x=bin_centers, y=pur_b, yerr=purerr_b, fmt="x", capsize=5, label="Baseline"
)
ax[1].set(xlabel=r"True H $p_T$ [GeV]", ylabel=r"Matching efficiency")
ax[1].set_ylim(0, 1)
ax[1].set_xlim(300, 900)


ax[0].legend()
# ax[1].legend()
plt.tight_layout()
plt.savefig("spanet_baseline_pur_eff.pdf")