In [None]:
from __future__ import annotations

import os
from collections import OrderedDict
import pickle

import hist
import matplotlib.pyplot as plt
import numpy as np
import plotting
import uproot
from HHbbVV.hh_vars import data_key, bg_keys, years, qcd_key
from hist import Hist
from HHbbVV.postprocessing.postprocessing import res_shape_vars
from pathlib import Path

from datacardHelpers import sum_templates

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
MAIN_DIR = "../../../"

plot_dir = Path(f"{MAIN_DIR}/plots/QCDTF/25Feb2722")
plot_dir.mkdir(exist_ok=True, parents=True)

file = uproot.open("/home/users/rkansal/hhcombine/cards/qcdftests_readw2/nTF22/FitShapesB.root")

In [None]:
# (name in templates, name in cards)
hist_label_map_inverse = OrderedDict(
    [
        ("QCD Fit", "qcd"),
        ("QCD MC", "data_obs"),
    ]
)

hist_label_map = {val: key for key, val in hist_label_map_inverse.items()}
samples = list(hist_label_map_inverse.keys())

In [None]:
templates_dir = Path("templates/25Feb23ResTemplatesHbbUncs")

templates_dict = {}
for year in years:
    with (templates_dir / f"{year}_templates.pkl").open("rb") as f:
        templates_dict[year] = pickle.load(f)
pre_templates = sum_templates(templates_dict, years)

In [None]:
shapes = {
    # "prefit": "Pre-Fit",
    "postfit": "B-only Post-Fit",
}

shape_vars = res_shape_vars

selection_regions = OrderedDict(
    [
        ("fail", "SR Fail"),
        ("pass", "SR Pass"),
    ]
)

In [None]:
shape = "postfit"

hists = {
    region: Hist(
        hist.axis.StrCategory(samples, name="Sample"),
        *[shape_var.axis for shape_var in shape_vars],
        storage="weight",
    )
    for region in selection_regions
}

for region in selection_regions:
    h = hists[region]

    for i in range(len(shape_vars[1].axis)):  # mX bins
        templates = file[f"mXbin{i}{region}_{shape}"]
        for key, file_key in hist_label_map_inverse.items():
            data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
            if file_key == "qcd":
                vals, variances = templates[file_key].values(), templates[file_key].variances()
            else:
                vals, variances = (
                    pre_templates[region][qcd_key, :, i].values(),
                    pre_templates[region][qcd_key, :, i].variances(),
                )

            h.values(flow=False)[data_key_index, :, i] = vals
            h.variances(flow=False)[data_key_index, :, i] = variances

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import mplhep as hep

plt.style.use(hep.style.CMS)
hep.style.use("CMS")
formatter = mticker.ScalarFormatter(useMathText=True)
formatter.set_powerlimits((-3, 3))

# plot the histograms
for i, shape_var in enumerate(shape_vars):
    fig, axs = plt.subplots(
        1,
        2,
        figsize=(24, 11),
        gridspec_kw=dict(wspace=0.17),
    )

    for j, (region, rlabel) in enumerate(selection_regions.items()):
        ax = axs[j]
        # rax = axs[1, j]
        h = plotting._divide_bin_widths(hists[region].project(0, i + 1), 1, 1, None)[0]

        hep.histplot(
            h["QCD Fit", ...],
            ax=ax,
            histtype="fill",
            stack=True,
            label="QCD Fit",
            color=plotting.colours[plotting.BG_COLOURS["QCD"]],
        )

        bg_err = [
            h["QCD Fit", ...].values() - h["QCD Fit", ...].variances() ** 0.5,
            h["QCD Fit", ...].values() + h["QCD Fit", ...].variances() ** 0.5,
        ]

        ax.fill_between(
            np.repeat(h.axes[1].edges, 2)[1:-1],
            np.repeat(bg_err[0], 2),
            np.repeat(bg_err[1], 2),
            color="black",
            alpha=0.2,
            hatch="//",
            linewidth=0,
            label="QCD Fit Uncertainty",
        )

        hep.histplot(
            h["QCD MC", ...],
            ax=ax,
            histtype="errorbar",
            label="QCD MC",
            color="black",
            xerr=True,
            markersize=15,
        )

        plotting.add_cms_label(ax, "all", loc=0)
        ax.set_ylabel("Events / GeV")
        ax.set_xlabel(shape_var.label)
        ax.legend()

    plt.savefig(plot_dir / f"{shape_var.var}.pdf")
    plt.show()