In [None]:
import logging

import awkward as ak
import h5py as h5
import numpy as np
import vector
vector.register_awkward()

logging.basicConfig(level=logging.INFO)

from hist.intervals import clopper_pearson_interval

import matplotlib.pyplot as plt
import mplhep as hep

plt.style.use(hep.style.CMS)

N_AK5_JETS = 10
dp_cut = 0.85

In [None]:
# read test target file
test_file = "hhh_test.h5"
test_h5 = h5.File(test_file)

# read baseline prediction
baseline_file = "pred_baseline.h5"
b_h5 = h5.File(baseline_file)

# read spanet prediction
spanet_file = "pred_v53.h5"
s_h5 = h5.File(spanet_file)

In [None]:
def sel_pred_by_dp_ap(dps, aps, b_ps, q1_ps, q2_ps, dp_cut, ap_cut):
    # parse predicted top assignment by DP 
    dp_filter = dps > dp_cut # minimal detection probability to be considered a top
    ap_filter = aps > ap_cut # minimal assignment probability to be considered a top
    filter = ap_filter & dp_filter

    # b jet
    b_ps_passed = b_ps.mask[filter] 
    b_ps_passed = ak.drop_none(b_ps_passed)
    # q1 jets
    q1_ps_passed = q1_ps.mask[filter]
    q1_ps_passed = ak.drop_none(q1_ps_passed)
    # q2 jets
    q2_ps_passed = q2_ps.mask[filter]
    q2_ps_passed = ak.drop_none(q2_ps_passed)

    return b_ps_passed, q1_ps_passed, q2_ps_passed

In [None]:
def sel_target_by_mask(b_ts, q1_ts, q2_ts, FRt_pts, FRt_masks): # looks at boosted baseline
    # b jets
    b_ts_selected = b_ts.mask[FRt_masks] # applying mask so that only events with the minimal number of jets are selected
    b_ts_selected = ak.drop_none(b_ts_selected)
    # q1 jets
    q1_ts_selected = q1_ts.mask[FRt_masks] # applying mask so that only events with the minimal number of jets are selected
    q1_ts_selected = ak.drop_none(q1_ts_selected)
    # q2 jets
    q2_ts_selected = q2_ts.mask[FRt_masks] # applying mask so that only events with the minimal number of jets are selected
    q2_ts_selected = ak.drop_none(q2_ts_selected)

    FRt_selected_pts = FRt_pts.mask[FRt_masks]
    FRt_selected_pts = ak.drop_none(FRt_selected_pts)

    return b_ts_selected, q1_ts_selected, q2_ts_selected, FRt_selected_pts

In [None]:
# A pred look up table is in shape
# [event,
#    pred_t,
#       [correct, FRt_pt]]
def gen_pred_LUT(b_ps_selected, q1_ps_selected, q2_ps_selected, b_ts_selected, q1_ts_selected, q2_ts_selected, jet4mom):
    LUT = []
    # for each event
    for b_t_event, q1_t_event, q2_t_event, b_p_event, q1_p_event, q2_p_event, jet4mom_event in zip(
        b_ps_selected, q1_ps_selected, q2_ps_selected, b_ts_selected, q1_ts_selected, q2_ts_selected, jet4mom
    ):
        # for each predicted {b,q1,q1} assignment, check if any FRt have a same {b,q1,q2} assignment
        LUT_event = []
        for b_p, q1_p, q2_p in zip(b_p_event, q1_p_event, q2_p_event):
            correct = 0
            predTop_pt = (jet4mom_event[b_p] + jet4mom_event[q1_p] + jet4mom_event[q1_p]).pt
            for b_t, q1_t, q2_t in zip(b_t_event, q1_t_event, q2_t_event):
                if b_p == b_t and q1_p == q1_t and q2_p == q2_t:
                    correct = 1
            LUT_event.append([correct, predTop_pt])
        LUT.append(LUT_event)
    return LUT

# A pred look up table is in shape
# [event,
#    pred_q,
#       [correct, jet_pt]]
def gen_pred_qLUT(q_ps_selected, q_ts_selected, pts):
    qLUT = []
    # for each event
    for q_t_event, q_p_event, pt_event in zip(
        q_ps_selected, q_ts_selected, pts
    ):
        # for each predicted bb assignment, check if any target H have a same bb assignment
        qLUT_event = []
        for q_p in q_p_event:
            correct = 0
            predq_pt = pt_event[q_p]
            for q_t in q_t_event:
                if q_p == q_t:
                    correct = 1
            qLUT_event.append([correct, predq_pt])
        qLUT.append(qLUT_event)
    return qLUT

