In [1]:
import uproot
import awkward as ak
import numpy as np
import hist
import matplotlib.pyplot as plt
import mplhep as hep
plt.style.use(hep.style.CMS)
from collections import defaultdict
import correctionlib
import os, json

no selection on fatjet as that is nested in res2b, probably this is the best region, but no single choice
({elliptical_cut_90}) && bjet1_idx>=0 && bjet2_idx>=0 && ({cat_reqs["HPSTau"]})"

https://indico.cern.ch/event/1451226/contributions/6253601/attachments/2977214/5241283/HHbbtautau_updates_29_11_2024.pdf


https://btv-wiki.docs.cern.ch/PerformanceCalibration/shapeCorrectionSFRecommendations/

weights.all = ["genWeightFixed", "puWeight", "trigSF",
            "idAndIsoAndFakeSF", "L1PreFiringWeight", "PUjetID_SF",
            "DYstitchWeight", "pdfWeight", "scaleWeight"
            ]
        if self.year == 2018:
            weights.all.append("hem_weight") # hem issue is only for 2018
        elif self.year >= 2016 and self.year <= 2017: pass
        else: raise ValueError()
        for category in self.categories:
            weights[category.name] = copy(weights.all) # very important to copy otherwise all updates are common to categories !
            if not "resolved_" in category.name:
                weights[category.name].append("fatjet_pnet_SF")
            if not "boosted_bb_boostedTau" in category.name: # we only consider boosted_bb for boostedTau thus no need for AK4 btag SFs
                weights[category.name].append("bTagweightReshape")

In [2]:
ar = uproot.concatenate([f"/grid_mnt/data__data.polcms/cms/cuisset/cmt/PreprocessRDF/bul_2017_ZZ_v12/tt_sl/cat_base_selection/prod_250218/data_{i}.root:Events" for i in range(2)], 
    filter_name=[
        "genWeightFixed", "*Weight", "trigSF", "idAndIsoAndFakeSF", "PUjetID_SF", "hem_weight", "fatjet_pnet_SF", "bTagweightReshape_*", "L1PreFiringWeight_Nom",
        "bjet1_JetIdx", "bjet2_JetIdx", "pairType", "isBoostedTau", "bjet1_btagDeepFlavB", "bjet2_btagDeepFlavB",
        "Jet_hadronFlavour", "jets_hem_preselection",
        "nJet", "LHE_Nb", "LHE_Njets"
        ],
    )
ar

In [None]:
year = 2018
version = "prod_250315"
deepjet_wp = correctionlib.CorrectionSet.from_file(f"/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/{year}_UL/btagging.json.gz")["deepJet_wp_values"].evaluate("M")
btag_systs_nojec = ["cferr1", "cferr2", "hf", "hfstats1", "hfstats2", "lf", "lfstats1", "lfstats2"]
btag_systs_jec = ["jesFlavorQCD", "jesRelativeBal", "jesHF", "jesBBEC1", "jesEC2", "jesAbsolute",
    f"jesBBEC1_{year}", f"jesEC2_{year}", f"jesAbsolute_{year}", f"jesHF_{year}", f"jesRelativeSample_{year}"]

jec_nba_to_branch_map = {
"jec_1": "_smeared_FlavorQCD",
"jec_2": "_smeared_RelativeBal",
"jec_3": "_smeared_HF",
"jec_4": "_smeared_BBEC1",
"jec_5": "_smeared_EC2",
"jec_6": "_smeared_Absolute",
"jec_7": f"_smeared_BBEC1_{year}",
"jec_8": f"_smeared_EC2_{year}",
"jec_9": f"_smeared_Absolute_{year}",
"jec_10": f"_smeared_HF_{year}",
"jec_11": f"_smeared_RelativeSample_{year}",
}

In [3]:
def select_region(ar_):
    return (ar_.bjet1_JetIdx>=0)&(ar_.bjet2_JetIdx>=0)&(~(ar_.isBoostedTau))&(ar_.pairType>=0)
def get_base_weights(ar_):
    return ar_.genWeightFixed*ar_.puWeight*ar_.trigSF*ar_.idAndIsoAndFakeSF*ar_.L1PreFiringWeight_Nom*ar_.PUjetID_SF#*ar_.pdfWeight*ar_.scaleWeight
def getFactor(ar_):
    base_weights = get_base_weights(ar_)
    return ak.sum(base_weights)/ak.sum(base_weights * ar_.bTagweightReshape_smeared)
