In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import matplotlib as mpl

from tqdm.auto import tqdm

import warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

from multiprocess import Pool

In [None]:
import pyanalib.pandas_helpers as ph
from makedf.util import *

import kinematics
import gump_cuts as gc

In [None]:
PLOTDIR = "/Users/gputnam/Work/osc/cafpyana/plots/"

DOSAVE = True

import os
os.makedirs(PLOTDIR, exist_ok=True)
os.makedirs(PLOTDIR + "/png", exist_ok=True)
os.makedirs(PLOTDIR + "/pdf", exist_ok=True)

In [None]:
DETECTOR = "SBND SPINE"

In [None]:
if DETECTOR == "ICARUS":
    ONBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/Run2_BNB_uncalo_prescaled.df"
    OFFBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/Run2_BNBoff_uncalo_prescaled.df"
    
    ONBEAMPOT = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/Run2_BNB_uncalo_unblind_POT.df"

elif DETECTOR == "SBND": 
    ONBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SpringBNBData_Dev.df"
    OFFBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SpringBNBOffData_5000.df"

elif DETECTOR == "SBND SPINE": 
    ONBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SPINE_SpringBNBDevData.df"
    OFFBEAM = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SPINE_SpringBNBOffData.df"

In [None]:
import h5py

def read_dfs(file, key):
    with h5py.File(file, "r") as f:
        keys = [k for k in f.keys() if k.startswith(key)]
        return pd.concat([pd.read_hdf(file, k) for k in keys])

In [None]:
if DETECTOR == "ICARUS":
    ngates_ON = read_dfs(ONBEAM, "trig").gate_delta.sum()*(1-1/100.)
    ngates_OFF = read_dfs(OFFBEAM, "trig").gate_delta.sum()*(1-1/20.)
    
    OFF_w = ngates_ON / ngates_OFF
elif "SBND" in DETECTOR:
    ngates_ON = read_dfs(ONBEAM, "bnb").shape[0]
    ngates_OFF = read_dfs(OFFBEAM, "hdr").noffbeambnb.sum()

    f_factor = 0.0754
    OFF_w = (1. - f_factor) * (ngates_ON) / (ngates_OFF)
    
ngates_ON, ngates_OFF, OFF_w

In [None]:
if DETECTOR == "ICARUS":
    print("ON:", 1/read_dfs(ONBEAM, "trig").gate_delta.mean(), "OFF:", 1/read_dfs(OFFBEAM, "trig").gate_delta.mean())

In [None]:
if DETECTOR == "ICARUS":
    # POT = pd.read_hdf(ONBEAMPOT).pot.sum()*1e12
    POT = read_dfs(ONBEAM, "hdr").merge(pd.read_hdf(ONBEAMPOT), left_index=True, right_index=True, how="left").pot_y.sum()*1e12
elif "SBND" in DETECTOR:
    POT = read_dfs(ONBEAM, "bnb").TOR875.sum()
    
POT

In [None]:
read_dfs(ONBEAM, "bnb").TOR875.sum()/ 1e19

In [None]:
print("N GATES ON / 5e12 POT")
print(5e12*ngates_ON/POT)

In [None]:
crt_ON = read_dfs(ONBEAM, "crt")
crt_OFF = read_dfs(OFFBEAM, "crt")

In [None]:
def top_crt(crtdf):
    return (crtdf.plane >= 30) & (crtdf.plane <= 40)

def side_crt(crtdf):
    return ~top_crt(crtdf)

In [None]:
# CRT_intime_hit_ON = crt_ON.time[top_crt(crt_ON) & (crt_ON.time > -1)].groupby(level=[0,1]).min()
# CRT_intime_hit_OFF = crt_OFF.time[top_crt(crt_OFF) & (crt_OFF.time > -1)].groupby(level=[0,1]).min()

CRT_intime_hit_ON = crt_ON.time[side_crt(crt_ON)]
CRT_intime_hit_OFF = crt_OFF.time[side_crt(crt_OFF)]

In [None]:
bins = np.linspace(-2, 3, 51)
# bins = np.linspace(-10, 10, 101)

N,bins = np.histogram(CRT_intime_hit_ON, bins=bins)
centers = (bins[:-1] + bins[1:]) / 2