In [None]:
# A target look up table is in shape
# [event,
#    target_t,
#        [retrieved, FRt_pt]]
def gen_target_LUT(b_ps_selected, q1_ps_selected, q2_ps_selected, b_ts_selected, q1_ts_selected, q2_ts_selected, FRt_pts):
    LUT = []
    # for each event
    for b_t_event, q1_t_event, q2_t_event, b_p_event, q1_p_event, q2_p_event, FRt_pt_event in zip(
        b_ps_selected, q1_ps_selected, q2_ps_selected, b_ts_selected, q1_ts_selected, q2_ts_selected, FRt_pts
    ):
        # for each predicted {b,q1,q1} assignment, check if any FRt have a same {b,q1,q2} assignment
        LUT_event = []
        for i, (b_t, q1_t, q2_t) in enumerate(zip(b_t_event, q1_t_event, q2_t_event)):
            retrieved = 0
            FRt_pt = FRt_pt_event[i]
            for b_p, q1_p, q2_p in zip(b_p_event, q1_p_event, q2_p_event):
                if b_p == b_t and q1_p == q1_t and q2_p == q2_t:
                    retrieved = 1
            LUT_event.append([retrieved, FRt_pt])
        LUT.append(LUT_event)
    return LUT

# A target look up table is in shape
# [event,
#    target_q,
#        [retrieved, jet_pt]]
def gen_target_qLUT(q_ps_selected, q_ts_selected, pts):
    qLUT = []
    # for each event
    for q_t_event, q_p_event, pt_event in zip(
        q_ps_selected, q_ts_selected, pts
    ):
        # for each predicted bb assignment, check if any target H have a same bb assignment
        qLUT_event = []
        for q_t in q_t_event:
            retrieved = 0
            predq_pt = pt_event[q_p]
            for q_p in q_p_event:
                if q_p == q_t:
                    retrieved = 1
            qLUT_event.append([retrieved, predq_pt])
        qLUT.append(qLUT_event)
    return qLUT

