In [None]:
from collections import OrderedDict

import uproot
import numpy as np
import matplotlib.pyplot as plt

import hist
from hist import Hist

import plotting
from hh_vars import data_key
from postprocessing import res_shape_vars, get_res_selection_regions

import os

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
MAIN_DIR = "../../../"

plot_dir = "../../../plots/PostFit/23May2NewWP"
_ = os.system(f"mkdir -p {plot_dir}")

In [None]:
# cards_dir = "f_tests/Apr26/nTF1_0_nTF2_0"
cards_dir = "23May1Scan/txbb_HP_thww_0.8/NMSSM_XToYHTo2W2BTo4Q2B_MX-3000_MY-250/"
file = uproot.open(
    # f"/uscms/home/rkansal/nobackup/CMSSW_11_2_0/src/cards/{cards_dir}/FitShapes.root"
    f"/uscms/home/rkansal/eos/bbVV/cards/{cards_dir}/FitShapes.root"
)

In [None]:
mx, my = 3000, 250

# (name in templates, name in cards)
hist_label_map_inverse = OrderedDict(
    [
        ("QCD", "XHYbbWW_boosted_qcd_datadriven"),
        # ("Diboson", "diboson"),
        ("TT", "ttbar"),
        ("ST", "singletop"),
        ("V+Jets", "vjets"),
        # ("X[3000]->H(bb)Y[190](VV)", "xhy_mx3000_my190"),
        (f"X[{mx}]->H(bb)Y[{my}](VV)", f"xhy_mx{mx}_my{my}"),
        (data_key, "data_obs"),
    ]
)

hist_label_map = {val: key for key, val in hist_label_map_inverse.items()}
samples = list(hist_label_map.values())

In [None]:
shapes = {
    "prefit": "Pre-Fit",
    # "shapes_fit_s": "S+B Post-Fit",
    "postfit": "B-only Post-Fit",
}

shape_vars = res_shape_vars

selection_regions = {
    "pass": "Pass",
    "fail": "Fail",
    "passBlinded": "Validation Pass",
    "failBlinded": "Validation Fail",
}

In [None]:
hists = {}

bins = list(shape_vars[0].axis)
binsize = np.array([b[1] - b[0] for b in bins])

for shape in shapes:
    hists[shape] = {
        region: Hist(
            hist.axis.StrCategory(samples, name="Sample"),
            *[shape_var.axis for shape_var in shape_vars],
            storage="double",
        )
        for region in selection_regions
    }

    for region in selection_regions:
        h = hists[shape][region]

        for i in range(len(shape_vars[1].axis)):  # mX bins
            # templates = file[shape][f"mXbin{i}{region}"]
            templates = file[f"mXbin{i}{region}_{shape}"]
            for key, file_key in hist_label_map_inverse.items():
                if key != data_key:
                    if file_key not in templates:
                        # print(f"No {key} in mXbin{i}{region}")
                        continue

                    data_key_index = np.where(np.array(list(h.axes[0])) == key)[0][0]
                    # h.view(flow=False)[data_key_index, :, i] = (
                    #     templates[file_key].values() * binsize
                    # )
                    h.view(flow=False)[data_key_index, :, i] = templates[file_key].values()

            data_key_index = np.where(np.array(list(h.axes[0])) == data_key)[0][0]
            # h.view(flow=False)[data_key_index, :, i] = (
            #     templates[hist_label_map_inverse[data_key]].values()[1] * binsize
            # )
            h.view(flow=False)[data_key_index, :, i] = np.nan_to_num(
                templates[hist_label_map_inverse[data_key]].values()
            )

In [None]:
pass_ylim = 300
fail_ylim = 170000
for shape, shape_label in shapes.items():
    for region, region_label in selection_regions.items():
        pass_region = region.startswith("pass")
        for i, shape_var in enumerate(shape_vars):
            plot_params = {
                "hists": hists[shape][region].project(0, i + 1),
                "sig_keys": [f"X[{mx}]->H(bb)Y[{my}](VV)"],
                "bg_keys": ["QCD", "V+Jets", "TT", "ST"],
                "sig_scale_dict": None,
                "show": True,
                "year": "all",
                "ylim": pass_ylim if pass_region else fail_ylim,
                "plot_data": region != "pass",
                "title": f"{shape_label} {region_label} Region",
                "name": f"{plot_dir}/{shape}_{region}_{shape_var.var}.pdf",
            }

            plotting.ratioHistPlot(**plot_params)

In [None]:
for shape, shape_label in shapes.items():
    _ = os.system(f"mkdir -p {plot_dir}/{shape}")
    samples = (
        ["Data", "TT", "V+Jets", "QCD", "X[3000]->H(bb)Y[190](VV)"]
        if shape == "shapes_prefit"
        else ["TT", "V+Jets", "QCD"]
    )

    plotting.hist2ds(
        hists[shape],
        f"{plot_dir}/{shape}/",
        regions=["pass", "fail", "passBlinded", "failBlinded"],
        region_labels=selection_regions,
        samples=samples,
        # fail_zlim=5e3,
        # pass_zlim=1.0,
    )