def filter_res2b(ar_):
    return (ar_.bjet1_btagDeepFlavB >= deepjet_wp)&(ar_.bjet2_btagDeepFlavB >= deepjet_wp)
def filter_res1b(ar_):
    return ((ar_.bjet1_btagDeepFlavB >= deepjet_wp)|(ar_.bjet2_btagDeepFlavB >= deepjet_wp))&(~filter_res2b(ar_))
def filter_res0b(ar_):
    return (ar_.bjet1_btagDeepFlavB < deepjet_wp)&(ar_.bjet2_btagDeepFlavB < deepjet_wp)

In [37]:
ar_region = ar[select_region(ar)]
base_weights = get_base_weights(ar_region)
ak.sum(base_weights)/ak.sum(base_weights * ar_region.bTagweightReshape_smeared)

0.9877586121169092

In [14]:
getFactor(ar[select_region(ar) & filter_res0b(ar)]), getFactor(ar[select_region(ar) & filter_res1b(ar)]), getFactor(ar[select_region(ar) & filter_res2b(ar)])

(0.8333372943486229, 0.9776176429196627, 1.0992696876700334)

In [None]:
def computeForDataset(dataset_name, max_events=1e4, jec_uncertainty=None):
    pattern_start = f"/grid_mnt/data__data.polcms/cms/cuisset/cmt/PreprocessRDF/bul_{year}_ZZ_v12/"
    pattern_end = f"/cat_base_selection/{version}/"
    i = 0
    category_names = ["inclusive", "etau", "mutau", "tautau", "e&mutau"]
    sums_weights_noReshapeWeight = {k:0. for k in category_names}
    if jec_uncertainty:
        btag_systs_dir = [jec_uncertainty]
    else:
        btag_systs_dir = ["nominal"] + [syst + "_up" for syst in btag_systs_nojec] + [syst + "_down" for syst in btag_systs_nojec]
    sums_weights_withReshapeWeight = {k:{uncert: 0. for uncert in btag_systs_dir} for k in category_names}
    processed_event_count = {k:0 for k in category_names}
    while True:
        if jec_uncertainty:
            file_path = pattern_start + dataset_name + pattern_end + f"data_{jec_uncertainty}_{i}.root"
        else:
            file_path = pattern_start + dataset_name + pattern_end + f"data_{i}.root"
        if not os.path.isfile(file_path) and i >=1:
            break
        with uproot.open(file_path + ":Events") as t:
            ar = t.arrays(filter_name=[
            "genWeightFixed", "puWeight", "trigSF", "idAndIsoAndFakeSF", "PUjetID_SF", "bTagweightReshape_*", "L1PreFiringWeight_Nom", # "hem_weight", "fatjet_pnet_SF", 
            "bjet1_JetIdx", "bjet2_JetIdx", "pairType", "isBoostedTau", 
            ])
            ar_region = ar[select_region(ar)]
            base_weights = get_base_weights(ar_region)
            for category_name in category_names:
                if category_name == "inclusive": cat_filter = ak.full_like(base_weights, True, dtype=bool)
                elif category_name == "etau": cat_filter = (ar_region.pairType == 1)
                elif category_name == "mutau": cat_filter = (ar_region.pairType == 0)
                elif category_name == "tautau": cat_filter = (ar_region.pairType == 2)
                elif category_name == "e&mutau": cat_filter = ((ar_region.pairType == 0) | (ar_region.pairType == 1))
                else: raise ValueError()
                sums_weights_noReshapeWeight[category_name] += ak.sum(base_weights[cat_filter])

                if not jec_uncertainty:
                    sums_weights_withReshapeWeight[category_name]["nominal"] += ak.sum(base_weights[cat_filter] * ar_region.bTagweightReshape_smeared[cat_filter])
                    for btag_uncert_dir in btag_systs_dir:
                        if btag_uncert_dir == "nominal":continue
                        btag_weight_branch_name = f"bTagweightReshape_smeared_{btag_uncert_dir}"
                        sums_weights_withReshapeWeight[category_name][btag_uncert_dir] += ak.sum(base_weights[cat_filter] * ar_region[btag_weight_branch_name][cat_filter])
                else:
                    branch_name = jec_nba_to_branch_map[jec_uncertainty.removesuffix("_up").removesuffix("_down")]
                    sums_weights_withReshapeWeight[category_name][jec_uncertainty] += ak.sum(base_weights[cat_filter] * ar_region[f"bTagweightReshape{branch_name}_{jec_uncertainty.split('_')[-1]}"][cat_filter])

                processed_event_count[category_name] += ak.count_nonzero(cat_filter)
        if min(processed_event_count.values()) > max_events or processed_event_count["inclusive"] > max_events*5:
            break
        i += 1
    if not jec_uncertainty and  min(processed_event_count.values()) < 1000:
        print(f"### WARNING low processed event counts: {processed_event_count}")
    return dict(ratios={
        cat: {
            syst : sum_weights_noReshapeWeight/sum_weights_withReshapeWeight
            for syst, sum_weights_withReshapeWeight in sum_weights_withReshapeWeight_allSysts.items()
        }
        for cat, sum_weights_noReshapeWeight, sum_weights_withReshapeWeight_allSysts in zip(category_names, sums_weights_noReshapeWeight.values(), sums_weights_withReshapeWeight.values())
        },
        sums_weights_noReshapeWeight=sums_weights_noReshapeWeight, sum_weights_withReshapeWeight=sums_weights_withReshapeWeight, processed_event_count=processed_event_count)