In [None]:
# generate pred/target LUT
# each entry corresponds to [recot correct or not, reco t pt]
# or
# [targett retrieved or not, target t pt]
def parse_pred_w_target(testfile, predfile, dp_cut=0.5, ap_cut=1 / 13):
    # Collect t pt, mask, target and predicted jet and fjets for 3 ts in each event
    # t pt
    FRt1_pt = np.array(testfile["TARGETS"]["FRt1"]["pt"])
    FRt2_pt = np.array(testfile["TARGETS"]["FRt2"]["pt"])

    # mask
    FRt1_mask = np.array(testfile["TARGETS"]["FRt1"]["mask"])
    FRt2_mask = np.array(testfile["TARGETS"]["FRt2"]["mask"])

    # target jets
    b_FRt1_t = np.array(testfile["TARGETS"]["FRt1"]["b"])
    b_FRt2_t = np.array(testfile["TARGETS"]["FRt2"]["b"])

    q1_FRt1_t = np.array(testfile["TARGETS"]["FRt1"]["q1"])
    q1_FRt2_t = np.array(testfile["TARGETS"]["FRt2"]["q1"])

    q2_FRt1_t = np.array(testfile["TARGETS"]["FRt1"]["q2"])
    q2_FRt2_t = np.array(testfile["TARGETS"]["FRt2"]["q2"])

    # pred jets
    b_FRt1_p = np.array(predfile["TARGETS"]["FRt1"]["b"])
    b_FRt2_p = np.array(predfile["TARGETS"]["FRt2"]["b"])

    q1_FRt1_p = np.array(predfile["TARGETS"]["FRt1"]["q1"])
    q1_FRt2_p = np.array(predfile["TARGETS"]["FRt2"]["q1"])

    q2_FRt1_p = np.array(predfile["TARGETS"]["FRt1"]["q2"])
    q2_FRt2_p = np.array(predfile["TARGETS"]["FRt2"]["q2"])

    # jet detection probability
    dp_FRt1 = np.array(predfile["TARGETS"]["FRt1"]["detection_probability"])
    dp_FRt2 = np.array(predfile["TARGETS"]["FRt2"]["detection_probability"])

    # jet assignment probability
    ap_FRt1 = np.array(predfile["TARGETS"]["FRt1"]["assignment_probability"])
    ap_FRt2 = np.array(predfile["TARGETS"]["FRt2"]["assignment_probability"])

    # collect jet pt
    pts = np.array(testfile["INPUTS"]["Jets"]["pt"])
    etas = np.array(testfile["INPUTS"]["Jets"]["eta"])
    phis = np.array(testfile["INPUTS"]["Jets"]["phi"])
    masses = np.array(testfile["INPUTS"]["Jets"]["mass"])
    jet4moms = ak.zip(
        {
            'rho': pts,
            'phi': phis,
            'eta': etas,
            'tau': masses
        }, with_name='Momentum4D'
    )

    # convert some arrays to ak array
    # true Fully Resolved tops pt and masks
    FRt_masks = np.concatenate(
        (FRt1_mask.reshape(-1, 1), FRt2_mask.reshape(-1, 1)), axis=1
    )
    FRt_masks = ak.Array(FRt_masks)
    FRt_pts = np.concatenate(
        (FRt1_pt.reshape(-1, 1), FRt2_pt.reshape(-1, 1)), axis=1
    )
    FRt_pts = ak.Array(FRt_pts)

    # true target jets
    b_ts = np.concatenate(
        (b_FRt1_t.reshape(-1, 1), b_FRt2_t.reshape(-1, 1)), axis=1
    )
    b_ts = ak.Array(b_ts)
    q1_ts = np.concatenate(
        (q1_FRt1_t.reshape(-1, 1), q1_FRt2_t.reshape(-1, 1)), axis=1
    )
    q1_ts = ak.Array(q1_ts)
    q2_ts = np.concatenate(
        (q2_FRt1_t.reshape(-1, 1), q2_FRt2_t.reshape(-1, 1)), axis=1
    )
    q2_ts = ak.Array(q2_ts)

    # predicted jets
    b_ps = np.concatenate(
        (b_FRt1_p.reshape(-1, 1), b_FRt2_p.reshape(-1, 1)), axis=1
    )
    b_ps = ak.Array(b_ps)
    q1_ps = np.concatenate(
        (q1_FRt1_p.reshape(-1, 1), q1_FRt2_p.reshape(-1, 1)), axis=1
    )
    q1_ps = ak.Array(q1_ps)
    q2_ps = np.concatenate(
        (q2_FRt1_p.reshape(-1, 1), q2_FRt2_p.reshape(-1, 1)), axis=1
    )
    q2_ps = ak.Array(q2_ps)

    # top detection and assignment probabilities
    dps = np.concatenate(
        (dp_FRt1.reshape(-1, 1), dp_FRt2.reshape(-1, 1)), axis=1
    )
    dps = ak.Array(dps)
    aps = np.concatenate(
        (ap_FRt1.reshape(-1, 1), ap_FRt2.reshape(-1, 1)), axis=1
    )
    aps = ak.Array(aps)
    
    pts = ak.Array(pts)

    # select predictions and targets
    b_ts_selected, q1_ts_selected, q2_ts_selected, FRt_selected_pts = sel_target_by_mask(b_ts, q1_ts, q2_ts, FRt_pts, FRt_masks)
    b_ps_selected, q1_ps_selected, q2_ps_selected = sel_pred_by_dp_ap(dps, aps, b_ps, q1_ps, q2_ps, dp_cut, ap_cut)

    # generate correct/retrieved LUT for pred/target respectively
    LUT_pred = gen_pred_LUT(b_ps_selected, q1_ps_selected, q2_ps_selected, b_ts_selected, q1_ts_selected, q2_ts_selected, jet4moms)
    LUT_target = gen_target_LUT(b_ps_selected, q1_ps_selected, q2_ps_selected, b_ts_selected, q1_ts_selected, q2_ts_selected, FRt_selected_pts)

    bLUT_pred = gen_pred_qLUT(b_ps_selected, b_ts_selected, pts)
    bLUT_target = gen_target_qLUT(b_ps_selected, b_ts_selected, pts)
    q1LUT_pred = gen_pred_qLUT(q1_ps_selected, q1_ts_selected, pts)
    q1LUT_target = gen_target_qLUT(q1_ps_selected, q1_ts_selected, pts)
    q2LUT_pred = gen_pred_qLUT(q2_ps_selected, q2_ts_selected, pts)
    q2LUT_target = gen_target_qLUT(q2_ps_selected, q2_ts_selected, pts)


    return LUT_pred, LUT_target, bLUT_pred, bLUT_target, q1LUT_pred, q1LUT_target, q2LUT_pred, q2LUT_target

In [None]:
(
    LUT_pred_spanet, LUT_target_spanet,
    bLUT_pred_spanet, bLUT_target_spanet,
    q1LUT_pred_spanet, q1LUT_target_spanet,
    q2LUT_pred_spanet, q2LUT_target_spanet,
) = parse_pred_w_target(test_h5, s_h5, dp_cut=dp_cut)

