In [None]:
import uproot
import awkward as ak
from coffea import nanoevents
from coffea.nanoevents.methods.base import NanoEventsArray
from coffea.analysis_tools import Weights, PackedSelection
from coffea.nanoevents.methods import nanoaod
from coffea.nanoevents.methods import vector
from coffea.lookup_tools.dense_lookup import dense_lookup

ak.behavior.update(vector.behavior)

import pickle, json, gzip
import numpy as np

from typing import Optional, List, Dict
from copy import copy

import matplotlib.pyplot as plt
import mplhep as hep
from matplotlib import colors

from tqdm import tqdm

import os

# import corrections
import correctionlib

# import utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
d_PDGID = 1
u_PDGID = 2
s_PDGID = 3
c_PDGID = 4
b_PDGID = 5
g_PDGID = 21
TOP_PDGID = 6

ELE_PDGID = 11
vELE_PDGID = 12
MU_PDGID = 13
vMU_PDGID = 14
TAU_PDGID = 15
vTAU_PDGID = 16

G_PDGID = 22
Z_PDGID = 23
W_PDGID = 24
HIGGS_PDGID = 25
Y_PDGID = 35

b_PDGIDS = [511, 521, 523]

GRAV_PDGID = 39

GEN_FLAGS = ["fromHardProcess", "isLastCopy"]

In [None]:
plot_dir = "../../../plots/ScaleFactors/Nov23"
_ = os.system(f"mkdir -p {plot_dir}")

In [None]:
P4 = {
    "eta": "Eta",
    "phi": "Phi",
    "mass": "Mass",
    "pt": "Pt",
}


PAD_VAL = -99999


skim_vars = {
    "FatJet": {
        **P4,
        "msoftdrop": "Msd",
        "particleNetMD_QCD": "ParticleNetMD_QCD",
        "particleNetMD_Xbb": "ParticleNetMD_Xbb",
        "particleNet_H4qvsQCD": "ParticleNet_Th4q",
        "particleNet_mass": "ParticleNetMass",
    },
    "GenHiggs": P4,
    "other": {"MET_pt": "MET_pt"},
}


def pad_val(
    arr: ak.Array,
    target: int,
    value: float = PAD_VAL,
    axis: int = 0,
    to_numpy: bool = True,
    clip: bool = True,
):
    """
    pads awkward array up to ``target`` index along axis ``axis`` with value ``value``,
    optionally converts to numpy array
    """
    ret = ak.fill_none(ak.pad_none(arr, target, axis=axis, clip=clip), value, axis=axis)
    return ret.to_numpy() if to_numpy else ret


def add_selection(
    name: str,
    sel: np.ndarray,
    selection: PackedSelection,
    cutflow: dict,
    isData: bool,
    genWeights: ak.Array = None,
):
    """adds selection to PackedSelection object and the cutflow dictionary"""
    if isinstance(sel, ak.Array):
        sel = sel.to_numpy()

    selection.add(name, sel.astype(bool))
    cutflow[name] = (
        np.sum(selection.all(*selection.names))
        if isData
        # add up genWeights for MC
        else np.sum(genWeights[selection.all(*selection.names)])
    )

In [None]:
events = nanoevents.NanoEventsFactory.from_root(
    # "../../../../data/2017_UL_nano/NMSSM_XToYH_MX1000_MY400_HTo2bYTo2W_hadronicDecay/nano_mc2017_101.root",
    # "../../../../data/2017_UL_nano/NMSSM_XToYHTo2W2BTo4Q2B_MX-3000_MY-190/nano_mc2016pre_13.root",
    # "/eos/uscms//store/user/lpcpfnano/rkansal/v2_3/2016APV/XHY/NMSSM_XToYHTo2W2BTo4Q2B_MX-3000_MY-190_TuneCP5_13TeV-madgraph-pythia8/NMSSM_XToYHTo2W2BTo4Q2B_MX-3000_MY-190/230323_173705/0000/nano_mc2016pre_13.root",
    # "/eos/uscms//store/user/lpcpfnano/rkansal/v2_3/2016/XHY/NMSSM_XToYHTo2W2BTo4Q2B_MX-3000_MY-190_TuneCP5_13TeV-madgraph-pythia8/NMSSM_XToYHTo2W2BTo4Q2B_MX-3000_MY-190/230323_193051/0000/nano_mc2016post_1-3.root",
    "/eos/uscms/store/user/lpcpfnano/ammitra/v2_3/2017/XHY/NMSSM_XToYHTo2W2BTo4Q2B_MX-1600_MY-125_TuneCP5_13TeV-madgraph-pythia8/NMSSM_XToYHTo2W2BTo4Q2B_MX-1600_MY-125/230323_184451/0000/nano_mc2017_23.root",
    schemaclass=nanoevents.NanoAODSchema,
).events()