def map_dict(f, x):
    if isinstance(x, dict):
        return {k:map_dict(f, v) for k, v in x.items()}
    else:
        return f(x)


In [32]:
computeForDataset("dy_ptz6", jec_uncertainty="jec_1_up")

{'ratios': {'inclusive': {'jec_1_up': 0.8374379908966685},
  'etau': {'jec_1_up': 0.8385973751357593},
  'mutau': {'jec_1_up': 0.8395373378865951},
  'tautau': {'jec_1_up': 0.8330411106041994},
  'e&mutau': {'jec_1_up': 0.8391860509011405}},
 'sums_weights_noReshapeWeight': {'inclusive': 5452.569314999917,
  'etau': 1460.0926867548112,
  'mutau': 2449.520281853298,
  'tautau': 1542.9563463918096,
  'e&mutau': 3909.612968608108},
 'sum_weights_withReshapeWeight': {'inclusive': {'jec_1_up': 6511.012605436848},
  'etau': {'jec_1_up': 1741.1128749579486},
  'mutau': {'jec_1_up': 2917.702609891759},
  'tautau': {'jec_1_up': 1852.197120587138},
  'e&mutau': {'jec_1_up': 4658.815484849707}},
 'processed_event_count': {'inclusive': 33846,
  'etau': 9436,
  'mutau': 15965,
  'tautau': 8445,
  'e&mutau': 25401}}

In [None]:
results_perDataset = {}
for dataset in os.listdir("/grid_mnt/data__data.polcms/cms/cuisset/cmt/PreprocessRDF/bul_2018_ZZ_v12"):
    if dataset.startswith("data_") or dataset.startswith("GluGluToXToZZ") or dataset.startswith("qcd_"): continue # or dataset == "wjets_ht0"
    for jec_uncertainty in [None] + [s+"_up" for s in jec_nba_to_branch_map] + [s+"_down" for s in jec_nba_to_branch_map]:
        try:
            res = computeForDataset(dataset, jec_uncertainty=jec_uncertainty)
            if not jec_uncertainty:
                nominal_res = {key:val['nominal'] for key, val in res['ratios'].items()}
                print(f"{dataset} ratio={nominal_res} evtsProcessed={res['processed_event_count']['inclusive']}")
            results_perDataset.setdefault(dataset, dict())[jec_uncertainty] = res
        except Exception as e:
            print(e)

with open(f"studies/SFs/btag_extrap_ratio_{year}_systs_v2.json", "w") as f:
    json.dump(map_dict(float, results_perDataset), f,  indent=4)