In [None]:
(
    LUT_pred_baseline, LUT_target_baseline,
    bLUT_pred_baseline, bLUT_target_baseline,
    q1LUT_pred_baseline, q1LUT_target_baseline,
    q2LUT_pred_baseline, q2LUT_target_baseline,
) = parse_pred_w_target(test_h5, b_h5, dp_cut=dp_cut)

In [None]:
# calculate efficiency (percentage of the time that the model's prediction is correct)
# if bins=None, put all data in a single bin
def calc_eff(LUT_pred, bins):
    predTops = [predTop for event in LUT_pred for predTop in event]
    predTops = np.array(predTops)

    predTops_inds = np.digitize(predTops[:, 1], bins)

    correctTruth_per_bin = []
    for bin_i in range(1, len(bins)):
        correctTruth_per_bin.append(predTops[:, 0][predTops_inds == bin_i])
    correctTruth_per_bin = ak.Array(correctTruth_per_bin)

    means = ak.mean(correctTruth_per_bin, axis=-1)

    errs = np.abs(
        clopper_pearson_interval(
            num=ak.sum(correctTruth_per_bin, axis=-1),
            denom=ak.num(correctTruth_per_bin, axis=-1),
        )
        - means
    )

    return means, errs

In [None]:
# # calculate purity (percentage of the time that the model correctly found a target particle)
# def calc_pur(LUT_target, bins):
#     targetTops = [targetTop for event in LUT_target for targetTop in event]
#     targetTops = np.array(targetTops)

#     targetTops_inds = np.digitize(targetTops[:, 1], bins)

#     correctTruth_per_bin = []
#     for bin_i in range(1, len(bins)):
#         correctTruth_per_bin.append(targetTops[:, 0][targetTops_inds == bin_i])
#     correctTruth_per_bin = ak.Array(correctTruth_per_bin)

#     means = ak.mean(correctTruth_per_bin, axis=-1)

#     errs = np.abs(
#         clopper_pearson_interval(
#             num=ak.sum(correctTruth_per_bin, axis=-1),
#             denom=ak.num(correctTruth_per_bin, axis=-1),
#         )
#         - means
#     )

#     return means, errs

In [None]:
bins = np.arange(200, 1000, 100)
bin_centers = [(bins[i] + bins[i + 1]) / 2 for i in range(bins.size - 1)]
xerr = (bins[1] - bins[0]) / 2 * np.ones(bins.shape[0] - 1)

In [None]:
# dp_cut = 0.85
# LUT_pred_spanet, LUT_target_spanet = parse_pred_w_target(test_h5, s_h5, dp_cut=dp_cut)
# LUT_pred_baseline, LUT_target_baseline = parse_pred_w_target(
#     test_h5, b_h5, dp_cut=dp_cut
# )
# eff_s, efferr_s = calc_eff(LUT_pred_spanet, bins)
# eff_b, efferr_b = calc_eff(LUT_pred_baseline, bins)
# pur_s, purerr_s = calc_pur(LUT_target_spanet, bins)
# pur_b, purerr_b = calc_pur(LUT_target_baseline, bins)

# full tops efficincies
eff_s, efferr_s = calc_eff(LUT_pred_spanet, bins)
eff_b, efferr_b = calc_eff(LUT_pred_baseline, bins)
pur_s, purerr_s = calc_eff(LUT_target_spanet, bins)
pur_b, purerr_b = calc_eff(LUT_target_baseline, bins)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].errorbar(
    x=bin_centers, y=eff_s, yerr=efferr_s, fmt="o", capsize=5, label="SPANet"
)
ax[0].errorbar(
    x=bin_centers, y=eff_b, yerr=efferr_b, fmt="x", capsize=5, label="Baseline"
)
ax[0].set(xlabel=r"Reco top $p_T$ [GeV]", ylabel=r"Matching purity")
ax[0].set_ylim(0, 1)
ax[0].set_xlim(300, 900)

ax[1].errorbar(
    x=bin_centers, y=pur_s, yerr=purerr_s, fmt="o", capsize=5, label="SPANet"
)
ax[1].errorbar(
    x=bin_centers, y=pur_b, yerr=purerr_b, fmt="x", capsize=5, label="Baseline"
)
ax[1].set(xlabel=r"True top $p_T$ [GeV]", ylabel=r"Matching efficiency")
ax[1].set_ylim(0, 1)
ax[1].set_xlim(300, 900)

ax[0].legend()
# ax[1].legend()
plt.tight_layout()
plt.savefig("spanet_baseline_pur_eff.pdf")

