# Control Plots

Author(s): Haoyang Li, Raghav Kansal

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import mplhep as hep
from matplotlib import colors

from boostedhh import utils, hh_vars, plotting
from boostedhh.utils import PAD_VAL
import Samples, postprocessing
from Samples import CHANNELS, SAMPLES, SIGNALS

import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("boostedhh.utils")
logger.setLevel(logging.DEBUG)

In [None]:
# automatically reloads imported files on edits
%load_ext autoreload
%autoreload 2

In [None]:
MAIN_DIR = Path("../../../")
CHANNEL = CHANNELS["hm"]  # options: "hh", "he", "hm"

plot_dir = MAIN_DIR / f"plots/ControlPlots/25Apr15{CHANNEL.key}"
plot_dir.mkdir(parents=True, exist_ok=True)

year = "2022"

base_dir = Path("/ceph/cms/store/user/rkansal/bbtautau/skimmer/")
data_paths = {
    "signal": base_dir / "25Apr17bbpresel_v12_private_signal",
    "data": base_dir / "25Apr17bbpresel_v12_private_signal",
    "bg": base_dir / "25Apr17bbpresel_v12_private_signal",
}

sigs = {s + CHANNEL.key: SAMPLES[s + CHANNEL.key] for s in SIGNALS}
bgs = {bkey: b for bkey, b in SAMPLES.items() if b.get_type() == "bg"}

## Load samples

In [None]:
# dictionary that will contain all information (from all samples)
events_dict = postprocessing.load_samples(
    year, CHANNEL, data_paths, load_bgs=True, filters_dict=postprocessing.bb_filters(num_fatjets=3)
)
cutflow = pd.DataFrame(index=list(events_dict.keys()))

utils.add_to_cutflow(events_dict, "Preselection", "finalWeight", cutflow)
cutflow

Triggers

In [None]:
postprocessing.apply_triggers(events_dict, year, CHANNEL)
utils.add_to_cutflow(events_dict, "Triggers", "finalWeight", cutflow)
cutflow

In [None]:
for key in events_dict.keys():
    Xbb = events_dict[key]["ak8FatJetPNetXbbLegacy"]
    QCD = events_dict[key]["ak8FatJetPNetQCDLegacy"]
    Xbb_vs_QCD = Xbb / (Xbb + QCD)
    for i in range(3):
        events_dict[key][f"ak8FatJetPNetXbbvsQCDLegacy{i}"] = Xbb_vs_QCD.loc[:, i]

In [None]:
bbtt_masks = postprocessing.bbtautau_assignment(events_dict, CHANNEL)

In [None]:
cutlabel = (
    r"$\geq 1$ AK8 jet with $p_T > 250$ & $m_{reg} > 50$ GeV"
    "\n"
    r"$\geq 1$ AK8 jet with $T_{Xbb} > 0.3$"
    "\n"
    r"$\geq 2$ AK8 jets with $p_T > 200$ GeV"
)

postprocessing.control_plots(
    events_dict,
    CHANNEL,
    sigs,
    bgs,
    postprocessing.control_plot_vars,
    plot_dir,
    year,
    bbtt_masks=bbtt_masks,
    cutlabel=cutlabel,
    plot_significance=False,
    show=True,
)