plt.errorbar(centers, N, np.sqrt(N), color="black", linestyle="none", marker=".", label="Beam ON")

Noff,_ = np.histogram(CRT_intime_hit_OFF, bins=bins)
plt.errorbar(centers, Noff*OFF_w, np.sqrt(Noff)*OFF_w, color="red", linestyle="none", marker=".", label="Beam OFF")

plt.legend()

In [None]:
CRTLO = -1
CRTHI = 1.8

In [None]:
if DETECTOR == "ICARUS":
    FILE = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/ICARUS_SpringMC_Dev.df"
elif DETECTOR == "SBND":
    FILE = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SpringMC_5000.df"
elif DETECTOR == "SBND SPINE":
    FILE = "/Users/gputnam/Work/osc/cafpyana/analysis_village/gump/SBND_SPINE_SpringMC.df"

In [None]:
FILE

In [None]:
def load_data(file, nfiles=-1):
    """Load event, header, and mcnu data from HDF file."""

    if nfiles < 0:
        with h5py.File(file, "r") as f:
            nfiles = len([k for k in f.keys() if k.startswith("hdr")])
    
    for s in range(nfiles):
        print("df index:"+str(s))
        if "SPINE" not in DETECTOR:
            df_evt = pd.read_hdf(file, "evt_"+str(s))
        else:
            df_evt = pd.read_hdf(file, "sevt_"+str(s))

        df_hdr = pd.read_hdf(file, "hdr_"+str(s))
        df_mcnu = pd.read_hdf(file, "wgt_"+str(s))

        cols_to_drop = []
        for c in df_mcnu.columns:
            if (c[0] != "GENIE") and (c[0] != "Flux") and (c[0] != "genie_mode"):
                cols_to_drop.append(c)
                
        genie_col = [c for c in df_mcnu.columns if c[0].startswith("GENIE")][:50]
        flux_col = [c for c in df_mcnu.columns if c[0].startswith("Flux")][:50]

        wgt_cols = []
        for i, (f, g) in enumerate(zip(flux_col, genie_col)):
            df_mcnu["WGT_univ_%i" % i] = df_mcnu[f]*df_mcnu[g]
            wgt_cols.append("WGT_univ_%i" % i)

        cols_to_drop = [c for c in df_mcnu.columns if c[0] not in wgt_cols and (c[0] != "genie_mode")]

        df_mcnu.drop(cols_to_drop, axis=1, inplace=True)
        matchdf = df_evt.copy()
        matchdf.columns = pd.MultiIndex.from_tuples([(col, '') for col in matchdf.columns])
        df_evt = ph.multicol_merge(matchdf.reset_index(), df_mcnu.reset_index(),
                                    left_on=[("__ntuple", ""), ("entry", ""), ("tmatch_idx", "")],
                                    right_on=[("__ntuple", ""), ("entry", ""), ("rec.mc.nu..index", "")],
                                    how="left") ## -- save all sllices

        for column in wgt_cols:
            df_evt[column] = df_evt[column].fillna(1)

        df_evt.columns = ["_".join([s for s in c if s]) for c in df_evt.columns]

        if s == 0:
            res_df_evt = df_evt
            res_df_hdr = df_hdr
        else:
            res_df_evt = pd.concat([res_df_evt, df_evt])
            res_df_hdr = pd.concat([res_df_hdr, df_hdr])

        del df_evt
        del df_hdr

    return res_df_evt, res_df_hdr

In [None]:
df, hdr = load_data(FILE)

In [None]:
df

In [None]:
pd.read_hdf(FILE, "trig_0")

In [None]:
hdr.shape[0], hdr.pot.sum(), hdr.shape[0] / (hdr.pot.sum() / 1e15)

In [None]:
NEVT_ON = read_dfs(ONBEAM, "hdr").shape[0]
NEVT_OFF = read_dfs(OFFBEAM, "hdr").shape[0]

NEVT = NEVT_ON - NEVT_OFF*OFF_w

In [None]:
NEVT_ON, POT, NEVT_ON / (POT / 1e15)

In [None]:
NEVT, POT, NEVT / (POT / 1e15)

