In [None]:
import pickle, json, gzip
import numpy as np
import hist
from hist import Hist

from typing import Optional, List, Dict
from copy import copy

import matplotlib.pyplot as plt
import mplhep as hep
from matplotlib import colors

from tqdm import tqdm

from pathlib import Path
import os

from HHbbVV.hh_vars import years, bg_keys
from HHbbVV.postprocessing import datacardHelpers
from postprocessing import nonres_shape_vars as shape_vars
import plotting

plt.rcParams.update({"font.size": 16})
plt.style.use(hep.style.CMS)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
plot_dir = Path("../../plots/Interpolate/24Aug1")
plot_dir.mkdir(parents=True, exist_ok=True)

templates_dir = Path("templates/24Apr26NonresBDT995AllSigs")

## Load and process templates

In [None]:
templates_dict: dict[str, dict[str, Hist]] = {}

for year in years:
    with (templates_dir / f"{year}_templates.pkl").open("rb") as f:
        templates_dict[year] = datacardHelpers.rem_neg(pickle.load(f))

templates = datacardHelpers.sum_templates(templates_dict, years)

In [None]:
vbf_keys = [
    "VBFHHbbVV",
    "qqHH_CV_1_C2V_1_kl_0_HHbbVV",
    "qqHH_CV_1_C2V_1_kl_2_HHbbVV",
    "qqHH_CV_1_C2V_0_kl_1_HHbbVV",
    "qqHH_CV_1_C2V_2_kl_1_HHbbVV",
    "qqHH_CV_1p5_C2V_1_kl_1_HHbbVV",
]

vbf_hists = {}

for key, h in templates.items():
    vbf_hists[key] = []
    for vbf_key in vbf_keys:
        vbf_hists[key].append(h[vbf_key, ...])

## Interpolation coefficients

In [None]:
import sympy

csamples = [
    # CV, C2V, kl
    (1.0, 1.0, 1.0),
    (1.0, 1.0, 0.0),
    (1.0, 1.0, 2.0),
    (1.0, 0.0, 1.0),
    (1.0, 2.0, 1.0),
    # (0.5, 1.0, 1.0),
    (1.5, 1.0, 1.0),
]

M = sympy.Matrix(
    [
        [
            CV**2 * kl**2,
            CV**4,
            C2V**2,
            CV**3 * kl,
            CV * C2V * kl,
            CV**2 * C2V,
        ]
        for i, (CV, C2V, kl) in enumerate(csamples)
    ]
)

# the vector of couplings
CV, C2V, kl = sympy.symbols("CV C2V kl")
c = sympy.Matrix(
    [
        [CV**2 * kl**2],
        [CV**4],
        [C2V**2],
        [CV**3 * kl],
        [CV * C2V * kl],
        [CV**2 * C2V],
    ]
)

# the vector of symbolic sample cross sections
s = sympy.Matrix([[sympy.Symbol("xs{}".format(i))] for i in range(len(csamples))])

# actual computation, i.e., matrix inversion and multiplications with vectors
M_inv = M.pinv()
coeffs = c.transpose() * M_inv
sigma = coeffs * s

In [None]:
def get_hist_interp(cv, c2v, Kl, hists):
    sigma_val = sigma.subs({CV: cv, C2V: c2v, kl: Kl})
    counts = []
    errs = []
    for i in range(len(hists[0].values())):
        count = np.array(
            sigma_val.subs(
                {sympy.Symbol(f"xs{j}"): hists[j].values()[i] for j in range(len(vbf_keys))}
            )
        )[0][0]
        err = np.array(
            sigma_val.subs(
                {
                    sympy.Symbol(f"xs{j}"): np.nan_to_num(
                        np.sqrt(hists[j].variances()[i]) / hists[j].values()[i]
                    )
                    for j in range(len(vbf_keys))
                }
            )
        )[0][0]

        if count < 1e-12:
            count = 0

        counts.append(count)
        errs.append(err)

    return np.array(counts), np.array(errs)

## Add interpolated signals to templates

In [None]:
interp_points = np.arange(-1.0, 3.1, 0.1)
samples = [f"qqHH_CV_1_C2V_{c:.1f}_kl_1_HHbbVV" for c in interp_points]

interp_hists = {}

for region in ["passvbf", "passggf", "fail"]:
    print(region)
    h = Hist(
        hist.axis.StrCategory(samples, name="Sample"),
        *templates["passvbf"].axes[1:],
        storage=hist.storage.Weight(),
    )
    for i, c in tqdm(enumerate(interp_points)):
        c_h, c_err = get_hist_interp(1.0, c, 1.0, vbf_hists[region])
        h.values()[i, :] = c_h
        h.variances()[i, :] = (c_err * c_h) ** 2

    interp_hists[region] = h

In [None]:
ctemplates = {}

for region in ["passvbf", "passggf", "fail"]:
    template = templates[region]
    # combined sig + bg samples
    csamples = list(template.axes[0]) + samples

    # new hist with all samples
    ctemplate = Hist(
        hist.axis.StrCategory(csamples, name="Sample"),
        *template.axes[1:],
        storage="weight",
    )

    # add background hists
    for sample in template.axes[0]:
        sample_key_index = np.where(np.array(list(ctemplate.axes[0])) == sample)[0][0]
        ctemplate.view(flow=True)[sample_key_index, ...] = template[sample, ...].view(flow=True)

    # add signal hists
    for sample in samples:
        sample_key_index = np.where(np.array(list(ctemplate.axes[0])) == sample)[0][0]
        ctemplate.view(flow=True)[sample_key_index, ...] = interp_hists[region][sample, ...].view(
            flow=True
        )

    ctemplates[region] = ctemplate

## Plot

In [None]:
selection_regions = {
    "passvbf": "VBF",
    "passggf": "ggF",
    # "fail": "Fail",
}

ylims = {"passggf": 200, "passvbf": 100, "fail": 7e5}

sig_scale_dict = {
    # "HHbbVV": 100,
    # "VBFHHbbVV": 2000,
    "qqHH_CV_1_C2V_1.6_kl_1_HHbbVV": 1,
    "qqHH_CV_1_C2V_0.6_kl_1_HHbbVV": 1,
    "qqHH_CV_1_C2V_0_kl_1_HHbbVV": 1,
    "qqHH_CV_1_C2V_2_kl_1_HHbbVV": 1,
}

for region, region_label in selection_regions.items():
    pass_region = region.startswith("pass")
    for i, shape_var in enumerate(shape_vars):
        plot_params = {
            "hists": ctemplates[region],
            "sig_keys": list(sig_scale_dict.keys()),
            "bg_keys": [],
            "sig_scale_dict": sig_scale_dict if pass_region else None,
            "show": True,
            "year": "all",
            "ylim": ylims[region],
            "title": f"Pre-fit {region_label} Region",
            "name": f"{plot_dir}/interp_{region}_{shape_var.var}_signal_log.pdf",
            "ncol": 2,  # if region == "passvbf" else 1,
            "ratio_ylims": [0, 5] if region == "passvbf" else [0, 2],
            "cmslabel": "Preliminary",
            "plot_data": False,
            "log": True,
        }

        plotting.ratioHistPlot(**plot_params, data_err=True)

#     break
# break