# Notebook: de-duplicate per event by Bp_P clustering

Goal: for each input ROOT file, for each event, cluster entries by relative Bp_P proximity (threshold `epsilon`). Keep one representative per cluster, drop the rest. Write a cleaned file with the same tree name.

In [None]:
import ROOT as r
import numpy as np
from scipy.cluster.hierarchy import fclusterdata
from pathlib import Path

r.ROOT.EnableImplicitMT()  # parallel where possible

# I/O
BASE_IN  = Path("data/processed")               # input ROOT files (flat or nested)
BASE_OUT = Path("data/processed_clean_bp_p")    # outputs
BASE_OUT.mkdir(parents=True, exist_ok=True)

# Trees to process
TREENAMES = {
    "B2OC": "ST-b2oc",
    "B2CC": "ST-b2cc",
}

# Relative difference threshold for Bp_P clustering (e.g., 0.005 = 0.5%)
EPSILON = 0.005

In [None]:
def list_root_files(root_dir: Path) -> list[Path]:
    """Return all .root files under root_dir (works for flat or nested)."""
    return sorted(root_dir.rglob("*.root"))

def open_tree(filename: Path, treename: str):
    """Open a ROOT file and get the tree; return (file, tree) or (None, None)."""
    f = r.TFile.Open(str(filename))
    if not f or f.IsZombie():
        print(f"skip (cannot open): {filename}")
        return None, None
    t = f.Get(treename)
    if not t:
        print(f"skip (missing tree {treename}): {filename}")
        f.Close()
        return None, None
    return f, t

def collect_event_bp_indices(tree) -> dict:
    """
    Build {event_id: [(bp_p, entry_index), ...]}.
    Skips entries with missing or non-positive Bp_P.
    """
    events = {}
    n = tree.GetEntries()
    for i in range(n):
        tree.GetEntry(i)
        ev = getattr(tree, "event", None)
        p  = getattr(tree, "Bp_P", None)
        if ev is None or p is None:
            continue
        if p <= 0:
            continue
        events.setdefault(ev, []).append((float(p), i))
    return events

def relative_metric(u: np.ndarray, v: np.ndarray) -> float:
    """
    Relative distance for 1D arrays u=[p], v=[p']: |p - p'| / min(p, p').
    Handles zeros by returning +inf so they never cluster with nonzero.
    """
    a, b = float(u[0]), float(v[0])
    m = min(a, b)
    return abs(a - b) / m if m > 0.0 else float("inf")

def choose_representative(bp_p_list: list[tuple[float, int]], labels: np.ndarray, cluster_id: int) -> int:
    """
    Pick one entry index for a given cluster.
    Deterministic: choose the index whose Bp_P is closest to the cluster median.
    """
    members = [bp_p_list[k] for k, c in enumerate(labels) if c == cluster_id]
    ps = np.array([p for p, _ in members], dtype=float)
    med = float(np.median(ps))
    # argmin |p - median|
    idx = int(np.argmin(np.abs(ps - med)))
    return members[idx][1]

def select_keep_indices(bp_p_list: list[tuple[float, int]], epsilon: float) -> set[int]:
    """
    For one event: cluster by relative Bp_P, keep one per cluster.
    Returns a set of entry indices to keep.
    """
    if len(bp_p_list) <= 1:
        return {bp_p_list[0][1]} if bp_p_list else set()
    arr = np.array([[p] for p, _ in bp_p_list], dtype=float)
    labels = fclusterdata(
        arr,
        t=epsilon,
        criterion="distance",
        method="single",
        metric=relative_metric,
    )
    keep = set()
    for c in set(labels):
        keep.add(choose_representative(bp_p_list, labels, c))
    return keep

def write_filtered_tree(infile: Path, outdir: Path, treename: str, keep_indices: set[int]) -> None:
    """Clone and fill only the selected entries; mirror relative path if nested."""
    outdir.mkdir(parents=True, exist_ok=True)
    outfile = outdir / infile.name
    fin, t = open_tree(infile, treename)
    if not t:
        return
    n = t.GetEntries()
    out = r.TFile(str(outfile), "RECREATE")
    newt = t.CloneTree(0)
    for i in range(n):
        if i in keep_indices:
            t.GetEntry(i)
            newt.Fill()
    newt.Write()
    out.Close()
    fin.Close()
    print(f"cleaned -> {outfile} (kept {len(keep_indices)} entries)")

In [None]:
def clean_all(base_in: Path, base_out: Path, treename: str, epsilon: float) -> None:
    """Run the dedup for all ROOT files under base_in for a given tree."""
    roots = list_root_files(base_in)
    if not roots:
        print(f"no .root files under {base_in}")
        return

    for fpath in roots:
        fin, tree = open_tree(fpath, treename)
        if not tree:
            continue

        # Collect per-event entries
        events = collect_event_bp_indices(tree)

        # Build keep set by clustering
        keep = set()
        for bp_list in events.values():
            keep |= select_keep_indices(bp_list, epsilon)

        fin.Close()

        # Write cleaned file next to base_out (flat)
        write_filtered_tree(fpath, base_out, treename, keep)

In [None]:
# Run for both decays
clean_all(BASE_IN, BASE_OUT, treename=TREENAMES["B2OC"], epsilon=EPSILON)
clean_all(BASE_IN, BASE_OUT, treename=TREENAMES["B2CC"], epsilon=EPSILON)