In [None]:
import json
import numpy as np
import uproot
from uproot.writing.identify import to_TH1x, to_TAxis

JSON_FILE = "nanoaod_inputs.json"
EMPTY_HIST_YIELD = 1.0

def extract_samples_from_json(json_file):
    out = []
    with open(json_file, "r") as fd:
        data = json.load(fd)
        for sample, conditions in data.items():
            for condition in conditions:
                out.append((sample, condition))
    return out

def list_channel_keys(file_handle, process_tag):
    out = {}
    for key in file_handle.keys(cycle=False):
        name = key.split(";")[0]
        if name.endswith("_" + process_tag):
            channel = name[: -(len(process_tag) + 1)]
            out[channel] = name
    return out

def h_sum_yield(h):
    vals, _ = h.to_numpy()
    return float(np.nansum(vals))

def try_variances(h):
    try:
        v = h.variances()
        if v is not None:
            return np.asarray(v)
    except Exception:
        pass
    return None

def build_th1x(vals, edges, name, title, var=None):
    vals = np.asarray(vals, dtype="float64")
    edges = np.asarray(edges, dtype="float64")
    nbins = len(edges) - 1
    data = np.zeros(nbins + 2, dtype="float64")
    data[1:-1] = vals
    if var is None:
        sumw2_core = vals.copy()
    else:
        sumw2_core = np.asarray(var, dtype="float64")
    sumw2 = np.zeros(nbins + 2, dtype="float64")
    sumw2[1:-1] = sumw2_core
    centers = 0.5 * (edges[1:] + edges[:-1])
    fEntries = float(np.sum(vals))
    fTsumw   = float(np.sum(vals))
    fTsumw2  = float(np.sum(sumw2_core))
    fTsumwx  = float(np.sum(vals * centers))
    fTsumwx2 = float(np.sum(vals * centers**2))
    xaxis = to_TAxis("xaxis", "", nbins, float(edges[0]), float(edges[-1]), edges)
    fMaximum = float(np.max(vals) * 1.2) if np.max(vals) > 0 else 1.0
    fMinimum = 0.0
    fBarOffset = 0
    fBarWidth = 1
    try:
        return to_TH1x(
            name, title, data,
            fEntries, fTsumw, fTsumw2, fTsumwx, fTsumwx2,
            sumw2, xaxis,
            None, None,
            None,
            fBarOffset, fBarWidth,
            fMaximum, fMinimum,
        )
    except TypeError:
        return to_TH1x(
            name, title, data,
            fEntries, fTsumw, fTsumw2, fTsumwx, fTsumwx2,
            sumw2, xaxis
        )

items = extract_samples_from_json(JSON_FILE)
everything_roots = [f"everything_merged_{s}__{c}.root" for (s, c) in items]

req_ttbar_me = "everything_merged_ttbar__ME_var.root"
req_ttbar_ps = "everything_merged_ttbar__PS_var.root"
req_wjets_no = "everything_merged_wjets__nominal.root"

with uproot.recreate("histograms_merged.root") as f_out:
    for h_file in everything_roots:
        try:
            with uproot.open(h_file) as f_in:
                for key in f_in.keys(cycle=False):
                    f_out[key] = f_in[key]
        except FileNotFoundError:
            print(f"[WARN] missing: {h_file}")
    if all(p in everything_roots for p in (req_ttbar_me, req_ttbar_ps, req_wjets_no)):
        with uproot.open(req_ttbar_me) as f_ttbar_ME, \
             uproot.open(req_ttbar_ps) as f_ttbar_PS, \
             uproot.open(req_wjets_no) as f_wjets:
            ttbar_me_map = list_channel_keys(f_ttbar_ME, "ttbar_ME_var")
            ttbar_ps_map = list_channel_keys(f_ttbar_PS, "ttbar_PS_var")
            wjets_map    = list_channel_keys(f_wjets,    "wjets_nominal")
            common_channels = sorted(set(ttbar_me_map) & set(ttbar_ps_map) & set(wjets_map))
            print(f"[INFO] channels for pseudodata_nominal: {len(common_channels)}")
            for channel in common_channels:
                h_me = f_ttbar_ME[ttbar_me_map[channel]]
                h_ps = f_ttbar_PS[ttbar_ps_map[channel]]
                h_wj = f_wjets[wjets_map[channel]]
                if (
                    h_sum_yield(h_me) <= EMPTY_HIST_YIELD or
                    h_sum_yield(h_ps) <= EMPTY_HIST_YIELD or
                    h_sum_yield(h_wj) <= EMPTY_HIST_YIELD
                ):
                    continue
                vals_me, edges    = h_me.to_numpy()
                vals_ps, edges_ps = h_ps.to_numpy()
                vals_wj, edges_wj = h_wj.to_numpy()
                if not (np.allclose(edges, edges_ps) and np.allclose(edges, edges_wj)):
                    print(f"[WARN] binning mismatch in {channel}, skipping")
                    continue
                new_vals = 0.5 * (vals_me + vals_ps) + vals_wj
                var_me = try_variances(h_me)
                var_ps = try_variances(h_ps)
                var_wj = try_variances(h_wj)
                new_var = 0.25 * (var_me + var_ps) + var_wj if (var_me is not None and var_ps is not None and var_wj is not None) else None
                hname  = f"{channel}_pseudodata_nominal"
                htitle = "Pseudodata = 0.5*(ttbar_ME + ttbar_PS) + wjets_nominal"
                f_out[hname] = build_th1x(new_vals, edges, hname, htitle, new_var)

print("[OK] histograms_merged.root created with pseudodata added.")