In [None]:
def scale_pot(df, df_hdr, desired_pot):
    """Scale DataFrame by desired POT."""
    pot = sum(df_hdr.pot.tolist())
    print(f"POT: {pot}\nScaling to: {desired_pot}")
    scale = POT / pot
    df['glob_scale'] = scale
    return pot, scale

scale_pot(df, hdr, POT)

In [None]:
ONdf = read_dfs(ONBEAM, "evt" if "SPINE" not in DETECTOR else "sevt")
OFFdf = read_dfs(OFFBEAM, "evt" if "SPINE" not in DETECTOR else "sevt")

In [None]:
crthit = ((crt_OFF.time > CRTLO) & (crt_OFF.time < CRTHI)).groupby(level=[0, 1]).any()
crthit.name = "crthit"

crthit.mean()

In [None]:
OFFdf = OFFdf.join(crthit, on=["__ntuple", "entry"])

In [None]:
crthit = ((crt_ON.time > CRTLO) & (crt_ON.time < CRTHI)).groupby(level=[0, 1]).any()
crthit.name = "crthit"

crthit.mean()

In [None]:
ONdf = ONdf.join(crthit, on=["__ntuple", "entry"])

In [None]:
crt_MC = read_dfs(FILE, "crt")

In [None]:
crthit = ((crt_MC.time > CRTLO) & (crt_MC.time < CRTHI)).groupby(level=[0, 1]).any()
crthit.name = "crthit"

crthit.mean()

In [None]:
df = df.join(crthit, on=["__ntuple", "entry"])

In [None]:
mode_list = [0, 10, 1]
mode_labels = ['QE', 'MEC', 'RES', 'Other $\\nu$', "Cosmic"]
mode_colors = ["#315031", "#d54c28", "#1e3f54", "#c89648", "#95af8b"]

def breakdown_mode(var, df):
    """Break down variable by interaction mode."""
    ret = [var[df.genie_mode == i] for i in mode_list]
    ret.append(var[(sum([df.genie_mode == i for i in mode_list]) == 0) & ~np.isnan(df.genie_mode)])
    ret.append(var[np.isnan(df.genie_mode)])
    return ret

In [None]:
pdg_list = [2212, 13, 211]
pdg_labels = ["$p$", "$\\mu$", "$\\pi^\\pm$", "Other"]
pdg_colors = mode_colors[:4]

def breakdown_pdg(var, df, particle="p"):
    ret = [var[np.abs(df["%s_true_pdg" % particle] == i)] for i in pdg_list]
    ret.append(var[sum([np.abs(df["%s_true_pdg" % particle] == i) for i in pdg_list]) == 0])
    return ret

In [None]:
FONTSIZE = 14
HAWKS_COLORS = ["#315031", "#d54c28", "#1e3f54", "#c89648", "#43140b", "#95af8b"]

def add_style(ax, xlabel, title="", det="ICARUS"):
    ax.tick_params(axis='both', which='both', direction='in', length=6, width=1.5, labelsize=FONTSIZE, top=True, right=True)
    for spine in ax.spines.values():
        spine.set_linewidth(1.5)
    ax.set_xlabel(xlabel, fontsize=FONTSIZE, fontweight='bold')
    ax.set_ylabel('Area Normalized', fontsize=FONTSIZE, fontweight='bold')
    ax.set_title(f"$\\bf{{{det}}}$  {title}", fontsize=FONTSIZE+2)
    ax.legend(fontsize=FONTSIZE)


In [None]:
def f_cov(df, cut, var, bins=None, areanorm=False, wgts=["WGT"], nuniv=50, mcwgt=None):
    NCV,bins = np.histogram(df.loc[cut(df), var], bins=bins,
                           weights=None if mcwgt is None else df.loc[cut(df), mcwgt])

    if areanorm:
        diff = (bins[1:] - bins[:-1])
        norm = np.sum(NCV*diff)
        if norm > 1e-5:
            NCV = NCV / norm
        
    N_univ = []
    for i_univ in range(nuniv):
        wgt = np.prod([df.loc[cut(df), "%s_univ_%i" % (w, i_univ)] for w in wgts], axis=0)
        if mcwgt is not None:
            wgt = wgt * df.loc[cut(df), mcwgt]
            
        N = np.histogram(df.loc[cut(df), var], bins=bins, weights=wgt)[0]
        if areanorm:
            diff = (bins[1:] - bins[:-1])
            norm = np.sum(N*diff)
            if norm > 1e-5:
                N = N / norm
            
        N_univ.append(N)

    return NCV, np.sum([np.outer(N - NCV, N - NCV) for N in N_univ], axis=0) / nuniv

