# Parse, combine and interpolate limits

In [None]:
from __future__ import annotations

import os
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import mplhep as hep
import numpy as np
import pandas as pd
from scipy import interpolate
from tqdm import tqdm

from HHbbVV.hh_vars import res_sigs
from HHbbVV.postprocessing import plotting
from HHbbVV.postprocessing.utils import mxmy
from HHbbVV.resonant import ProcessLimits
from HHbbVV.resonant.ProcessLimits import get_lim

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
MAIN_DIR = "../../../"
plot_dir = Path(f"{MAIN_DIR}/plots/XHY/Limits/25Feb22Unblinding")
plot_dir.mkdir(parents=True, exist_ok=True)

# cards_dir = Path("/eos/uscms/store/user/rkansal/bbVV/cards/25Feb27QCDTF")
cards_dir = Path("/eos/uscms/store/user/rkansal/bbVV/cards/25Feb19ResUnblinded")

## Load / process limits

In [None]:
limits = ProcessLimits.get_limits(cards_dir)

### Load Amitav's limits

In [None]:
alimits_path = Path(
    "/uscms/home/ammitra/nobackup/2DAlphabet/fitting/CMSSW_14_1_0_pre4/src/XHYbbWW/limits/"
)
alimits = {
    " 2.5": [],
    "16.0": [],
    "50.0": [],
    "84.0": [],
    "97.5": [],
    "Observed": [],
    "Significance": [],
}
key_map = {
    # mine: amitav's
    " 2.5": "limits_Minus2",
    "16.0": "limits_Minus1",
    "50.0": "limits_Expected",
    "84.0": "limits_Plus1",
    "97.5": "limits_Plus2",
    "Observed": "limits_OBSERVED",
    "Significance": "significance",
}

for mkey, akey in key_map.items():
    alimits[mkey] = pd.read_csv(alimits_path / f"{akey}.csv").values[:, 1:]

Min expected limits

In [None]:
print(np.min(limits["50.0"][:, 2]))
print(np.min(alimits["50.0"][:, 2]))

Checking excesses

In [None]:
twosigma = limits["97.5"][:, 2] < limits["Observed"][:, 2]

for i in range(np.sum(twosigma)):
    mx, my = limits["50.0"][twosigma][i][:2]
    print(
        f"({mx}, {my}):\t Expected {limits['50.0'][twosigma][i, 2]}+{limits['97.5'][twosigma][i, 2]}\t Observed {limits['Observed'][twosigma][i, 2]:.2f}\t Sign {limits['Significance'][twosigma][i, 2]:.2f}"
    )

# print(limits["50.0"][twosigma], limits["97.5"][twosigma], limits["Observed"][twosigma])

In [None]:
alimits["Significance"][np.argmax(alimits["Significance"][:, 2])]

## Process and plot

### Boosted alone

In [None]:
mymax = 250
mxs = np.logspace(np.log10(600), np.log10(3999), 100, base=10)
mys = np.logspace(np.log10(60), np.log10(mymax), 100, base=10)

xx, yy = np.meshgrid(mxs, mys)

interpolated = {}
grids = {}

for key, val in limits.items():
    interpolated[key] = interpolate.LinearNDInterpolator(val[:, :2], np.log(val[:, 2]))
    grids[key] = np.exp(interpolated[key](xx, yy))

In [None]:
for key, grid in grids.items():
    label = (
        f"{key}% expected exclusion limits (fb)"
        if key != "50.0"
        else "Median expected exclusion limits (fb)"
    )
    plotting.colormesh(xx, yy, grid, label, f"{plot_dir}/upper{mymax}_mesh_{key}_turbo.pdf")

In [None]:
key = "50.0"
val = limits[key]
plotting.scatter2d(val, f"Median expected exclusion limits (fb)", f"{plot_dir}/scatter_{key}.pdf")

### Check whose expected limit is better

In [None]:
sb_better = []
alim_med = alimits["50.0"]

for mx, my, lim in limits["50.0"]:
    match = (alim_med[:, 0] == mx) * (alim_med[:, 1] == my)
    if np.any(match):
        alim = float(alim_med[:, 2][match])

    if alim < lim:
        pbetter = (lim - alim) / lim
        print(f"Semiboosted better for ({mx}, {my}) by {pbetter * 100:.2f}%")
        sb_better.append([mx, my, pbetter])

sb_better = np.array(sb_better)

In [None]:
plotting.scatter2d_overlay(
    limits["50.0"],
    sb_better,
    f"Median expected exclusion limits (fb)",
    f"{plot_dir}/scatter_overlay.pdf",
    show=True,
)

### Combined

In [None]:
combined_limits = {
    " 2.5": [],
    "16.0": [],
    "50.0": [],
    "84.0": [],
    "97.5": [],
    "Observed": [],
    "Significance": [],
}
alim_med = alimits["50.0"]
blim_med = limits["50.0"]

checked_mxmy = []

for mxy in np.vstack((alim_med, blim_med))[:, :2]:
    mx, my = mxy
    mxy = (int(mxy[0]), int(mxy[1]))
    if mx < 900:
        continue

    if mxy in checked_mxmy:
        continue

    amatch, alim = get_lim(alim_med, mxy)
    bmatch, blim = get_lim(blim_med, mxy)

    alim = alim[0, 2] if np.any(amatch) else np.inf
    blim = blim[0, 2] if np.any(bmatch) else np.inf

    if alim < blim and (my < 200):
        # skipping samples for which 2018 PFNano failed !! :(
        print(f"Skipping {mxy} because of missing PFNano!")
        continue

    use_lims = alimits if alim < blim else limits

    for key, lims in combined_limits.items():
        umatch, lim = get_lim(use_lims[key], mxy)
        if np.any(umatch):
            lims.append([*mxy, use_lims[key][umatch][0, 2]])
        else:
            print(f"Missing {mxy} for {key}!")

    checked_mxmy.append(mxy)

for key, val in combined_limits.items():
    combined_limits[key] = np.array(val)

In [None]:
idx = np.argmax(alimits["Significance"][:, 2])
print(alimits["Significance"][idx])
idx = np.argmax(limits["Significance"][:, 2])
print(limits["Significance"][idx])

Checking excesses

In [None]:
twosigma = combined_limits["97.5"][:, 2] < combined_limits["Observed"][:, 2]

for i in range(np.sum(twosigma)):
    mx, my = combined_limits["50.0"][twosigma][i][:2]
    print(
        f"({mx}, {my}): Expected {combined_limits['50.0'][twosigma][i, 2]}+{combined_limits['97.5'][twosigma][i, 2]}\t Observed {combined_limits['Observed'][twosigma][i, 2]:.2f}\t Sign {combined_limits['Significance'][twosigma][i, 2]:.2f}"
    )

In [None]:
mxs = np.logspace(np.log10(800), np.log10(3999), 300, base=10)
mys = np.logspace(np.log10(60), np.log10(2800), 300, base=10)
cxx, cyy = np.meshgrid(mxs, mys)

for key, val in combined_limits.items():
    # if key != "50.0":
    #     continue

    interpolated = interpolate.LinearNDInterpolator(val[:, :2], np.log(val[:, 2]))
    grid = np.exp(interpolated(cxx, cyy))

    if key == "50.0":
        label = "Median expected exclusion limits (fb)"
    elif key == "Observed":
        label = "Exclusion limits (fb)"
    elif key == "Significance":
        label = "Signal Significance"
    else:
        label = f"{key}% expected exclusion limits (fb)"

    plotting.colormesh(
        cxx, cyy, grid, label, f"{plot_dir}/combined_mesh_{key}.pdf", figsize=(12, 8), show=False
    )