In [None]:
def make_rocs(
    events_dict: dict[str, pd.DataFrame],
    scores_key: str,
    weight_key: str,
    sig_key: str,
    bg_keys: list[str],
):
    rocs = {}
    for bkg in [*bg_keys, "merged"]:
        if bkg != "merged":
            scores_roc = np.concatenate(
                [
                    events_dict[sig_key][check_get_jec_var(scores_key, jshift)],
                    events_dict[bkg][scores_key],
                ]
            )
            scores_true = np.concatenate(
                [
                    np.ones(len(events_dict[sig_key])),
                    np.zeros(len(events_dict[bkg])),
                ]
            )
            scores_weights = np.concatenate(
                [events_dict[sig_key][weight_key], events_dict[bkg][weight_key]]
            )
            fpr, tpr, thresholds = roc_curve(scores_true, scores_roc, sample_weight=scores_weights)
        else:
            scores_roc = np.concatenate(
                [events_dict[sig_key][check_get_jec_var(scores_key, jshift)]]
                + [events_dict[bg_key][scores_key] for bg_key in bg_keys]
            )
            scores_true = np.concatenate(
                [
                    np.ones(len(events_dict[sig_key])),
                    np.zeros(np.sum([len(events_dict[bg_key]) for bg_key in bg_keys])),
                ]
            )
            scores_weights = np.concatenate(
                [events_dict[sig_key][weight_key]]
                + [events_dict[bg_key][weight_key] for bg_key in bg_keys]
            )
            fpr, tpr, thresholds = roc_curve(scores_true, scores_roc, sample_weight=scores_weights)

        rocs[bkg] = {
            "fpr": fpr,
            "tpr": tpr,
            "thresholds": thresholds,
            "label": plotting.label_by_sample[bkg] if bkg != "merged" else "Combined",
        }

    return rocs