In [None]:
np.linalg.LinAlgError

In [None]:
def f_chi2(NMC, Ndata, cov):
    # ignore singular entries
    which_bin = NMC > 0

    NMC = NMC[which_bin]
    Ndata = Ndata[which_bin]
    cov = cov[which_bin, :]
    cov = cov[:, which_bin]

    delta = NMC - Ndata
    try:
        cov_inv = np.linalg.inv(cov)
    except np.linalg.LinAlgError as _:
        return -1, which_bin.sum()
        
    return delta@cov_inv@delta, which_bin.sum()

In [None]:
def make_plot_data(var, bins, cut, mc_weight, breakdown, areanorm, breakdown_labels, breakdown_colors, xlabel, title, 
                   det="ICARUS", fillna=np.nan, nsystuniv=50):
    
    pvars = breakdown(df.loc[cut(df), var].fillna(fillna), df[cut(df)])
    weights = breakdown(df.loc[cut(df), mc_weight], df[cut(df)])

    NMC_breakdown = []
    for pvar, w in zip(pvars, weights):    
        thisNMC, bins = np.histogram(pvar, bins=bins, weights=w)
        NMC_breakdown.append(thisNMC)
        
    NMC,_ = np.histogram(df.loc[cut(df), var].fillna(fillna), bins=bins, weights=df.loc[cut(df), mc_weight])
    if areanorm:
        diff = (bins[1:] - bins[:-1])
        norm = np.sum(NMC*diff)
        if norm > 1e-5:
            NMC = NMC / norm
            for i in range(len(NMC_breakdown)):
                NMC_breakdown[i] = NMC_breakdown[i] / norm

    NMC_breakdown = np.array(NMC_breakdown)
        
    NON,_ = np.histogram(ONdf.loc[cut(ONdf), var].fillna(fillna), bins=bins)
    NOff,_ = np.histogram(OFFdf.loc[cut(OFFdf), var].fillna(fillna), bins=bins)

    N = NON - NOff*OFF_w
    Nerr = np.sqrt(NON + NOff*OFF_w**2)
    if areanorm:
        diff = (bins[1:] - bins[:-1])
        
        norm = np.sum(N*diff)
        if norm > 1e-5:
            N = N / norm
            Nerr = Nerr / norm

    _, cov = f_cov(df, cut, var, bins=bins, areanorm=areanorm, nuniv=nsystuniv, mcwgt=mc_weight) 
    err = np.sqrt(np.diag(cov))

    cov_w_stat = cov + np.diag(Nerr**2) # add stat uncertainty
    chi2, ndof = f_chi2(NMC, N, cov_w_stat)

    return {
        "det": det,
        "title": title,
        "xlabel": xlabel,
        "bins": bins,
        "areanorm": areanorm,
        "breakdown_labels": breakdown_labels,
        "breakdown_colors": breakdown_colors,
        "NMC_breakdown": NMC_breakdown,
        "NMC_total": NMC,
        "NData": N,
        "NDataErr": Nerr,
        "cov": cov,
        "cov_w_stat": cov_w_stat,
        "chi2": chi2,
        "ndof": ndof,
        "POT": POT
    }