In [None]:
# individual jet efficiencies
b_eff_s, b_efferr_s = calc_eff(bLUT_pred_spanet, bins)
b_eff_b, b_efferr_b = calc_eff(bLUT_pred_baseline, bins)
b_pur_s, b_purerr_s = calc_eff(bLUT_target_spanet, bins)
b_pur_b, b_purerr_b = calc_eff(bLUT_target_baseline, bins)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].errorbar(
    x=bin_centers, y=b_eff_s, yerr=b_efferr_s, fmt="o", capsize=5, label="SPANet"
)
ax[0].errorbar(
    x=bin_centers, y=b_eff_b, yerr=b_efferr_b, fmt="x", capsize=5, label="Baseline"
)
ax[0].set(xlabel=r"bjet $p_T$ [GeV]", ylabel=r"Matching purity")
ax[0].set_ylim(0, 1)
ax[0].set_xlim(300, 900)

ax[1].errorbar(
    x=bin_centers, y=b_pur_s, yerr=b_purerr_s, fmt="o", capsize=5, label="SPANet"
)
ax[1].errorbar(
    x=bin_centers, y=b_pur_b, yerr=b_purerr_b, fmt="x", capsize=5, label="Baseline"
)
ax[1].set(xlabel=r"bjet $p_T$ [GeV]", ylabel=r"Matching efficiency")
ax[1].set_ylim(0, 1)
ax[1].set_xlim(300, 900)

ax[0].legend()
# ax[1].legend()
plt.tight_layout()
plt.savefig("spanet_baseline_bjet_pur_eff.pdf")

In [None]:
# individual jet efficiencies
q1_eff_s, q1_efferr_s = calc_eff(q1LUT_pred_spanet, bins)
q1_eff_b, q1_efferr_b = calc_eff(q1LUT_pred_baseline, bins)
q1_pur_s, q1_purerr_s = calc_eff(q1LUT_target_spanet, bins)
q1_pur_b, q1_purerr_b = calc_eff(q1LUT_target_baseline, bins)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].errorbar(
    x=bin_centers, y=eff_s, yerr=efferr_s, fmt="o", capsize=5, label="SPANet"
)
ax[0].errorbar(
    x=bin_centers, y=eff_b, yerr=efferr_b, fmt="x", capsize=5, label="Baseline"
)
ax[0].set(xlabel=r"q1-jet $p_T$ [GeV]", ylabel=r"Matching purity")
ax[0].set_ylim(0, 1)
ax[0].set_xlim(300, 900)

ax[1].errorbar(
    x=bin_centers, y=pur_s, yerr=purerr_s, fmt="o", capsize=5, label="SPANet"
)
ax[1].errorbar(
    x=bin_centers, y=pur_b, yerr=purerr_b, fmt="x", capsize=5, label="Baseline"
)
ax[1].set(xlabel=r"q1-jet $p_T$ [GeV]", ylabel=r"Matching efficiency")
ax[1].set_ylim(0, 1)
ax[1].set_xlim(300, 900)

ax[0].legend()
# ax[1].legend()
plt.tight_layout()
plt.savefig("spanet_baseline_pur_eff.pdf")

In [None]:
# individual jet efficiencies
q2_eff_s, q2_efferr_s = calc_eff(q2LUT_pred_spanet, bins)
q2_eff_b, q2_efferr_b = calc_eff(q2LUT_pred_baseline, bins)
q2_pur_s, q2_purerr_s = calc_eff(q2LUT_target_spanet, bins)
q2_pur_b, q2_purerr_b = calc_eff(q2LUT_target_baseline, bins)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].errorbar(
    x=bin_centers, y=eff_s, yerr=efferr_s, fmt="o", capsize=5, label="SPANet"
)
ax[0].errorbar(
    x=bin_centers, y=eff_b, yerr=efferr_b, fmt="x", capsize=5, label="Baseline"
)
ax[0].set(xlabel=r"q2-jet $p_T$ [GeV]", ylabel=r"Matching purity")
ax[0].set_ylim(0, 1)
ax[0].set_xlim(300, 900)

ax[1].errorbar(
    x=bin_centers, y=pur_s, yerr=purerr_s, fmt="o", capsize=5, label="SPANet"
)
ax[1].errorbar(
    x=bin_centers, y=pur_b, yerr=purerr_b, fmt="x", capsize=5, label="Baseline"
)
ax[1].set(xlabel=r"q2-jet $p_T$ [GeV]", ylabel=r"Matching efficiency")
ax[1].set_ylim(0, 1)
ax[1].set_xlim(300, 900)

ax[0].legend()
# ax[1].legend()
plt.tight_layout()
plt.savefig("spanet_baseline_pur_eff.pdf")