In [None]:
def bdt_roc(events_combined: dict[str, pd.DataFrame], plot_dir: str, legacy: bool, jshift=""):
    sig_keys = [
        "hh4b",
        "hh4b-kl0",
        "hh4b-kl2p45",
        "hh4b-kl5",
        "vbfhh4b",
        "vbfhh4b-k2v0",
        "vbfhh4b-k2v2",
        "vbfhh4b-kl2",
    ]
    scores_keys = {
        "hh4b": "bdt_score",
        "hh4b-kl0": "bdt_score",
        "hh4b-kl2p45": "bdt_score",
        "hh4b-kl5": "bdt_score",
        "vbfhh4b": "bdt_score_vbf",
        "vbfhh4b-kl2": "bdt_score_vbf",
        "vbfhh4b-k2v2": "bdt_score_vbf",
        "vbfhh4b-k2v0": "bdt_score_vbf",
    }
    bkg_keys = ["qcd", "ttbar"]
    legtitle = get_legtitle(legacy, pnet_xbb_str="Legacy")

    if "bdt_score_vbf" not in events_combined["ttbar"]:
        sig_keys.remove("vbfhh4b-k2v0")

    for sig_key in sig_keys:
        rocs = make_rocs(
            events_combined,
            scores_keys[sig_key],
            "weight",
            sig_key,
            bkg_keys,
        )
        bkg_colors = {**plotting.color_by_sample, "merged": "orange"}
        fig, ax = plt.subplots(1, 1, figsize=(18, 12))
        for bg_key in [*bkg_keys, "merged"]:
            ax.plot(
                rocs[bg_key]["tpr"],
                rocs[bg_key]["fpr"],
                linewidth=2,
                color=bkg_colors[bg_key],
                label=rocs[bg_key]["label"],
            )

        ax.set_xlim([0.0, 0.6])
        ax.set_ylim([1e-5, 1e-1])
        ax.set_yscale("log")

        ax.set_title(f"{plotting.label_by_sample[sig_key]} BDT ROC Curve")
        ax.set_xlabel("Signal efficiency")
        ax.set_ylabel("Background efficiency")

        ax.xaxis.grid(True, which="major")
        ax.yaxis.grid(True, which="major")
        ax.legend(
            title=legtitle,
            bbox_to_anchor=(1.03, 1),
            loc="upper left",
        )
        fig.tight_layout()
        _jshift = f"_{jshift}" if jshift != "" else ""
        fig.savefig(plot_dir / f"{sig_key}_roc{_jshift}.png")
        fig.savefig(plot_dir / f"{sig_key}_roc{_jshift}.pdf", bbox_inches="tight")
        plt.close()

    bdt_axis = hist.axis.Regular(40, 0, 1, name="bdt", label=r"BDT")
    cat_axis = hist.axis.StrCategory([], name="cat", label="cat", growth=True)
    h_bdt = hist.Hist(bdt_axis, cat_axis)
    for sig_key in sig_keys:
        h_bdt.fill(
            events_combined[sig_key][scores_keys[sig_key]],
            sig_key,
            weight=events_combined[sig_key]["weight"],
        )

    def find_nearest(array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return idx

    th_colours = ["#9381FF", "#1f78b4", "#a6cee3", "cyan", "blue"]

    for vbf_in_sig_key in [True, False]:
        fig, ax = plt.subplots(1, 1, figsize=(18, 12))
        # add lines at BDT cuts
        plot_thresholds = [0.88, 0.98] if vbf_in_sig_key else [0.98]

        isig = 0
        for sig_key in sig_keys:
            if ("vbf" in sig_key) == vbf_in_sig_key:
                continue
            rocs = postprocessing.make_rocs(
                events_combined, scores_keys[sig_key], "weight", sig_key, bkg_keys
            )
            pths = {th: [[], []] for th in plot_thresholds}
            for th in plot_thresholds:
                idx = find_nearest(rocs["merged"]["thresholds"], th)
                pths[th][0].append(rocs["merged"]["tpr"][idx])
                pths[th][1].append(rocs["merged"]["fpr"][idx])
            # print(vbf_in_sig_key, " isig ",isig, sig_key, pths)
            for k, th in enumerate(plot_thresholds):
                if isig == 0:
                    ax.scatter(
                        *pths[th],
                        marker="o",
                        s=40,
                        label=rf"BDT > {th}",
                        color=th_colours[k],
                        zorder=100,
                    )
                else:
                    ax.scatter(
                        *pths[th],
                        marker="o",
                        s=40,
                        color=th_colours[k],
                        zorder=100,
                    )

            ax.plot(
                rocs["merged"]["tpr"],
                rocs["merged"]["fpr"],
                linewidth=2,
                color=plotting.color_by_sample[sig_key],
                label=plotting.label_by_sample[sig_key],
            )
            isig = isig + 1
        ax.set_xlim([0.0, 0.6])
        ax.set_ylim([1e-5, 1e-1])
        ax.set_yscale("log")
        if vbf_in_sig_key:
            ax.set_title("ggF BDT ROC Curve")
        else:
            ax.set_title("VBF BDT ROC Curve")
        ax.set_xlabel("Signal efficiency")
        ax.set_ylabel("Background efficiency")
        ax.xaxis.grid(True, which="major")
        ax.yaxis.grid(True, which="major")

        ax.legend(
            title=legtitle,
            bbox_to_anchor=(1.03, 1),
            loc="upper left",
        )
        fig.tight_layout()
        if vbf_in_sig_key:
            fig.savefig(plot_dir / f"GGF_hh4b_allroc{_jshift}.png", bbox_inches="tight")
            fig.savefig(plot_dir / f"GGF_hh4b_allroc{_jshift}.pdf", bbox_inches="tight")
        else:
            fig.savefig(plot_dir / f"VBF_hh4b_allroc{_jshift}.png", bbox_inches="tight")
            fig.savefig(plot_dir / f"VBF_hh4b_allroc{_jshift}.pdf", bbox_inches="tight")
        plt.close()

        # plot scores too
        fig, ax = plt.subplots(1, 1, figsize=(18, 12))
        for sig_key in sig_keys:
            if ("vbf" in sig_key) == vbf_in_sig_key:
                continue
            hep.histplot(
                h_bdt[{"cat": sig_key}],
                ax=ax,
                label=plotting.label_by_sample[sig_key],
                color=plotting.color_by_sample[sig_key],
                histtype="step",
                linewidth=1.5,
                density=True,
                flow="none",
            )
        ax.legend()
        fig.tight_layout()
        if vbf_in_sig_key:
            fig.savefig(plot_dir / f"GGF_hh4b_allbdt{_jshift}.png", bbox_inches="tight")
            fig.savefig(plot_dir / f"GGF_hh4b_allbdt{_jshift}.pdf", bbox_inches="tight")
        else:
            fig.savefig(plot_dir / f"VBF_hh4b_allbdt{_jshift}.png", bbox_inches="tight")
            fig.savefig(plot_dir / f"VBF_hh4b_allbdt{_jshift}.pdf", bbox_inches="tight")
        plt.close()