In [None]:
isData = False
signGenWeights = None if isData else np.sign(events["genWeight"])
n_events = len(events) if isData else int(np.sum(signGenWeights))
selection = PackedSelection()

cutflow = {}
cutflow["all"] = len(events)

In [None]:
skim_vars = P4
fatjets = events.FatJet

higgs = events.GenPart[
    (abs(events.GenPart.pdgId) == HIGGS_PDGID) * events.GenPart.hasFlags(GEN_FLAGS)
]
GenHiggsVars = {f"GenHiggs{key}": higgs[var].to_numpy() for (var, key) in skim_vars.items()}
is_bb = abs(higgs.children.pdgId) == b_PDGID
has_bb = ak.sum(ak.flatten(is_bb, axis=2), axis=1) == 2

bb = ak.flatten(higgs.children[is_bb], axis=2)
GenbbVars = {f"Genbb{key}": pad_val(bb[var], 2, axis=1) for (var, key) in skim_vars.items()}

# gen Y and kids
Ys = events.GenPart[(abs(events.GenPart.pdgId) == Y_PDGID) * events.GenPart.hasFlags(GEN_FLAGS)]
GenYVars = {f"GenY{key}": Ys[var].to_numpy() for (var, key) in skim_vars.items()}
is_VV = (abs(Ys.children.pdgId) == W_PDGID) + (abs(Ys.children.pdgId) == Z_PDGID)
has_VV = ak.sum(ak.flatten(is_VV, axis=2), axis=1) == 2

add_selection("has_bbVV", has_bb * has_VV, selection, cutflow, False, signGenWeights)

VV = ak.flatten(Ys.children[is_VV], axis=2)
GenVVVars = {f"GenVV{key}": VV[var][:, :2].to_numpy() for (var, key) in skim_vars.items()}

VV_children = VV.children

# iterate through the children in photon scattering events to get final daughter quarks
for i in range(5):
    photon_mask = ak.any(ak.flatten(abs(VV_children.pdgId), axis=2) == G_PDGID, axis=1)
    if not np.any(photon_mask):
        break

    # use a where condition to get next layer of children for photon scattering events
    VV_children = ak.where(photon_mask, ak.flatten(VV_children.children, axis=3), VV_children)

quarks = abs(VV_children.pdgId) <= b_PDGID
all_q = ak.all(ak.all(quarks, axis=2), axis=1)
add_selection("all_q", all_q, selection, cutflow, False, signGenWeights)

V_has_2q = ak.count(VV_children.pdgId, axis=2) == 2
has_4q = ak.values_astype(ak.prod(V_has_2q, axis=1), bool)
add_selection("has_4q", has_4q, selection, cutflow, False, signGenWeights)

Gen4qVars = {
    f"Gen4q{key}": ak.to_numpy(
        ak.fill_none(
            ak.pad_none(ak.pad_none(VV_children[var], 2, axis=1, clip=True), 2, axis=2, clip=True),
            PAD_VAL,
        )
    )
    for (var, key) in skim_vars.items()
}

# fatjet gen matching
Hbb = ak.pad_none(higgs, 1, axis=1, clip=True)[:, 0]
HVV = ak.pad_none(Ys, 1, axis=1, clip=True)[:, 0]

bbdr = fatjets[:, :2].delta_r(Hbb)
vvdr = fatjets[:, :2].delta_r(HVV)

match_dR = 0.8
Hbb_match = bbdr <= match_dR
HVV_match = vvdr <= match_dR

# overlap removal - in the case where fatjet is matched to both, match it only to the closest Higgs
Hbb_match = (Hbb_match * ~HVV_match) + (bbdr <= vvdr) * (Hbb_match * HVV_match)
HVV_match = (HVV_match * ~Hbb_match) + (bbdr > vvdr) * (Hbb_match * HVV_match)

VVJets = ak.pad_none(fatjets[HVV_match], 1, axis=1)[:, 0]
quarkdrs = ak.flatten(VVJets.delta_r(VV_children), axis=2)
num_prongs = ak.sum(quarkdrs < match_dR, axis=1)

GenMatchingVars = {
    "ak8FatJetHbb": pad_val(Hbb_match, 2, axis=1),
    "ak8FatJetHVV": pad_val(HVV_match, 2, axis=1),
    "ak8FatJetHVVNumProngs": ak.fill_none(num_prongs, PAD_VAL).to_numpy(),
}

genbb, gen4q = bb, ak.flatten(VV_children, axis=2)