In [None]:
def ratio_plot(plt, plotdata):
    fig, (ax0, ax1) = plt.subplots(2, 1, height_ratios=[3, 1], sharex=True)
    bins = plotdata["bins"]
    centers = (bins[:-1] + bins[1:])/2

    NMC_breakdown = plotdata["NMC_breakdown"]
    fill = np.array([centers for _ in range(NMC_breakdown.shape[0])]).T
    ax0.hist(fill, bins=bins, stacked=True, label=plotdata["breakdown_labels"],
                    color=plotdata["breakdown_colors"], weights=NMC_breakdown.T)

    NData = plotdata["NData"]
    NDataErr = plotdata["NDataErr"]
    line = ax0.errorbar(centers, NData, NDataErr, color="black", linestyle="none", marker=".")

    NMC = plotdata["NMC_total"]
    err = np.sqrt(np.diag(plotdata["cov"]))
    ax0.fill_between(bins[:-1], NMC+err, NMC-err, facecolor="none", hatch="//", edgecolor="gray", linewidth=0.0, step="post")

    ax1.errorbar(centers, NData/NMC, NDataErr/NMC, color="black", linestyle="none", marker=".")
    ax1.set_ylim([0.5, 1.5])
    ax1.axhline([1], color="red", linestyle="--")
    ax1.fill_between(bins[:-1], 1+err/NMC, 1-err/NMC, facecolor="none", hatch="//", edgecolor="gray", linewidth=0.0, step="post")

    ax0.tick_params(axis='both', which='both', direction='in', length=6, width=1.5, labelsize=FONTSIZE, top=True, right=True)
    ax1.tick_params(axis='both', which='both', direction='in', length=6, width=1.5, labelsize=FONTSIZE, top=True, right=True)
    for spine in ax0.spines.values():
        spine.set_linewidth(1.5)
    ax1.set_xlabel(plotdata["xlabel"], fontsize=FONTSIZE, fontweight='bold')
    
    if plotdata["areanorm"]:
        ax0.set_ylabel('Area Normalized' % (plotdata["POT"]/1e19), fontsize=FONTSIZE, fontweight='bold')
    else:
        ax0.set_ylabel('Events / %.1f$\\times 10^{19}$ POT' % (plotdata["POT"]/1e19), fontsize=FONTSIZE, fontweight='bold')

    det = plotdata["det"]
    title = plotdata["title"]
    ax0.set_title(f"$\\bf{{{det}}}$ {title}", fontsize=FONTSIZE+2)
    ld = ax0.legend([line], ["Data\n(ON Beam - OFF)"], frameon=False, loc="upper left", fontsize=10)

    ax0_l0, ax0_hi = ax0.get_ylim()
    ax0.set_ylim([ax0_l0, ax0_hi*1.2])
    
    ax0.legend(fontsize=12)
    ax0.add_artist(ld)

    chi2_str = "$\\chi^2_\\mathrm{shape}$" if plotdata["areanorm"] else "$\\chi^2$"
    ax0.text(0.5, 0.98, "%s: %.1f / %i" % (chi2_str, plotdata["chi2"], plotdata["ndof"] - int(plotdata["areanorm"])),
            verticalalignment="top", horizontalalignment="center", fontsize=FONTSIZE-2, transform=ax0.transAxes)
    
    plt.subplots_adjust(hspace=0.05)
    

In [None]:
def FV(df):    
    det = DETECTOR.split(" ")[0]
    is_spine = "SPINE" in DETECTOR
    
    ret = gc.slcfv_cut(df, det) & gc.mufv_cut(df, det) & gc.pfv_cut(df, det) 
    
    if is_spine:
        ret = ret & (df.is_time_contained)
    
    return ret
    
def simple_cosmic_rej(df):
    is_spine = "SPINE" in DETECTOR
    return FV(df) & (df.crlongtrkdiry > -0.3)

def crtveto(df):
    return FV(df) & ~df.crthit

def twoprong_cut(df):
    return FV(df) & np.isnan(df.other_shw_length) & np.isnan(df.other_trk_length)

def pid_cut(df):
    is_spine = "SPINE" in DETECTOR
    if not is_spine:
        return twoprong_cut(df) & gc.pid_cut_df(df)
    else:
        return twoprong_cut(df) & (df.prot_chi2_of_prot_cand > 0.6) & (df.mu_chi2_of_mu_cand > 0.6)


In [None]:
if DETECTOR == "ICARUS":
    cuts = [
        FV,
        crtveto,
        simple_cosmic_rej,
        twoprong_cut,
        pid_cut,
    ]
    
    cutnames = [
        "Contained",
        "CRT Veto",
        "Simple Cos. Rej.",
        "Two Prong Cut",
        "PID Cut",
    ]
elif DETECTOR == "SBND SPINE":
    cuts = [
        FV,
        twoprong_cut,
        pid_cut,
    ]
    
    cutnames = [
        "Contained",
        "Two Prong Cut",
        "PID Cut",
    ]

