In [1]:
from __future__ import annotations

import pickle
from collections import OrderedDict
from pathlib import Path

import hist
import matplotlib.pyplot as plt
import numpy as np
import uproot
from hist import Hist

from boostedhh.hh_vars import data_key, years
from bbtautau.postprocessing.datacardHelpers import sum_templates
from bbtautau.postprocessing.postprocessing import shape_vars
from bbtautau.postprocessing import plotting
from bbtautau.postprocessing import utils as putils
from bbtautau.postprocessing.Samples import BGS, CHANNELS

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
MAIN_DIR = Path("../../../")

plot_dir = MAIN_DIR / "plots/PostFit/25Apr25AllYears"
plot_dir.mkdir(exist_ok=True, parents=True)

# years = ["2022"]

In [9]:
cards_dir = "25Apr25PassFix"
file = uproot.open(f"/home/users/rkansal/hhcombine/bbtautau/cards/{cards_dir}/FitShapes.root")

templates_dir = "templates/25Apr25LudoCuts"

In [None]:
# sig_templates_dir = Path(f"templates/25Apr23/{CHANNEL.key}")
# bg_templates_dir = Path(f"templates/25Apr23/{CHANNEL.key}")

# templates_dict = {}
# for year in years:
#     with (sig_templates_dir / f"{year}_templates.pkl").open("rb") as f:
#         templates_dict[year] = pickle.load(f)

# sig_pre_templates = sum_templates(templates_dict, years)

# templates_dict = {}
# for year in years:
#     with (bg_templates_dir / f"{year}_templates.pkl").open("rb") as f:
#         templates_dict[year] = pickle.load(f)

# bg_pre_templates = sum_templates(templates_dict, years)

In [5]:
def get_pre_templates(templates_dir):
    templates_dict = {}
    for year in years:
        with (templates_dir / f"{year}_templates.pkl").open("rb") as f:
            templates_dict[year] = pickle.load(f)

    return sum_templates(templates_dict, years)

In [6]:
workspace_data_key = "data_obs"

# (name in templates, name in cards)
hist_label_map_inverse = OrderedDict(
    [
        ("qcd", "CMS_bbtautau_boosted_qcd_datadriven"),
        ("ttbarsl", "ttbarsl"),
        ("ttbarll", "ttbarll"),
        ("ttbarhad", "ttbarhad"),
        ("dyjets", "dyjets"),
        ("wjets", "wjets"),
        ("zjets", "zjets"),
        ("hbb", "hbb"),
        (data_key, workspace_data_key),
    ]
)

hist_label_map = {val: key for key, val in hist_label_map_inverse.items()}

sig_keys = ["bbtt"]

# pbg_keys = [bk for bk in bg_keys if bk not in ["Diboson", "Hbb", "HWW"]]
pbg_keys = BGS
samples = pbg_keys + sig_keys + [data_key]

In [7]:
shapes = {
    "prefit": "Pre-Fit",
    # "shapes_fit_s": "S+B Post-Fit",
    "postfit": "B-only Post-Fit",
}

selection_regions = {}
for channel in CHANNELS.values():
    selection_regions[f"{channel.key}pass"] = f"{channel.label} Pass"
    selection_regions[f"{channel.key}fail"] = f"{channel.label} Fail"

In [None]:
hists = {}
bgerrs = {}

for shape in shapes:
    print(shape)
    hists[shape] = {
        region: Hist(
            hist.axis.StrCategory(samples, name="Sample"),
            *[shape_var.axis for shape_var in shape_vars],
            storage="double",
        )
        for region in selection_regions
    }
    bgerrs[shape] = {}

    for region in selection_regions:
        h = hists[shape][region]
        templates = file[f"{region}_{shape}"]
        for key, file_key in hist_label_map_inverse.items():
            if key != data_key:
                if file_key not in templates:
                    print(f"No {key} in {region}")
                    continue

                data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
                h.view(flow=False)[data_key_index, :] = templates[file_key].values()

        # # if key not in fit output, take from templates
        # for key in pbg_keys:
        #     if key not in hist_label_map_inverse:
        #         data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
        #         h.view(flow=False)[data_key_index, :] = bg_pre_templates[region][key, ...].values()

        # if key not in fit output, take from templates
        for key in sig_keys:
            if key not in hist_label_map_inverse:
                sig_pre_templates = get_pre_templates(Path(f"{templates_dir}/{region[:2]}"))
                data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
                h.view(flow=False)[data_key_index, :] = sig_pre_templates[region[2:]][
                    key + region[:2], ...
                ].values()

        data_key_index = np.where(np.array(list(h.axes[0])) == data_key)[0][0]
        h.view(flow=False)[data_key_index, :] = np.nan_to_num(
            templates[hist_label_map_inverse[data_key]].values()
        )

        bgerrs[shape][region] = np.minimum(
            templates["TotalBkg"].errors(), templates["TotalBkg"].values()
        )

In [9]:
# if not unblinded:
#     for shapeh in hists.values():
#         for region, h in shapeh.items():
#             if region != "fail":
#                 utils.blindBins(h, [100, 150], data_key, axis=0)

In [None]:
# ylims = {"hhpass": 1, "passvbf": 11, "fail": 7e5}
sig_scale_dict = {"bbtt": 100}

(plot_dir / "preliminary").mkdir(exist_ok=True, parents=True)
(plot_dir / "final").mkdir(exist_ok=True, parents=True)

for prelim, plabel, pplotdir in zip([True, False], ["Preliminary", ""], ["preliminary", "final"]):
    for shape, shape_label in shapes.items():
        # if shape != "postfit":
        #     continue
        for region, region_label in selection_regions.items():
            pass_region = "pass" in region
            for i, shape_var in enumerate(shape_vars):
                plot_params = {
                    "hists": hists[shape][region],
                    "sig_keys": sig_keys,
                    "bg_keys": pbg_keys,
                    "bg_err": bgerrs[shape][region],
                    "data_err": True,
                    "sig_scale_dict": sig_scale_dict if pass_region else None,
                    "show": True,
                    "year": "2022",
                    # "ylim": ylims[region],
                    # "title": f"{shape_label} {region_label} Region{title_label}",
                    "region_label": region_label,
                    "name": f"{plot_dir}/{shape}_{region}_{shape_var.var}.pdf",
                    "ratio_ylims": [0, 2],
                    "cmslabel": plabel,
                    "leg_args": {"fontsize": 22, "ncol": 2},
                    "channel": CHANNELS[region[:2]],
                }

                plotting.ratioHistPlot(**plot_params)

        # break
    break

## QCD Transfer Factor

In [None]:
import matplotlib.ticker as mticker
import mplhep as hep

plt.style.use(hep.style.CMS)
hep.style.use("CMS")
formatter = mticker.ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-3, 3))

In [None]:
ylims = {"passggf": 1e-4, "passvbf": 1e-5}
tfs = {}

for region, region_label in selection_regions.items():
    if region == "fail":
        continue

    tf = hists["postfit"][region]["QCD", ...] / hists["postfit"]["fail"]["QCD", ...]
    tfs[region] = tf

    hep.histplot(tf)
    plt.title(f"{region_label} Region")
    plt.ylabel("QCD Transfer Factor")
    plt.xlim([50, 250])
    plt.ylim([0, ylims[region]])
    plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
    plt.savefig(f"{plot_dir}/{region}_QCDTF.pdf", bbox_inches="tight")
    plt.show()

In [None]:
tf = tfs["passvbf"]
slope = (tf.view()[-1] - tf.view()[0]) / (245 - 55)
yint = tf.view()[0] - slope * 55
print(slope, yint)