In [None]:
fatjet_idx = 0
ak8_pfcands = events.FatJetPFCands
ak8_pfcands = ak8_pfcands[ak8_pfcands.jetIdx == fatjet_idx]
pfcands0 = events.PFCands[ak8_pfcands.pFCandsIdx]

fatjet_idx = 1
ak8_pfcands = events.FatJetPFCands
ak8_pfcands = ak8_pfcands[ak8_pfcands.jetIdx == fatjet_idx]
pfcands1 = events.PFCands[ak8_pfcands.pFCandsIdx]

In [None]:
sel = np.prod(
    pad_val(
        (fatjets.pt > 250) * (np.abs(fatjets.eta) < 2.4) * (fatjets.particleNet_mass >= 50),
        2,
        False,
        axis=1,
    ),
    axis=1,
).astype(bool)

In [None]:
pfcands0[sel]

In [None]:
_ = plt.hist(
    ak.sum(pfcands1[sel].pdgId == 22, axis=1) / ak.count(pfcands1[sel].pdgId, axis=1),
    np.linspace(0.8, 1, 101),
    histtype="step",
)

In [None]:
_ = plt.hist(
    ak.sum(pfcands0.pdgId == 22, axis=1) / ak.count(pfcands0.pdgId, axis=1),
    np.linspace(0, 1, 101),
    histtype="step",
)

In [None]:
_ = plt.hist(ak.count(pfcands0[sel].pdgId, axis=1), np.linspace(-0.5, 20.5, 22), histtype="step")

In [None]:
_ = plt.hist(ak.count(pfcands1[sel].pdgId, axis=1), np.linspace(-0.5, 20.5, 22), histtype="step")

In [None]:
np.where(ak.count(pfcands0[sel].pdgId, axis=1) == 3)

In [None]:
event_idx = np.where(ak.count(pfcands0[sel].pdgId, axis=1) == 3)[0][0]

In [None]:
pfcands0[sel][event_idx].pdgId

In [None]:
pfcands0[sel][event_idx].pt

In [None]:
VV_children[sel][event_idx].pdgId

In [None]:
VV_children[sel][event_idx].children.children.pdgId

In [None]:
ak.flatten(ak.flatten(VV_children[sel][event_idx].children.pdgId, axis=-1), axis=-1)

In [None]:
ak.flatten(
    ak.flatten(ak.flatten(VV_children[sel][event_idx].children.children.pdgId, axis=-1), axis=-1),
    axis=-1,
)

In [None]:
ak.flatten(
    ak.flatten(
        ak.flatten(
            ak.flatten(VV_children[sel][event_idx].children.children.children.pdgId, axis=-1),
            axis=-1,
        ),
        axis=-1,
    ),
    axis=-1,
)

In [None]:
for i in events.FatJet[597][1].delta_r(events.PFCands[597][events.FatJetPFCands[597].pFCandsIdx]):
    print(i)

In [None]:
events.FatJet[sel][event_idx][1].delta_r(VV_children[sel][event_idx])

In [None]:
events.FatJet[597][1].delta_r(Hbb[597])

In [None]:
events.FatJet[597][1].delta_r(HVV[597])

In [None]:
events.FatJet[597][1].delta_r(VV_children[597])

In [None]:
events.FatJet[597][1].delta_r(pfcands1[597])

In [None]:
gen4q

In [None]:
fatjet_idx = 0
ak8_pfcands = events.FatJetPFCands
ak8_pfcands = ak8_pfcands[ak8_pfcands.jetIdx == fatjet_idx]
pfcands0 = events.PFCands[ak8_pfcands.pFCandsIdx]

fatjet_idx = 1
ak8_pfcands = events.FatJetPFCands
ak8_pfcands = ak8_pfcands[ak8_pfcands.jetIdx == fatjet_idx]
pfcands1 = events.PFCands[ak8_pfcands.pFCandsIdx]

In [None]:
sel = np.prod(
    pad_val(
        (fatjets.pt > 250) * (np.abs(fatjets.eta) < 2.4) * (fatjets.particleNet_mass >= 50),
        2,
        False,
        axis=1,
    ),
    axis=1,
)

In [None]:
np.min(ak.count(pfcands1[sel].pt, axis=1))

In [None]:
fatjet_idx = 0
ak8_pfcands = events.FatJetPFCands
ak8_pfcands = ak8_pfcands[ak8_pfcands.jetIdx == fatjet_idx]
pfcands0 = events.PFCands[ak8_pfcands.pFCandsIdx]

fatjet_idx = 1
ak8_pfcands = events.FatJetPFCands
ak8_pfcands = ak8_pfcands[ak8_pfcands.jetIdx == fatjet_idx]
pfcands1 = events.PFCands[ak8_pfcands.pFCandsIdx]