In [237]:
import glob
import os
import json
import pickle
import yaml

import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None
import pyarrow.parquet as pq
from sklearn.metrics import auc, roc_curve
from scipy.special import softmax

import hist as hist2
import matplotlib.pyplot as plt
import mplhep as hep

plt.style.use(hep.style.CMS)

import utils_farouk as utils
plt.rcParams.update({"font.size": 20})

#!/usr/bin/python

import glob
import json
import os
import pickle as pkl
import warnings

import hist as hist2
import numpy as np
import pandas as pd

In [75]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [76]:
! ls ../Apr12_presel_2017

[34mDYJetsToLL_Pt-100To250[m[m                  [34mSingleElectron_Run2017F[m[m
[34mDYJetsToLL_Pt-250To400[m[m                  [34mSingleMuon_Run2017B[m[m
[34mDYJetsToLL_Pt-400To650[m[m                  [34mSingleMuon_Run2017C[m[m
[34mDYJetsToLL_Pt-50To100[m[m                   [34mSingleMuon_Run2017D[m[m
[34mDYJetsToLL_Pt-650ToInf[m[m                  [34mSingleMuon_Run2017E[m[m
[34mGluGluHToTauTau[m[m                         [34mSingleMuon_Run2017F[m[m
[34mGluGluHToWW_Pt-200ToInf_M-125[m[m           [34mTTTo2L2Nu[m[m
[34mGluGluZH_HToWW_ZTo2L_M-125[m[m              [34mTTToHadronic[m[m
[34mHWminusJ_HToWW_M-125[m[m                    [34mTTToSemiLeptonic[m[m
[34mHWplusJ_HToWW_M-125[m[m                     [34mVBFHToWWToLNuQQ_M-125_withDipoleRecoil[m[m
[34mHZJ_HToWW_M-125[m[m                         [34mWJetsToLNu_HT-100To200[m[m
[34mQCD_Pt_1000to1400[m[m                       [34mWJetsToLNu_HT-1200To2500

In [77]:

def get_sum_sumgenweight(pkl_files, year, sample):
    sum_sumgenweight = 0
    for ifile in pkl_files:
        # load and sum the sumgenweight of each
        with open(ifile, "rb") as f:
            metadata = pkl.load(f)
        sum_sumgenweight = sum_sumgenweight + metadata[sample][year]["sumgenweight"]
    return sum_sumgenweight


def get_xsecweight(pkl_files, year, sample, is_data, luminosity):
    if not is_data:
        # find xsection
        f = open("../fileset/xsec_pfnano.json")
        xsec = json.load(f)
        f.close()
        try:
            xsec = eval(str((xsec[sample])))
        except ValueError:
            print(f"sample {sample} doesn't have xsecs defined in xsec_pfnano.json so will skip it")
            return None

        # get overall weighting of events.. each event has a genweight...
        # sumgenweight sums over events in a chunk... sum_sumgenweight sums over chunks
        xsec_weight = (xsec * luminosity) / get_sum_sumgenweight(pkl_files, year, sample)
    else:
        xsec_weight = 1
    return xsec_weight


# Read parquets
- loads parquet dataframes and combine different pT-bins per sample
- applies the preselection specefied in the next cell
- saves the combined dataframe under `events[ch][sample]`

- axis1=samples
- axis2=reconstructed mass
for keys:
- cat1_sr: signal region (we can change the names depending on how many categories you have in the signal region
- qcd_cr: qcd control region
- wjets_cr: wjets control region
- tt_cr: ttbar control region

In [87]:

# new stuff
combine_samples = {
    # data
    # "SingleElectron_": "SingleElectron",
    "SingleElectron_": "Data",
    # "SingleMuon_": "SingleMuon_",
    "SingleMuon_": "Data",
    # "EGamma_": "EGamma",
    "EGamma_": "Data",
    # signal
    "GluGluHToWW_Pt-200ToInf_M-125": "HWW",
    "HToWW_M-125": "VH",
    "VBFHToWWToLNuQQ_M-125_withDipoleRecoil": "VBF",
    "ttHToNonbb_M125": "ttH",
    # bkg
    "QCD_Pt": "QCD",
    "DYJets": "DYJets",
    "WJetsToLNu_": "WJetsLNu",
    "JetsToQQ": "WZQQ",
    "TT": "TTbar",
    "ST_": "SingleTop",
    "WW": "Diboson",
    "WZ": "Diboson",
    "ZZ": "Diboson",
    "GluGluHToTauTau": "HTauTau",
}
signals = ["HWW", "ttH", "VH", "VBF"]

data_by_ch = {
    "ele": "SingleElectron",
    "mu": "SingleMuon",
}

weights = {
    "mu": {
        "weight_genweight": 1,
        "weight_L1Prefiring": 1,
        "weight_pileup": 1,
        "weight_trigger_iso_muon": 1,
        "weight_trigger_noniso_muon": 1,
        "weight_isolation_muon": 1,
        "weight_id_muon": 1,
        "weight_vjets_nominal": 1,
    },
    "ele": {
        "weight_genweight": 1,
        "weight_L1Prefiring": 1,
        "weight_pileup": 1,
        "weight_trigger_electron": 1,
        "weight_reco_electron": 1,
        "weight_id_electron": 1,
        "weight_vjets_nominal": 1,
    },
}


# tagger definitions
def disc_score(df, sigs, bkgs):
    num = df[sigs].sum(axis=1)
    den = df[sigs].sum(axis=1) + df[bkgs].sum(axis=1)
    return num / den


# scores definition
hwwev = ["fj_PN_probHWqqWev0c", "fj_PN_probHWqqWev1c", "fj_PN_probHWqqWtauev0c", "fj_PN_probHWqqWtauev1c"]
hwwmv = ["fj_PN_probHWqqWmv0c", "fj_PN_probHWqqWmv1c", "fj_PN_probHWqqWtaumv0c", "fj_PN_probHWqqWtaumv1c"]
hwwhad = [
    "fj_PN_probHWqqWqq0c",
    "fj_PN_probHWqqWqq1c",
    "fj_PN_probHWqqWqq2c",
    "fj_PN_probHWqqWq0c",
    "fj_PN_probHWqqWq1c",
    "fj_PN_probHWqqWq2c",
    "fj_PN_probHWqqWtauhv0c",
    "fj_PN_probHWqqWtauhv1c",
]
qcd = ["fj_PN_probQCDbb", "fj_PN_probQCDcc", "fj_PN_probQCDb", "fj_PN_probQCDc", "fj_PN_probQCDothers"]

tope = ["fj_PN_probTopbWev", "fj_PN_probTopbWtauev"]
topm = ["fj_PN_probTopbWmv", "fj_PN_probTopbWtaumv"]
tophad = ["fj_PN_probTopbWqq0c", "fj_PN_probTopbWqq1c", "fj_PN_probTopbWq0c", "fj_PN_probTopbWq1c", "fj_PN_probTopbWtauhv"]

top = tope + topm + tophad

sigs = hwwev + hwwmv + hwwhad

qcd_bkg = [b.replace("PN", "ParT") for b in qcd]
top_bkg = [b.replace("PN", "ParT") for b in top]
inclusive_bkg = [b.replace("PN", "ParT") for b in qcd + top]
new_sig = [s.replace("PN", "ParT") for s in sigs]

In [212]:
year = "2017"
channels = ["mu", "ele"]

# samples_dir = "../Mar23_2017"
samples_dir = f"../Apr12_presel_{year}"

samples = [
    "HWW", 
    "VH",
    "VBF",
    "ttH",
    "QCD",
    "DYJets",
    "WJetsLNu",
    "WZQQ",
    "TTbar",
    "SingleTop",
    "Diboson",
    "Data"
]

In [228]:



regions_selections = {
    "cat1_sr": "( (inclusive_score>0.99) & (n_bjets_M < 2) & (lep_fj_dr<0.3) )",  
    "wjets_cr": "( (inclusive_score>0.99) & (n_bjets_M < 1) & (lep_fj_dr>0.3) )",  
    "tt_cr": "( (inclusive_score<0.90) & (n_bjets_M >=2 ) & (lep_fj_dr>0.3) )",  
}

# initialzie th histograms
regions = ["cat1_sr", "wjets_cr", "tt_cr"]
hists = {}
for region in regions:
    hists[region] = hist2.Hist(    
        hist2.axis.StrCategory([], name="samples", growth=True),
        hist2.axis.Regular(30, 200, 600, name="fj_pt", label=r"Jet $p_T$ [GeV]", overflow=True),
        hist2.axis.Regular(25, 50, 480, name="rec_higgs_m", label=r"Higgs reconstructed mass [GeV]", overflow=True),
    )
    

for ch in channels:

    # get lumi
    luminosity = 0
    with open("../fileset/luminosity.json") as f:
        luminosity += json.load(f)[ch][year]

    condor_dir = os.listdir(samples_dir)
    for sample in condor_dir:
        
        if sample == "DYJetsToLL_M-10to50":
            # ParT is not there for some reason
            continue

        # get a combined label to combine samples of the same process
        for key in combine_samples:
            if key in sample:
                sample_to_use = combine_samples[key]
                break
            else:
                sample_to_use = sample

        if sample_to_use not in samples:
            print(f"ATTENTION: {sample} will be skipped")
            continue

        is_data = False
        if sample_to_use == "Data":
            is_data = True

        print(f"Finding {sample} samples and should combine them under {sample_to_use}")

        out_files = f"{samples_dir}/{sample}/outfiles/"
        parquet_files = glob.glob(f"{out_files}/*_{ch}.parquet")
        pkl_files = glob.glob(f"{out_files}/*.pkl")

        if not parquet_files:
            print(f"No parquet file for {sample}")
            continue

        data = pd.read_parquet(parquet_files)
        if len(data) == 0:
            continue

        # replace the weight_pileup of the strange events with the mean weight_pileup of all the other events
        if not is_data:
            strange_events = data["weight_pileup"] > 6
            if len(strange_events) > 0:
                data["weight_pileup"][strange_events] = data[~strange_events]["weight_pileup"].mean(axis=0)

        # apply selection
#         print("---> Applying preselection.")
        for selection in presel[ch]:
#             print(f"applying {selection} selection on {len(data)} events")
            data = data.query(presel[ch][selection])
#             print("---> Done with preselection.")

        # get event_weight
        if not is_data:
#                 print("---> Accumulating event weights.")
            event_weight = get_xsecweight(pkl_files, year, sample, is_data, luminosity)
            for w in weights[ch]:
                if w not in data.keys():
#                     print(f"{w} weight is not stored in parquet")
                    continue
                if weights[ch][w] == 1:
#                     print(f"Applying {w} weight")
                    event_weight *= data[w]

#                 print("---> Done with accumulating event weights.")
        else:
            event_weight = np.ones_like(data["fj_pt"])

        data["event_weight"] = event_weight

        # add tagger scores
        data["inclusive_score"] = disc_score(data, new_sig, inclusive_bkg)

        for region in regions:
            data1 = data.copy()   # get fresh copy of the data to apply selections on
#             print(f"{region}: applying selection on {len(data1)} events")
            data1 = data1.query(regions_selections[region])
#             print(f"will fill the {sample_to_use} dataframe with the remaining {len(data1)} events")

            hists[region].fill(
            samples=sample_to_use,
            fj_pt=data1["fj_pt"],
            rec_higgs_m=data1["rec_higgs_m"],
            weight=data1["event_weight"],
            )

Finding WJetsToLNu_HT-100To200 samples and should combine them under WJetsLNu
Finding DYJetsToLL_Pt-400To650 samples and should combine them under DYJets
Finding VBFHToWWToLNuQQ_M-125_withDipoleRecoil samples and should combine them under VBF
Finding HWminusJ_HToWW_M-125 samples and should combine them under VH
Finding WJetsToLNu_HT-800To1200 samples and should combine them under WJetsLNu
Finding TTToSemiLeptonic samples and should combine them under TTbar
Finding DYJetsToLL_Pt-250To400 samples and should combine them under DYJets
Finding ST_t-channel_top_4f_InclusiveDecays samples and should combine them under SingleTop
Finding ST_s-channel_4f_hadronicDecays samples and should combine them under SingleTop
Finding WJetsToLNu_HT-1200To2500 samples and should combine them under WJetsLNu
Finding WJetsToLNu_HT-200To400 samples and should combine them under WJetsLNu
Finding ST_tW_top_5f_inclusiveDecays samples and should combine them under SingleTop
Finding GluGluHToWW_Pt-200ToInf_M-125 sam

Finding QCD_Pt_800to1000 samples and should combine them under QCD
Finding WJetsToQQ_HT-400to600 samples and should combine them under WZQQ
Finding WJetsToLNu_HT-400To600 samples and should combine them under WJetsLNu
Finding QCD_Pt_470to600 samples and should combine them under QCD
Finding HZJ_HToWW_M-125 samples and should combine them under VH
Finding WZ samples and should combine them under Diboson
Finding QCD_Pt_1400to1800 samples and should combine them under QCD
Finding DYJetsToLL_Pt-100To250 samples and should combine them under DYJets
Finding SingleMuon_Run2017F samples and should combine them under Data


In [238]:
hists.keys()

dict_keys(['cat1_sr', 'wjets_cr', 'tt_cr'])

In [239]:
hists["cat1_sr"]

Hist(
  StrCategory(['WJetsLNu', 'DYJets', 'VBF', 'VH', 'TTbar', 'SingleTop', 'HWW', 'QCD', 'Data', 'Diboson', 'WZQQ', 'ttH'], growth=True, name='samples'),
  Regular(30, 200, 600, name='fj_pt', label='Jet $p_T$ [GeV]'),
  Regular(25, 50, 480, name='rec_higgs_m', label='Higgs reconstructed mass [GeV]'),
  storage=Double()) # Sum: 7131.306830118755 (7504.16116732042 with flow)

In [240]:
hists["cat1_sr"][{"fj_pt": sum}]

# Store the hists in a rootfile

In [232]:
for region in hists.keys():
    file = uproot.recreate(f"hww_templates/{region}.root")
    
    for sample in hists[region].axes["samples"]:
        if sample=="Data":
            file[f"{region}/data_obs"] = hists[region][{"fj_pt": sum, "samples": sample}]
            continue
        file[f"{region}/{sample}"] = hists[region][{"fj_pt": sum, "samples": sample}]

In [245]:
a = uproot.open("hww_templates/cat1_sr.root")
a.keys()

['cat1_sr;1',
 'cat1_sr/WJetsLNu;1',
 'cat1_sr/DYJets;1',
 'cat1_sr/VBF;1',
 'cat1_sr/VH;1',
 'cat1_sr/TTbar;1',
 'cat1_sr/SingleTop;1',
 'cat1_sr/HWW;1',
 'cat1_sr/QCD;1',
 'cat1_sr/data_obs;1',
 'cat1_sr/Diboson;1',
 'cat1_sr/WZQQ;1',
 'cat1_sr/ttH;1']

In [250]:
a["cat1_sr/HWW"].to_hist()

In [242]:
a = uproot.open("hww_templates/wjets_cr.root")
a.keys()

['wjets_cr;1',
 'wjets_cr/WJetsLNu;1',
 'wjets_cr/DYJets;1',
 'wjets_cr/VBF;1',
 'wjets_cr/VH;1',
 'wjets_cr/TTbar;1',
 'wjets_cr/SingleTop;1',
 'wjets_cr/HWW;1',
 'wjets_cr/data_obs;1',
 'wjets_cr/QCD;1',
 'wjets_cr/Diboson;1',
 'wjets_cr/WZQQ;1',
 'wjets_cr/ttH;1']

In [235]:
a["wjets_cr/WJetsLNu"].to_numpy()

(array([  1.91875912,  30.14432289, 102.06678237, 185.05198495,
        273.34322459, 328.88311224, 337.97492873, 325.03006213,
        335.32459516, 307.09825714, 283.07327381, 238.0810131 ,
        194.74329398, 166.51987618, 129.79129267, 113.37776244,
         86.96370929,  66.7398087 ,  50.47681562,  38.22177202,
         25.07622356,  20.79568437,  19.49033967,  11.26194787,
          9.05860176]),
 array([ 50. ,  67.2,  84.4, 101.6, 118.8, 136. , 153.2, 170.4, 187.6,
        204.8, 222. , 239.2, 256.4, 273.6, 290.8, 308. , 325.2, 342.4,
        359.6, 376.8, 394. , 411.2, 428.4, 445.6, 462.8, 480. ]))

In [236]:
a["wjets_cr/TTbar"].to_numpy()

(array([ 0.        ,  0.6629599 ,  1.82597946,  3.28765297,  6.10959692,
         9.56550009, 12.36021317, 13.84239876, 13.32106499, 12.69352841,
        13.12131288, 11.28368488,  9.69319149,  9.38160001,  9.03113202,
         7.05978494,  5.48594933,  4.91750024,  3.1108927 ,  3.16715066,
         1.61765083,  2.46486222,  1.68065109,  1.05804321,  0.60996375]),
 array([ 50. ,  67.2,  84.4, 101.6, 118.8, 136. , 153.2, 170.4, 187.6,
        204.8, 222. , 239.2, 256.4, 273.6, 290.8, 308. , 325.2, 342.4,
        359.6, 376.8, 394. , 411.2, 428.4, 445.6, 462.8, 480. ]))

In [134]:
! ls hists_tests.pkl

hists_tests.pkl


In [135]:
with open("hists_tests.pkl", "rb") as f:
    hists_tests = pickle.load(f)

In [139]:
hists_tests["cat1_sr"][{"fj_pt":sum}]

In [115]:
! ls ../melissa/datacards/datacards_jan182021/allhad4top_2017_shapehists.root

../melissa/datacards/datacards_jan182021/allhad4top_2017_shapehists.root


In [117]:
import uproot
f = uproot.open("../melissa/datacards/datacards_jan182021/allhad4top_2017_shapehists.root")

In [120]:
f.keys()

['cut2bin1_2017_data_obs;1',
 'cut2bin1_2017_TTTT;1',
 'cut2bin1_2017_TTX;1',
 'cut2bin1_2017_other;1',
 'cut2bin1_2017_DDBKG;1',
 'cut2bin0_2017_data_obs;1',
 'cut2bin0_2017_TTTT;1',
 'cut2bin0_2017_TTX;1',
 'cut2bin0_2017_other;1',
 'cut2bin0_2017_DDBKG;1',
 'cut1bin0_2017_data_obs;1',
 'cut1bin0_2017_TTTT;1',
 'cut1bin0_2017_TTX;1',
 'cut1bin0_2017_other;1',
 'cut1bin0_2017_DDBKG;1',
 'cut1bin1_2017_data_obs;1',
 'cut1bin1_2017_TTTT;1',
 'cut1bin1_2017_TTX;1',
 'cut1bin1_2017_other;1',
 'cut1bin1_2017_DDBKG;1',
 'cut0bin1_2017_data_obs;1',
 'cut0bin1_2017_TTTT;1',
 'cut0bin1_2017_TTX;1',
 'cut0bin1_2017_other;1',
 'cut0bin1_2017_DDBKG;1',
 'cut0bin0_2017_data_obs;1',
 'cut0bin0_2017_TTTT;1',
 'cut0bin0_2017_TTX;1',
 'cut0bin0_2017_other;1',
 'cut0bin0_2017_DDBKG;1',
 'cut0bin3_2017_data_obs;1',
 'cut0bin3_2017_TTTT;1',
 'cut0bin3_2017_TTX;1',
 'cut0bin3_2017_other;1',
 'cut0bin3_2017_DDBKG;1',
 'cut0bin2_2017_data_obs;1',
 'cut0bin2_2017_TTTT;1',
 'cut0bin2_2017_TTX;1',
 'cut0bin2_2

In [130]:
h = f["cut0bin6_2017_TTX_btagHFstats1_2017Up"]
h

<TH1F (version 2) at 0x7f96917ef100>

In [131]:
uproot.from_pyroot(h)

ModuleNotFoundError: No module named 'ROOT'

# Plot histograms

In [102]:
! ls /Users/fmokhtar/Desktop/hww/templates

[34m23May13[m[m           [34m23May13MP[m[m         AN2021_126_v4.pdf


In [106]:
with open('/Users/fmokhtar/Desktop/hww/templates/23May13/2017_templates.pkl', 'rb') as f:
    data = pickle.load(f)

In [107]:
data.keys()

dict_keys(['pass', 'fail', 'pass_JES_up', 'fail_JES_up', 'pass_JES_down', 'fail_JES_down', 'pass_JER_up', 'fail_JER_up', 'pass_JER_down', 'fail_JER_down', 'pass_JMS_up', 'fail_JMS_up', 'pass_JMS_down', 'fail_JMS_down', 'pass_JMR_up', 'fail_JMR_up', 'pass_JMR_down', 'fail_JMR_down', 'passBlinded', 'failBlinded', 'pass_JES_upBlinded', 'fail_JES_upBlinded', 'pass_JES_downBlinded', 'fail_JES_downBlinded', 'pass_JER_upBlinded', 'fail_JER_upBlinded', 'pass_JER_downBlinded', 'fail_JER_downBlinded', 'pass_JMS_upBlinded', 'fail_JMS_upBlinded', 'pass_JMS_downBlinded', 'fail_JMS_downBlinded', 'pass_JMR_upBlinded', 'fail_JMR_upBlinded', 'pass_JMR_downBlinded', 'fail_JMR_downBlinded'])

In [109]:
plt.rcParams.update({"font.size": 20})
data["pass"]

In [None]:
#### Inputs to the function
    events_dict: Dict[str, pd.DataFrame],
    bb_masks: Dict[str, pd.DataFrame],
    year: str,
    sig_keys: List[str],
    selection_regions: Dict[str, Region],
    shape_vars: List[ShapeVar],
    systematics: Dict,
    template_dir: str = "",
    bg_keys: List[str] = bg_keys,
    plot_dir: str = "",
    prev_cutflow: pd.DataFrame = None,
    weight_key: str = "finalWeight",
    sig_splits: List[List[str]] = None,
    weight_shifts: Dict = {},
    jshift: str = "",
    plot_shifts: bool = False,
    pass_ylim: int = None,
    fail_ylim: int = None,
    blind_pass: bool = False,
    show: bool = False,

In [None]:
# (1)

events_dict = _load_samples(args, bg_samples, sig_samples, cutflow)

### inputs
cutflow = pd.DataFrame(index=list(all_samples.keys()))  # save cutflow as pandas table

### inputs built from another function
sig_samples = res_samples if args.resonant else nonres_samples

if args.read_sig_samples:
    # read all signal samples in directory
    read_year = args.year if args.year != "all" else "2017"
    read_samples = os.listdir(f"{args.signal_data_dir}/{args.year}")
    sig_samples = OrderedDict()
    for sample in read_samples:
        if sample.startswith("NMSSM_XToYHTo2W2BTo4Q2B_MX-"):
            mY = int(sample.split("-")[-1])
            mX = int(sample.split("NMSSM_XToYHTo2W2BTo4Q2B_MX-")[1].split("_")[0])

            sig_samples[f"X[{mX}]->H(bb)Y[{mY}](VV)"] = sample

if args.sig_samples is not None:
    for sig_key, sample in list(sig_samples.items()):
        if sample not in args.sig_samples:
            del sig_samples[sig_key]

bg_samples = deepcopy(samples)
for bg_key, sample in list(bg_samples.items()):
    if bg_key not in args.bg_keys and bg_key != data_key:
        del bg_samples[bg_key]

if not args.resonant:
    for key in sig_samples.copy():
        if key not in BDT_sample_order:
            del sig_samples[key]

    for key in bg_samples.copy():
        if key not in BDT_sample_order:
            del bg_samples[key]

if not args.data:
    del bg_samples[data_key]

# (4)
sig_keys = list(sig_samples.keys())
bg_keys = list(bg_samples.keys())


### the function itself
def _load_samples(args, samples, sig_samples, cutflow):
    filters = old_filters if args.old_processor else new_filters
    events_dict = utils.load_samples(args.signal_data_dir, sig_samples, args.year, filters)
    events_dict |= utils.load_samples(args.data_dir, samples, args.year, filters)
    utils.add_to_cutflow(events_dict, "Pre-selection", "weight", cutflow)

    print("")
    # print weighted sample yields
    wkey = "finalWeight" if "finalWeight" in list(events_dict.values())[0] else "weight"
    for sample in events_dict:
        tot_weight = np.sum(events_dict[sample][wkey].values)
        print(f"Pre-selection {sample} yield: {tot_weight:.2f}")

    return events_dict


### helper functions
def utils.load_samples(
    data_dir: str,
    samples: Dict[str, str],
    year: str,
    filters: List = None,
    columns: List = None,
) -> Dict[str, pd.DataFrame]:
    """
    Loads events with an optional filter.
    Reweights samples by nevents.

    Args:
        data_dir (str): path to data directory.
        samples (Dict[str, str]): dictionary of samples and selectors to load.
        year (str): year.
        filters (List): Optional filters when loading data.

    Returns:
        Dict[str, pd.DataFrame]: ``events_dict`` dictionary of events dataframe for each sample.

    """

    from os import listdir

    full_samples_list = listdir(f"{data_dir}/{year}")
    events_dict = {}

    for label, selector in samples.items():
        events_dict[label] = []
        for sample in full_samples_list:
            if not check_selector(sample, selector):
                continue

            # print(sample)
            # if sample.startswith("QCD") and not sample.endswith("_PSWeights_madgraph"):
            #     continue

            if not exists(f"{data_dir}/{year}/{sample}/parquet"):
                print(f"No parquet file for {sample}")
                continue

            # print(f"Loading {sample}")
            events = pd.read_parquet(
                f"{data_dir}/{year}/{sample}/parquet", filters=filters, columns=columns
            )
            not_empty = len(events) > 0
            pickles_path = f"{data_dir}/{year}/{sample}/pickles"

            if label != data_key:
                if label in nonres_sig_keys + res_sig_keys:
                    n_events = get_cutflow(pickles_path, year, sample)["has_4q"]
                else:
                    n_events = get_nevents(pickles_path, year, sample)

                if not_empty:
                    if "weight_noxsec" in events:
                        if np.all(events["weight"] == events["weight_noxsec"]):
                            print(f"WARNING: {sample} has not been scaled by its xsec and lumi")

                    events["weight_nonorm"] = events["weight"]

                    if "weight_noTrigEffs" in events and not np.all(
                        np.isclose(events["weight"], events["weight_noTrigEffs"], rtol=1e-5)
                    ):
                        events["finalWeight"] = events["weight"] / n_events
                        events["finalWeight_noTrigEffs"] = events["weight_noTrigEffs"] / n_events
                    else:
                        events["weight"] /= n_events

            if not_empty:
                events_dict[label].append(events)

            print(f"Loaded {sample: <50}: {len(events)} entries")

        if len(events_dict[label]):
            events_dict[label] = pd.concat(events_dict[label])
        else:
            del events_dict[label]

    return events_dict

def add_to_cutflow(
    events_dict: Dict[str, pd.DataFrame], key: str, weight_key: str, cutflow: pd.DataFrame
):
    cutflow[key] = [
        np.sum(events_dict[sample][weight_key]).squeeze() for sample in list(cutflow.index)
    ]

In [None]:
# (2)
bb_masks = bb_VV_assignment(events_dict)
sig_keys, sig_samples, bg_keys, bg_samples = _process_samples(args)

In [None]:
# (3) 
year = "2017"

In [None]:
# (5)
selection_regions = (
    get_nonres_selection_regions(args.year, **cutargs)
    if not args.resonant
    else get_res_selection_regions(args.year, **cutargs)
)


def get_nonres_selection_regions(
    year: str,
    txbb_wp: str = "HP",
    bdt_wp: float = 0.99,
):
    pt_cuts = [300, CUT_MAX_VAL]
    txbb_cut = txbb_wps[year][txbb_wp]

    return {
        # {label: {cutvar: [min, max], ...}, ...}
        "pass": Region(
            cuts={
                "bbFatJetPt": pt_cuts,
                "VVFatJetPt": pt_cuts,
                "BDTScore": [bdt_wp, CUT_MAX_VAL],
                "bbFatJetParticleNetMD_Txbb": [txbb_cut, CUT_MAX_VAL],
            },
            label="Pass",
        ),
        "fail": Region(
            cuts={
                "bbFatJetPt": pt_cuts,
                "VVFatJetPt": pt_cuts,
                "bbFatJetParticleNetMD_Txbb": [0.8, txbb_cut],
            },
            label="Fail",
        ),
        "lpsf": Region(
            cuts={  # cut for which LP SF is calculated
                "BDTScore": [bdt_wp, CUT_MAX_VAL],
            },
            label="LP SF Cut",
        ),
    }


def get_res_selection_regions(
    year: str, mass_window: List[float] = [110, 145], txbb_wp: str = "HP", thww_wp: float = 0.96
):
    pt_cuts = [300, CUT_MAX_VAL]
    mwsize = mass_window[1] - mass_window[0]
    mw_sidebands = [
        [mass_window[0] - mwsize / 2, mass_window[0]],
        [mass_window[1], mass_window[1] + mwsize / 2],
    ]
    txbb_cut = txbb_wps[year][txbb_wp]

    return {
        # "unblinded" regions:
        "pass": Region(
            cuts={
                "bbFatJetPt": pt_cuts,
                "VVFatJetPt": pt_cuts,
                "bbFatJetParticleNetMass": mass_window,
                "bbFatJetParticleNetMD_Txbb": [txbb_cut, CUT_MAX_VAL],
                "VVFatJetParTMD_THWWvsT": [thww_wp, CUT_MAX_VAL],
            },
            label="Pass",
        ),
        "fail": Region(
            cuts={
                "bbFatJetPt": pt_cuts,
                "VVFatJetPt": pt_cuts,
                "bbFatJetParticleNetMass": mass_window,
                "bbFatJetParticleNetMD_Txbb": [0.8, txbb_cut],
                "VVFatJetParTMD_THWWvsT": [-CUT_MAX_VAL, thww_wp],
            },
            label="Fail",
        ),
        # "blinded" validation regions:
        "passBlinded": Region(
            cuts={
                "bbFatJetPt": pt_cuts,
                "VVFatJetPt": pt_cuts,
                "bbFatJetParticleNetMass": mw_sidebands,
                "bbFatJetParticleNetMD_Txbb": [txbb_cut, CUT_MAX_VAL],
                "VVFatJetParTMD_THWWvsT": [thww_wp, CUT_MAX_VAL],
            },
            label="Validation Pass",
        ),
        "failBlinded": Region(
            cuts={
                "bbFatJetPt": pt_cuts,
                "VVFatJetPt": pt_cuts,
                "bbFatJetParticleNetMass": mw_sidebands,
                "bbFatJetParticleNetMD_Txbb": [0.8, txbb_cut],
                "VVFatJetParTMD_THWWvsT": [-CUT_MAX_VAL, thww_wp],
            },
            label="Validation Fail",
        ),
        # cut for which LP SF is calculated
        "lpsf": Region(
            cuts={"VVFatJetParTMD_THWWvsT": [thww_wp, CUT_MAX_VAL]},
            label="LP SF Cut",
        ),
    }

In [None]:
# (6)

# fitting on bb regressed mass for nonresonant
nonres_shape_vars = [
    ShapeVar(
        "bbFatJetParticleNetMass",
        r"$m^{bb}_{Reg}$ (GeV)",
        [20, 50, 250],
        reg=True,
        blind_window=[100, 150],
    )
]


# fitting on VV regressed mass + dijet mass for resonant
res_shape_vars = [
    ShapeVar(
        "VVFatJetParticleNetMass",
        r"$m^{VV}_{Reg}$ (GeV)",
        list(range(50, 110, 10)) + list(range(110, 200, 15)) + [200, 220, 250],
        reg=False,
    ),
    ShapeVar(
        "DijetMass",
        r"$m^{jj}$ (GeV)",
        list(range(800, 1400, 100)) + [1400, 1600, 2000, 3000, 4400],
        reg=False,
    ),
]

shape_vars, scan, scan_cuts, scan_wps = _init(args)

def _init(args):
    if not (args.control_plots or args.templates or args.scan):
        print("You need to pass at least one of --control-plots, --templates, or --scan")
        return

    if not args.resonant:
        scan = len(args.nonres_txbb_wp) > 1 or len(args.nonres_bdt_wp) > 1
        scan_wps = list(itertools.product(args.nonres_txbb_wp, args.nonres_bdt_wp))
        scan_cuts = nonres_scan_cuts
        shape_vars = nonres_shape_vars
    else:
        scan = len(args.res_txbb_wp) > 1 or len(args.res_thww_wp) > 1
        scan_wps = list(itertools.product(args.res_txbb_wp, args.res_thww_wp))
        scan_cuts = res_scan_cuts
        shape_vars = res_shape_vars

    return shape_vars, scan, scan_cuts, scan_wps

In [None]:
# (7)

systematics = _check_load_systematics(systs_file, args.year)

def _check_load_systematics(systs_file: str, year: str):
    if os.path.exists(systs_file):
        print("Loading systematics")
        with open(systs_file, "r") as f:
            systematics = json.load(f)
    else:
        systematics = {}

    if year not in systematics:
        systematics[year] = {}

    return systematics

In [None]:
# (8) 
# TODO: check which of these applies to resonant as well
weight_shifts = {
    "pileup": Syst(samples=nonres_sig_keys + res_sig_keys + bg_keys, label="Pileup"),
    "PDFalphaS": Syst(samples=nonres_sig_keys, label="PDF"),
    "ISRPartonShower": Syst(samples=nonres_sig_keys + ["V+Jets"], label="ISR Parton Shower"),
    "FSRPartonShower": Syst(samples=nonres_sig_keys + ["V+Jets"], label="FSR Parton Shower"),
    "L1EcalPrefiring": Syst(
        samples=nonres_sig_keys + res_sig_keys + bg_keys,
        years=["2016APV", "2016", "2017"],
        label="L1 ECal Prefiring",
    ),
    # "top_pt": ["TT"],
}

In [None]:
def _get_fill_data(
    events: pd.DataFrame, bb_mask: pd.DataFrame, shape_vars: List[ShapeVar], jshift: str = ""
):
    return {
        shape_var.var: utils.get_feat(
            events,
            shape_var.var if jshift == "" else utils.check_get_jec_var(shape_var.var, jshift),
            bb_mask,
        )
        for shape_var in shape_vars
    }

def get_feat(events: pd.DataFrame, feat: str, bb_mask: pd.DataFrame = None):
    if feat in events:
        return events[feat].values.squeeze()
    elif feat.startswith("bb") or feat.startswith("VV"):
        assert bb_mask is not None, "No bb mask given!"
        return events["ak8" + feat[2:]].values[bb_mask ^ feat.startswith("VV")].squeeze()



def get_templates(
    events_dict: Dict[str, pd.DataFrame],
    bb_masks: Dict[str, pd.DataFrame],
    year: str,
    sig_keys: List[str],
    selection_regions: Dict[str, Region],
    shape_vars: List[ShapeVar],
    systematics: Dict,
    template_dir: str = "",
    bg_keys: List[str] = bg_keys,
    plot_dir: str = "",
    prev_cutflow: pd.DataFrame = None,
    weight_key: str = "finalWeight",
    sig_splits: List[List[str]] = None,
    weight_shifts: Dict = {},
    jshift: str = "",
    plot_shifts: bool = False,
    pass_ylim: int = None,
    fail_ylim: int = None,
    blind_pass: bool = False,
    show: bool = False,
) -> Dict[str, Hist]:
    """
    (1) Makes histograms for each region in the ``selection_regions`` dictionary,
    (2) Applies the Txbb scale factor in the pass region,
    (3) Calculates trigger uncertainty,
    (4) Calculates weight variations if ``weight_shifts`` is not empty (and ``jshift`` is ""),
    (5) Takes JEC / JSMR shift into account if ``jshift`` is not empty,
    (6) Saves a plot of each (if ``plot_dir`` is not "").

    Args:
        selection_region (Dict[str, Dict]): Dictionary of ``Region``s including cuts and labels.
        bg_keys (list[str]): background keys to plot.

    Returns:
        Dict[str, Hist]: dictionary of templates, saved as hist.Hist objects.

    """
    do_jshift = jshift != ""
    jlabel = "" if not do_jshift else "_" + jshift
    templates = {}

    for rname, region in selection_regions.items():
        pass_region = rname.startswith("pass")

        if rname == "lpsf":
            continue

        if not do_jshift:
            print(rname)

        # make selection, taking JEC/JMC variations into account
        sel, cf = utils.make_selection(
            region.cuts, events_dict, bb_masks, prev_cutflow=prev_cutflow, jshift=jshift
        )

        if template_dir != "":
            cf.to_csv(f"{template_dir}/{rname}_cutflow{jlabel}.csv")

        # trigger uncertainties
        if not do_jshift:
            systematics[year][rname] = {}
            total, total_err = corrections.get_uncorr_trig_eff_unc(events_dict, bb_masks, year, sel)
            systematics[year][rname]["trig_total"] = total
            systematics[year][rname]["trig_total_err"] = total_err
            print(f"Trigger SF Unc.: {total_err / total:.3f}\n")

        # ParticleNetMD Txbb and ParT LP SFs
        sig_events = {}
        for sig_key in sig_keys:
            sig_events[sig_key] = deepcopy(events_dict[sig_key][sel[sig_key]])
            sig_bb_mask = bb_masks[sig_key][sel[sig_key]]

            if pass_region:
                # scale signal by LP SF
                for wkey in [weight_key, f"{weight_key}_noTrigEffs"]:
                    sig_events[sig_key][wkey] *= systematics[sig_key]["lp_sf"]

                corrections.apply_txbb_sfs(sig_events[sig_key], sig_bb_mask, year, weight_key)

        # if not do_jshift:
        #     print("\nCutflow:\n", cf)

        # set up samples
        hist_samples = list(events_dict.keys())

        if not do_jshift:
            # set up weight-based variations
            for shift in ["down", "up"]:
                if pass_region:
                    for sig_key in sig_keys:
                        hist_samples.append(f"{sig_key}_txbb_{shift}")

                for wshift, wsyst in weight_shifts.items():
                    # if year in wsyst.years:
                    # add to the axis even if not applied to this year to make it easier to sum later
                    for wsample in wsyst.samples:
                        if wsample in events_dict:
                            hist_samples.append(f"{wsample}_{wshift}_{shift}")

        # histograms
        h = Hist(
            hist.axis.StrCategory(hist_samples, name="Sample"),
            *[shape_var.axis for shape_var in shape_vars],
            storage="weight",
        )

        # fill histograms
        for sample in events_dict:
            events = sig_events[sample] if sample in sig_keys else events_dict[sample][sel[sample]]
            if not len(events):
                continue

            bb_mask = bb_masks[sample][sel[sample]]
            fill_data = _get_fill_data(
                events, bb_mask, shape_vars, jshift=jshift if sample != data_key else None
            )
            weight = events[weight_key].values.squeeze()
            h.fill(Sample=sample, **fill_data, weight=weight)

            if not do_jshift:
                # add weight variations
                for wshift, wsyst in weight_shifts.items():
                    if sample in wsyst.samples and year in wsyst.years:
                        # print(wshift)
                        for skey, shift in [("Down", "down"), ("Up", "up")]:
                            # reweight based on diff between up/down and nominal weights
                            sweight = (
                                weight
                                * (
                                    events[f"weight_{wshift}{skey}"][0] / events["weight_nonorm"]
                                ).values.squeeze()
                            )
                            h.fill(Sample=f"{sample}_{wshift}_{shift}", **fill_data, weight=sweight)

        if pass_region:
            # blind signal mass windows in pass region in data
            for i, shape_var in enumerate(shape_vars):
                if shape_var.blind_window is not None:
                    utils.blindBins(h, shape_var.blind_window, data_key, axis=i)

        if pass_region and not do_jshift:
            for sig_key in sig_keys:
                if not len(sig_events[sig_key]):
                    continue

                # ParticleNetMD Txbb SFs
                fill_data = _get_fill_data(
                    sig_events[sig_key], bb_masks[sig_key][sel[sig_key]], shape_vars
                )
                for shift in ["down", "up"]:
                    h.fill(
                        Sample=f"{sig_key}_txbb_{shift}",
                        **fill_data,
                        weight=sig_events[sig_key][f"{weight_key}_txbb_{shift}"],
                    )

        templates[rname + jlabel] = h

        # plot templates incl variations
        if plot_dir != "" and (not do_jshift or plot_shifts):
            if pass_region:
                sig_scale_dict = {"HHbbVV": 1, **{skey: 1 for skey in res_sig_keys}}

            title = (
                f"{region.label} Region Pre-Fit Shapes"
                if not do_jshift
                else f"{region.label} Region {jshift} Shapes"
            )

            if sig_splits is None:
                sig_splits = [sig_keys]

            for i, shape_var in enumerate(shape_vars):
                for j, plot_sig_keys in enumerate(sig_splits):
                    split_str = "" if len(sig_splits) == 1 else f"sigs{j}_"
                    plot_params = {
                        "hists": h.project(0, i + 1),
                        "sig_keys": plot_sig_keys,
                        "bg_keys": bg_keys,
                        "sig_scale_dict": {key: sig_scale_dict[key] for key in plot_sig_keys}
                        if pass_region
                        else None,
                        "show": show,
                        "year": year,
                        "ylim": pass_ylim if pass_region else fail_ylim,
                        "plot_data": not (rname == "pass" and blind_pass),
                    }

                    plot_name = (
                        f"{plot_dir}/"
                        f"{'jshifts/' if do_jshift else ''}"
                        f"{split_str}{rname}_region_{shape_var.var}"
                    )

                    plotting.ratioHistPlot(
                        **plot_params,
                        title=title,
                        name=f"{plot_name}{jlabel}.pdf",
                    )

                    if not do_jshift and plot_shifts:
                        plot_name = (
                            f"{plot_dir}/wshifts/" f"{split_str}{rname}_region_{shape_var.var}"
                        )

                        for wshift, wsyst in weight_shifts.items():
                            if wsyst.samples == [sig_key]:
                                plotting.ratioHistPlot(
                                    **plot_params,
                                    sig_err=wshift,
                                    title=f"{region.label} Region {wsyst.label} Unc. Shapes",
                                    name=f"{plot_name}_{wshift}.pdf",
                                )
                            else:
                                for skey, shift in [("Down", "down"), ("Up", "up")]:
                                    plotting.ratioHistPlot(
                                        **plot_params,
                                        variation=(wshift, shift, wsyst.samples),
                                        title=f"{region.label} Region {wsyst.label} Unc. {skey} Shapes",
                                        name=f"{plot_name}_{wshift}_{shift}.pdf",
                                    )

                        if pass_region:
                            plotting.ratioHistPlot(
                                **plot_params,
                                sig_err="txbb",
                                title=rf"{region.label} Region $T_{{Xbb}}$ Shapes",
                                name=f"{plot_name}_txbb.pdf",
                            )

    return templates