plotvars = [
    "crlongtrkdiry",
    "nu_score",
    "mu_chi2_of_prot_cand",
    "prot_chi2_of_prot_cand",
    "mu_chi2_of_mu_cand",
    "prot_chi2_of_mu_cand",  
]

if "SPINE" not in DETECTOR:
    bins = [
        np.linspace(-1,1,21),
        np.linspace(0, 1, 21),
        np.linspace(0, 80, 21),
        np.linspace(0, 300, 21),
        np.linspace(0, 80, 21),
        np.linspace(0, 300, 21),
    ]
else:
    bins = [
        np.linspace(-1,1,21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
    ]

labels = [
    "CRLongTrkDirY",
    "$\\nu$ Score",
    "Proton Cand. $\\mu$-like PID",
    "Proton Cand. $p$-like PID",
    "Muon Cand. $\\mu$-like PID",
    "Muon Cand. $p$-like PID",
]

In [None]:
def inner(dat):
    (v, b, l, cut, cutname) = dat
    return v, make_plot_data(v, b, cut, "glob_scale", breakdown_mode, False, mode_labels, 
                                  mode_colors, l, cutname, fillna=-1, det=DETECTOR)

In [None]:
# all_plotdata_normed = {}


# with Pool(10) as p:
#     for cut, cutname in zip(cuts, cutnames):
#         all_plotdata_normed[cutname] = {}
#         inputs = [(v, b, l, cut, cutname) for (v, b, l) in zip(plotvars, bins, labels)]
#         for v, plotdata in tqdm(p.imap_unordered(inner, inputs), total=len(inputs)):    
#             all_plotdata_normed[cutname][v] = plotdata


In [None]:
# ifig = 0
# for cname in cutnames:
#     for v in plotvars:
#         plt.figure(ifig)
#         ratio_plot(plt, all_plotdata_normed[cname][v])
        
#         if DOSAVE:
#             savename_pdf = PLOTDIR + "/pdf/%s_%s_%s_potnorm.pdf" % (all_plotdata_normed[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
#             savename_png = PLOTDIR + "/png/%s_%s_%s_potnorm.png" % (all_plotdata_normed[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
#             plt.savefig(savename_pdf, bbox_inches="tight")
#             plt.savefig(savename_png, bbox_inches="tight")
#             plt.close()
#         else:
#             ifig += 1

In [None]:
if DETECTOR == "ICARUS":
    cuts = [
        FV,
        crtveto,
        simple_cosmic_rej,
        twoprong_cut,
        pid_cut,
    ]
    
    cutnames = [
        "Contained",
        "CRT Veto",
        "Simple Cos. Rej.",
        "Two Prong Cut",
        "PID Cut",
    ]
elif DETECTOR == "SBND SPINE":
    cuts = [
        FV,
        twoprong_cut,
        pid_cut,
    ]
    
    cutnames = [
        "Contained",
        "Two Prong Cut",
        "PID Cut",
    ]

plotvars = [
    "crlongtrkdiry",
    "nu_score",
    "other_trk_length",
    "other_shw_length",
    "mu_chi2_of_prot_cand",
    "prot_chi2_of_prot_cand",
    "mu_chi2_of_mu_cand",
    "prot_chi2_of_mu_cand",  
    "del_p",
    "p_len",
]

if "SPINE" not in DETECTOR:
    bins = [
        np.linspace(-1,1,21),
        np.linspace(0, 1, 21),
        np.array([-5] + list(np.linspace(0, 20, 5))),
        np.array([-10] + list(np.linspace(0, 100, 11))),
        np.linspace(0, 80, 21),
        np.linspace(0, 300, 21),
        np.linspace(0, 80, 21),
        np.linspace(0, 300, 21),
        np.linspace(0, 1.5, 16),
        np.linspace(0, 50, 11)
    ]
else:
    bins = [
        np.linspace(-1,1,21),
        np.linspace(0, 1, 21),
        np.array([-5] + list(np.linspace(0, 20, 5))),
        np.array([-10] + list(np.linspace(0, 100, 11))),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        np.linspace(0, 1, 21),
        # np.linspace(0, 1.5, 16),
        np.linspace(0, 1.0, 11),
        np.linspace(0, 50, 11)
    ]

labels = [
    "CRLongTrkDirY",
    "$\\nu$ Score",
    "Maximum Third Track Length [cm]",
    "Maximum Shower Length [cm]",
    "Proton Cand. $\\mu$-like PID",
    "Proton Cand. $p$-like PID",
    "Muon Cand. $\\mu$-like PID",
    "Muon Cand. $p$-like PID",
    "Transverse Momentum [GeV]",
    "Proton Cand. Length [cm]",
]

In [None]:
def inner(dat):
    (v, b, l, cut, cutname) = dat
    return v, make_plot_data(v, b, cut, "glob_scale", breakdown_mode, True, mode_labels, 
                                  mode_colors, l, cutname, fillna=-1, det=DETECTOR)

In [None]:
all_plotdata = {}


with Pool(10) as p:
    for cut, cutname in zip(cuts, cutnames):
        all_plotdata[cutname] = {}
        inputs = [(v, b, l, cut, cutname) for (v, b, l) in zip(plotvars, bins, labels)]
        for v, plotdata in tqdm(p.imap_unordered(inner, inputs), total=len(inputs)):    
            all_plotdata[cutname][v] = plotdata


In [None]:
ifig = 0
for cname in cutnames:
    for v in plotvars:
        plt.figure(ifig)
        ratio_plot(plt, all_plotdata[cname][v])
        
        if DOSAVE:
            savename_pdf = PLOTDIR + "/pdf/%s_%s_%s.pdf" % (all_plotdata[cname][v]["det"].replace(" ", "-"), cname.replace(" ", "").replace(".", "").lower(), v)
            savename_png = PLOTDIR + "/png/%s_%s_%s.png" % (all_plotdata[cname][v]["det"].replace(" ", "-"), cname.replace(" ", "").replace(".", "").lower(), v)
            plt.savefig(savename_pdf, bbox_inches="tight")
            plt.savefig(savename_png, bbox_inches="tight")
            plt.close()
        else:
            ifig += 1

In [None]:
cuts = [
    FV,
    simple_cosmic_rej,
    twoprong_cut,
    pid_cut,
]

cutnames = [
    "Contained",
    "Simple Cos. Rej.",
    "Two Prong Cut",
    "PID Cut",
]

plotvars = [
    "mu_chi2_of_prot_cand",
    "prot_chi2_of_prot_cand",
    "mu_chi2_of_mu_cand",
    "prot_chi2_of_mu_cand",  
]

bins = [
    np.linspace(0, 80, 21),
    np.linspace(0, 300, 21),
    np.linspace(0, 80, 21),
    np.linspace(0, 300, 21),
]

labels = [
    "Proton Cand. $\\chi^2_\\mu$",
    "Proton Cand. $\\chi^2_p$",
    "Muon Cand. $\\chi^2_\\mu$",
    "Muon Cand. $\\chi^2_p$",
]

In [None]:
def inner(dat):
    (v, b, l, cut, cutname) = dat
    return v, make_plot_data(v, b, cut, "glob_scale", breakdown_pdg, True, pdg_labels, 
                                  pdg_colors, l, cutname, fillna=-1, det=DETECTOR)

In [None]:
all_plotdata_pdg = {}


with Pool(4) as p:
    for cut, cutname in zip(cuts, cutnames):
        all_plotdata_pdg[cutname] = {}
        inputs = [(v, b, l, cut, cutname) for (v, b, l) in zip(plotvars, bins, labels)]
        for v, plotdata in tqdm(p.imap_unordered(inner, inputs), total=len(inputs)):    
            all_plotdata_pdg[cutname][v] = plotdata


In [None]:
ifig = 0
for cname in cutnames:
    for v in plotvars:
        plt.figure(ifig)
        ratio_plot(plt, all_plotdata_pdg[cname][v])
        
        if DOSAVE:
            savename_pdf = PLOTDIR + "/pdf/%s_%s_%s_bkdwnpdg.pdf" % (all_plotdata_pdg[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
            savename_png = PLOTDIR + "/png/%s_%s_%s_bkdwnpdg.png" % (all_plotdata_pdg[cname][v]["det"], cname.replace(" ", "").replace(".", "").lower(), v)
            plt.savefig(savename_pdf, bbox_inches="tight")
            plt.savefig(savename_png, bbox_inches="tight")
            plt.close()
        else:
            ifig += 1