dy_ptz6 ratio={'inclusive': 0.8375235735408024, 'etau': 0.8385479752383694, 'mutau': 0.840165458980845, 'tautau': 0.8323973586835032, 'e&mutau': 0.8395604708033134} evtsProcessed=33671
dy_ptz5 ratio={'inclusive': 0.901365585142078, 'etau': 0.8968222920644164, 'mutau': 0.9021081373449437, 'tautau': 0.9050731223458535, 'e&mutau': 0.9000914175086435} evtsProcessed=48220
dy_ptz4 ratio={'inclusive': 0.9385984299222828, 'etau': 0.9327742438545163, 'mutau': 0.9339309509700956, 'tautau': 0.949723359680108, 'e&mutau': 0.9335203279210954} evtsProcessed=52339
dy_ptz3 ratio={'inclusive': 0.9727412050220756, 'etau': 0.9637564638342211, 'mutau': 0.976924974213991, 'tautau': 0.9739332784793494, 'e&mutau': 0.9725152281476755} evtsProcessed=50098
ggf_sm ratio={'inclusive': 0.976788288577461, 'etau': 0.9787802041216921, 'mutau': 0.9777257939838655, 'tautau': 0.973693053712535, 'e&mutau': 0.9780629772235058} evtsProcessed=38228
dy ratio={'inclusive': 0.967059853040508, 'etau': 0.9526260300583895, 'mutau'

In [10]:
type(results_perDataset["zz_lnu"]["ratios"]["inclusive"])

numpy.float64

In [13]:
with open("btag_extrap_ratio_2017.json", "w") as f:
    json.dump({dataset : {key : {cat : float(x) for cat, x in key_vals.items()} for key, key_vals in dataset_val.items()} for dataset, dataset_val in results_perDataset.items()}, f, indent=4)


In [44]:
results_perDataset = {}
for dataset in os.listdir("/grid_mnt/data__data.polcms/cms/cuisset/cmt/PreprocessRDF/bul_2017_ZZ_v12"):
    if dataset.startswith("data_") or dataset.startswith("GluGluToXToZZ") or dataset == "wjets_ht0": continue
    # try:
    res = computeForDataset(dataset)
    print(f"{dataset} ratio={res['ratios']} evtsProcessed={res['processed_event_count']['inclusive']}")
    results_perDataset[dataset] = res
    # except Exception as e:
    #     print(e)

wz_lllnu ratio={'inclusive': 0.926000383370768, 'etau': 0.9228495588897558, 'mutau': 0.9288177740705607, 'tautau': 0.9140606665960284, 'e&mutau': 0.9264106285862679} evtsProcessed=105409
zh_htt ratio={'inclusive': 0.9583144582194365, 'etau': 0.9531750440333058, 'mutau': 0.9618755439152351, 'tautau': 0.9567469608443468, 'e&mutau': 0.9586890985376116} evtsProcessed=195350
zz_dl ratio={'inclusive': 0.8965101521510584, 'etau': 0.8885969556770379, 'mutau': 0.905433418932915, 'tautau': 0.8628600506322012, 'e&mutau': 0.8982833178173455} evtsProcessed=301599
dy_ptz4 ratio={'inclusive': 0.9502086681017841, 'etau': 0.9485156728100025, 'mutau': 0.950553642343978, 'tautau': 0.9521857010201995, 'e&mutau': 0.9495426821438657} evtsProcessed=394428
tt_dl ratio={'inclusive': 0.9814520267098701, 'etau': 0.9811377329127998, 'mutau': 0.9816854229353031, 'tautau': 0.9795100381022406, 'e&mutau': 0.9814719057634458} evtsProcessed=6322376
dy_1j ratio={'inclusive': 0.9635050349974497, 'etau': 0.959790499561007

In [47]:
{key: (value["ratios"]["e&mutau"], value["ratios"]["tautau"])  for key, value in sorted(results_perDataset.items())}

{'dy': (0.9665163682206754, 0.961849797034292),
 'dy_0j': (0.9541915876857185, 0.9570455663308225),
 'dy_1j': (0.9633952914676255, 0.9664634545248503),
 'dy_2j': (0.9708396330638043, 0.9677263086266815),
 'dy_ptz1': (0.9672996164293796, 0.9726843265055428),
 'dy_ptz2': (0.9680616355218153, 0.9661950977330135),
 'dy_ptz3': (0.9710256109646898, 0.9727792603504616),
 'dy_ptz4': (0.9495426821438657, 0.9521857010201995),
 'dy_ptz5': (0.9127935788972893, 0.9144365980445499),
 'dy_ptz6': (0.8502322126757778, 0.8521666267791477),
 'ewk_wminus': (0.9842871083277758, 0.9767306141951412),
 'ewk_wplus': (0.9853165696207518, 0.9785948333625821),
 'ewk_z': (0.96733578204037, 0.966468541030316),
 'ggH_ZZ': (0.9689697985698629, 0.9733999150175852),
 'ggf_sm': (0.9767404270011315, 0.9792871394717136),
 'st_antitop': (0.9870256957450483, 0.9924326486915046),
 'st_top': (0.984888080022178, 0.9832185532982413),
 'st_tw_antitop': (0.9873616092348638, 0.9836377709272521),
 'st_tw_top': (0.9892851723253863, 