
# OVI Detection Prediction with Explainable Boosting Machines (EBM)

End‑to‑end, fully explainable pipeline using `interpret`'s Explainable Boosting Classifier (EBM) for predicting **has_OVI_absorber** from physically‑meaningful features.
Includes:
- Robust data loading/cleaning
- Class imbalance handling via sample weights
- Train/validation/test split with stratification
- EBM‑GAM (no interactions) and EBM‑GA2M (pairwise interactions) models
- Metrics: Accuracy, Balanced Accuracy, ROC‑AUC, PR‑AUC, Brier score, calibration
- Threshold selection (max F1 and Youden’s J)
- Global interpretability (feature importances, shape functions)
- Interaction surfaces (2‑D heatmaps for top interactions)
- Local interpretability (per‑example contribution breakdowns)
- Slice metrics by astrophysically relevant cohorts (`is_star_forming`, `is_central`, mass/impact parameter bins)
- Artifact export: models, metrics, predictions, and figures


In [26]:

# If run locally and `interpret` or sklearn are missing, set INSTALL_PACKAGES=True.
INSTALL_PACKAGES = False

if INSTALL_PACKAGES:
    import sys, subprocess
    def _pip(pkg):
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
    _pip("interpret>=0.5.0")
    _pip("scikit-learn>=1.1.0")
    _pip("joblib>=1.2.0")


In [27]:

import os, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd

# EBM (Explainable Boosting Machine)
from interpret.glassbox import ExplainableBoostingClassifier

# Sklearn basics
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
                             roc_auc_score, average_precision_score,
                             precision_recall_curve, roc_curve,
                             confusion_matrix, classification_report,
                             brier_score_loss)
from sklearn.calibration import calibration_curve
from sklearn.preprocessing import KBinsDiscretizer

from joblib import dump

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")
plt.rc("font", size=12)
plt.rc("axes", titlesize=13, labelsize=12)
plt.rc("xtick", labelsize=10)
plt.rc("ytick", labelsize=10)
plt.rc("legend", fontsize=10)
plt.rc("figure", titlesize=14)

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)


In [28]:

# ---- Configure paths ----
# Primary (user-provided) path
BASE_PATH = '/Users/wavefunction/ASU Dropbox/Tanmay Singh/'
CSV_PATH_PRIMARY = os.path.join(BASE_PATH, 'Synthetic_IGrM_Sightlines/TNG50_fitting_results/feature_table.csv')

# Fallback (for alternative runs)
CSV_PATH_FALLBACK = '/mnt/data/feature_table.csv'

# Output directory for artifacts
OUTDIR = Path(os.path.join(BASE_PATH, 'Synthetic_IGrM_Sightlines/TNG50_fitting_results/ebm_outputs'))
try:
    OUTDIR.mkdir(parents=True, exist_ok=True)
except Exception:
    OUTDIR = Path("./ebm_outputs")
    OUTDIR.mkdir(parents=True, exist_ok=True)

# ---- Feature/Target selection ----
TARGET = 'has_OVI_absorber'

# Keep the physically interpretable set (mirrors prior XGBoost run), exclude Velocity_Offset to avoid target leakage.
FEATURES = [
    'log_M_halo',
    'log_M_star_group',
    'impact_param_group',
    'impact_param_galaxy',
    'log_M_star_galaxy',
    'log_sSFR_galaxy',
    'is_central',
    'is_star_forming',
    'is_bound'
]

# Optional subsampling for very large runs.
SAMPLE_FRACTION = 1.0  # set to e.g. 0.25 for quick iteration


In [29]:

# ---- Load CSV ----
if os.path.exists(CSV_PATH_PRIMARY):
    CSV_PATH = CSV_PATH_PRIMARY
elif os.path.exists(CSV_PATH_FALLBACK):
    CSV_PATH = CSV_PATH_FALLBACK
else:
    raise FileNotFoundError("feature_table.csv not found at either the primary or fallback path. Update CSV_PATHs.")

print(f"[load] Using CSV at: {CSV_PATH}")
data = pd.read_csv(CSV_PATH)

# Drop known empty column if present
if 'Unnamed: 20' in data.columns:
    data = data.drop(columns=['Unnamed: 20'])

# Ensure required columns exist
missing = set([TARGET] + FEATURES) - set(data.columns)
if missing:
    raise KeyError(f"Missing expected columns: {missing}")

# Force numeric types where needed; EBM can handle NaN (as separate bins)
to_numeric_cols = ['log_M_halo','log_M_star_group','impact_param_group','impact_param_galaxy',
                   'log_M_star_galaxy','log_sSFR_galaxy','is_central','is_star_forming','is_bound']
for c in to_numeric_cols:
    data[c] = pd.to_numeric(data[c], errors='coerce')

# Target to int
data[TARGET] = pd.to_numeric(data[TARGET], errors='coerce').fillna(0).astype(int)

# Optional subsample
if 0 < SAMPLE_FRACTION < 1.0:
    data = data.sample(frac=SAMPLE_FRACTION, random_state=RANDOM_STATE)

# Basic summary
n = len(data)
pos = int(data[TARGET].sum())
neg = n - pos
print(f"[summary] Rows={n:,}  Positives={pos:,}  Negatives={neg:,}  PosRate={pos/n:.3f}")


[load] Using CSV at: /Users/wavefunction/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/feature_table.csv
[summary] Rows=1,199,446  Positives=248,496  Negatives=950,950  PosRate=0.207


In [30]:

# Train/Val/Test split: 70/15/15 (stratified)
X = data[FEATURES]
y = data[TARGET]

X_train, X_tmp, y_train, y_tmp = train_test_split(
    X, y, test_size=0.30, stratify=y, random_state=RANDOM_STATE
)

X_valid, X_test, y_valid, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.50, stratify=y_tmp, random_state=RANDOM_STATE
)

print(f"[split] train={len(X_train):,}  valid={len(X_valid):,}  test={len(X_test):,}")


[split] train=839,612  valid=179,917  test=179,917


In [31]:

# Compute inverse-frequency sample weights on the TRAIN set only
pos_rate = y_train.mean()
w_pos = 0.5 / pos_rate
w_neg = 0.5 / (1.0 - pos_rate)

sample_weight_train = np.where(y_train.values == 1, w_pos, w_neg).astype(np.float64)

print(f"[weights] pos_rate={pos_rate:.4f}  w_pos={w_pos:.3f}  w_neg={w_neg:.3f}")


[weights] pos_rate=0.2072  w_pos=2.413  w_neg=0.631


In [32]:

# EBM-GAM (no pairwise interactions)
ebm_gam = ExplainableBoostingClassifier(
    interactions=0,
    outer_bags=8,
    inner_bags=0,
    learning_rate=0.02,
    max_bins=256,
    max_leaves=3,
    min_samples_leaf=200,
    max_rounds=5000,
    validation_size=0.15,
    early_stopping_tolerance=1e-5,
    n_jobs=-1,
    random_state=RANDOM_STATE,
)

# EBM-GA2M (pairwise interactions)
ebm_ga2m = ExplainableBoostingClassifier(
    interactions=10,   # learn up to 10 pairwise interactions
    outer_bags=8,
    inner_bags=0,
    learning_rate=0.02,
    max_bins=256,
    max_leaves=3,
    min_samples_leaf=200,
    max_rounds=5000,
    validation_size=0.15,
    early_stopping_tolerance=1e-5,
    n_jobs=-1,
    random_state=RANDOM_STATE,
)

print("[fit] Training EBM-GAM...")
ebm_gam.fit(X_train, y_train, sample_weight=sample_weight_train)

print("[fit] Training EBM-GA2M (interactions)...")
ebm_ga2m.fit(X_train, y_train, sample_weight=sample_weight_train)


[fit] Training EBM-GAM...
[fit] Training EBM-GA2M (interactions)...


0,1,2
,feature_names,
,feature_types,
,max_bins,256
,max_interaction_bins,64
,interactions,10
,exclude,
,validation_size,0.15
,outer_bags,8
,inner_bags,0
,learning_rate,0.02


In [33]:

def evaluate_model(name, model, Xv, yv, prefix, outdir: Path):
    # Scores
    p = model.predict_proba(Xv)[:, 1]
    yhat = (p >= 0.5).astype(int)

    acc  = accuracy_score(yv, yhat)
    bacc = balanced_accuracy_score(yv, yhat)
    roc  = roc_auc_score(yv, p)
    ap   = average_precision_score(yv, p)
    bs   = brier_score_loss(yv, p)

    # Threshold sweep (F1 and Youden's J)
    prec, rec, thr = precision_recall_curve(yv, p)
    f1 = (2 * prec * rec) / np.maximum(prec + rec, 1e-12)
    best_f1_idx = int(np.nanargmax(f1))
    best_f1_thr = thr[max(best_f1_idx - 1, 0)] if best_f1_idx < len(thr) else 0.5
    fpr, tpr, thr_roc = roc_curve(yv, p)
    youden = tpr - fpr
    best_j_idx = int(np.nanargmax(youden))
    best_j_thr = thr_roc[best_j_idx]

    # Confusion matrices
    def _cm_at(t):
        yh = (p >= t).astype(int)
        return confusion_matrix(yv, yh)

    cm_05   = _cm_at(0.5)
    cm_f1   = _cm_at(best_f1_thr)
    cm_j    = _cm_at(best_j_thr)

    # Save metrics
    metrics = {
        "model": name,
        "acc@0.5": float(acc),
        "bacc@0.5": float(bacc),
        "roc_auc": float(roc),
        "pr_auc": float(ap),
        "brier": float(bs),
        "best_f1_thr": float(best_f1_thr),
        "best_j_thr": float(best_j_thr),
        "cm@0.5": cm_05.tolist(),
        "cm@best_f1": cm_f1.tolist(),
        "cm@best_j": cm_j.tolist(),
    }
    with open(outdir / f"{prefix}_{name}_metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    # Plots: ROC, PR, Calibration
    # ROC
    fpr, tpr, _ = roc_curve(yv, p)
    plt.figure()
    plt.plot(fpr, tpr, linewidth=2)
    plt.plot([0,1], [0,1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC: {name} (AUC={roc:.3f})")
    plt.tight_layout()
    plt.savefig(outdir / f"{prefix}_{name}_ROC.png", dpi=150)
    plt.close()

    # PR
    plt.figure()
    plt.plot(rec, prec, linewidth=2)
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR: {name} (AP={ap:.3f})")
    plt.tight_layout()
    plt.savefig(outdir / f"{prefix}_{name}_PR.png", dpi=150)
    plt.close()

    # Calibration
    prob_true, prob_pred = calibration_curve(yv, p, n_bins=15, strategy="uniform")
    plt.figure()
    plt.plot(prob_pred, prob_true, marker="o", linewidth=2)
    plt.plot([0,1], [0,1], linestyle="--")
    plt.xlabel("Predicted probability")
    plt.ylabel("Empirical probability")
    plt.title(f"Calibration: {name} (Brier={bs:.3f})")
    plt.tight_layout()
    plt.savefig(outdir / f"{prefix}_{name}_Calibration.png", dpi=150)
    plt.close()

    # Predictions CSV
    preds = pd.DataFrame({
        "y_true": yv.values,
        "p": p,
        "y_pred@0.5": (p >= 0.5).astype(int),
        f"y_pred@best_f1({best_f1_thr:.3f})": (p >= best_f1_thr).astype(int),
        f"y_pred@best_j({best_j_thr:.3f})": (p >= best_j_thr).astype(int),
    })
    preds.to_csv(outdir / f"{prefix}_{name}_predictions.csv", index=False)

    # Text report
    report = classification_report(yv, yhat, digits=3)
    with open(outdir / f"{prefix}_{name}_cls_report.txt", "w") as f:
        f.write(report)

    return metrics


In [34]:

VAL_METRICS_GAM  = evaluate_model("EBM_GAM",  ebm_gam,  X_valid, y_valid, "valid", OUTDIR)
VAL_METRICS_GA2M = evaluate_model("EBM_GA2M", ebm_ga2m, X_valid, y_valid, "valid", OUTDIR)

TEST_METRICS_GAM  = evaluate_model("EBM_GAM",  ebm_gam,  X_test, y_test, "test", OUTDIR)
TEST_METRICS_GA2M = evaluate_model("EBM_GA2M", ebm_ga2m, X_test, y_test, "test", OUTDIR)

with open(OUTDIR / "summary_metrics.json", "w") as f:
    json.dump({
        "valid": {"GAM": VAL_METRICS_GAM, "GA2M": VAL_METRICS_GA2M},
        "test":  {"GAM": TEST_METRICS_GAM, "GA2M": TEST_METRICS_GA2M},
    }, f, indent=2)

print("[eval] Validation/Test metrics saved.")


[eval] Validation/Test metrics saved.


In [36]:
# =========================
# Robust EBM explanation helpers (version-agnostic)
# =========================
import numpy as np, pandas as pd, matplotlib.pyplot as plt, json
from pathlib import Path
from itertools import combinations

# ---------- universal utilities ----------
def _safe_logit(p, eps=1e-9):
    p = np.clip(np.asarray(p), eps, 1 - eps)
    return np.log(p / (1 - p))

def _percentile_grid(x, n=40):
    x = np.asarray(x)
    x = x[np.isfinite(x)]
    if x.size == 0:
        return np.linspace(0, 1, n)
    qs = np.linspace(0, 1, n)
    return np.unique(np.quantile(x, qs))

def _baseline_row(X: pd.DataFrame):
    base = {}
    for c in X.columns:
        if pd.api.types.is_numeric_dtype(X[c]):
            base[c] = np.nanmedian(X[c].values)
        else:
            base[c] = X[c].mode().iloc[0]
    return pd.Series(base, index=X.columns)

def _predict_logit(model, Xdf: pd.DataFrame):
    p = model.predict_proba(Xdf)[:, 1]
    return _safe_logit(p)

def _exp_data(expl):
    # interpret Explanation: .data may be a method or a property
    return expl.data() if callable(getattr(expl, "data", None)) else expl.data

def _glob_data(glob, i):
    # explain_global().data(i) OR explain_global().data[i]
    if callable(getattr(glob, "data", None)):
        return glob.data(i)
    arr = glob.data
    return arr if i == -1 else arr[i]

def _feat_names_in(ebm, fallback_cols):
    # Use any available attribute and coerce to list
    for attr in ("feature_names_in_", "feature_names_"):
        names = getattr(ebm, attr, None)
        if names is not None:
            try:
                return list(names)
            except Exception:
                try:
                    return names.tolist()
                except Exception:
                    pass
    return list(fallback_cols)

def _term_features(ebm):
    tf = getattr(ebm, "term_features_", None)
    # Expect list[tuple[int]]; if missing, synthesize singletons
    if tf is None:
        fn = _feat_names_in(ebm, [])
        return [(i,) for i in range(len(fn))]
    return tf

def _term_names(ebm, feat_names):
    tn = getattr(ebm, "term_names_", None)
    if tn is not None:
        try:
            return list(tn)
        except Exception:
            pass
    # fallback from term_features_
    out = []
    for idxs in _term_features(ebm):
        if len(idxs) == 1:
            out.append(f"{feat_names[idxs[0]]}")
        elif len(idxs) == 2:
            a, b = idxs
            out.append(f"{feat_names[a]} × {feat_names[b]}")
        else:
            out.append(" + ".join(feat_names[i] for i in idxs))
    return out

# ---------- GLOBAL: overall importances ----------
def save_global_importances(ebm, name: str, outdir: Path):
    outdir.mkdir(parents=True, exist_ok=True)
    glob = ebm.explain_global(name=name)
    g0 = _glob_data(glob, 0)  # overall importance
    term_names  = np.array(g0["names"])
    term_scores = np.array(g0["scores"]).astype(float)
    order = np.argsort(-term_scores)
    df_imp = pd.DataFrame({"term": term_names[order], "score": term_scores[order]})
    df_imp.to_csv(outdir / f"global_{name}_importances.csv", index=False)

    plt.figure(figsize=(8, 5))
    plt.barh(df_imp["term"].iloc[:12][::-1], df_imp["score"].iloc[:12][::-1])
    plt.xlabel("Importance (term score)")
    plt.title(f"Global Feature Importances — {name}")
    plt.tight_layout()
    plt.savefig(outdir / f"global_{name}_importances.png", dpi=150)
    plt.close()
    return df_imp

# ---------- GLOBAL: main-effect curves (exact via eval_terms) ----------
def save_main_effects(ebm, X_ref: pd.DataFrame, name: str, outdir: Path, n_grid=40, features=None):
    outdir.mkdir(parents=True, exist_ok=True)
    X = X_ref.copy()
    base = _baseline_row(X)

    term_feats = _term_features(ebm)               # list[tuple[int]]
    feat_names = _feat_names_in(ebm, X.columns)    # list[str]
    term_names = _term_names(ebm, feat_names)      # list[str]

    if features is None:
        features = list(X.columns)

    for t_idx, f_idx_tuple in enumerate(term_feats):
        if len(f_idx_tuple) != 1:
            continue
        f_idx = f_idx_tuple[0]
        feat = feat_names[f_idx]
        if feat not in features or feat not in X.columns:
            continue

        # grid for this feature
        if pd.api.types.is_numeric_dtype(X[feat]):
            vals = _percentile_grid(X[feat].values, n=n_grid)
        else:
            vals = np.unique(X[feat].values)

        grid = pd.DataFrame([base] * len(vals))
        grid[feat] = vals

        # per-term logit contributions
        term_matrix = ebm.eval_terms(grid)     # [len(vals), n_terms]
        contrib = term_matrix[:, t_idx]

        plt.figure()
        if pd.api.types.is_numeric_dtype(X[feat]):
            plt.plot(vals, contrib, marker="o", linewidth=2)
            plt.xlabel(feat)
        else:
            xi = np.arange(len(vals))
            plt.plot(xi, contrib, marker="o", linewidth=2)
            plt.xticks(xi, [str(v) for v in vals], rotation=45, ha="right")
            plt.xlabel(feat)
        plt.ylabel("Logit contribution")
        plt.title(f"Main effect: {term_names[t_idx]} — {name}")
        plt.tight_layout()
        plt.savefig(outdir / f"global_{name}_shape_{t_idx:02d}_{feat}.png", dpi=150)
        plt.close()

# ---------- GLOBAL: interaction surfaces (exact via eval_terms on pair terms) ----------
def save_interactions(ebm, X_ref: pd.DataFrame, name: str, outdir: Path, k=5, n_grid=25):
    outdir.mkdir(parents=True, exist_ok=True)
    X = X_ref.copy()
    base = _baseline_row(X)

    term_feats = _term_features(ebm)
    feat_names = _feat_names_in(ebm, X.columns)
    term_names = _term_names(ebm, feat_names)

    # Gather pair term indices and rank by importance from global chart
    glob = ebm.explain_global(name=name)
    g0 = _glob_data(glob, 0)
    imp_scores = dict(zip(g0["names"], g0["scores"]))

    pair_terms = [(t_idx, tuple(feat_names[i] for i in fidxs))
                  for t_idx, fidxs in enumerate(term_feats) if len(fidxs) == 2]
    if not pair_terms:
        return

    pair_terms_sorted = sorted(
        pair_terms,
        key=lambda x: imp_scores.get(term_names[x[0]], 0.0),
        reverse=True
    )[:k]

    for rank, (t_idx, (fa, fb)) in enumerate(pair_terms_sorted):
        if pd.api.types.is_numeric_dtype(X[fa]):
            ax = _percentile_grid(X[fa].values, n=n_grid)
        else:
            ax = np.unique(X[fa].values)
        if pd.api.types.is_numeric_dtype(X[fb]):
            bx = _percentile_grid(X[fb].values, n=n_grid)
        else:
            bx = np.unique(X[fb].values)

        Z = np.empty((len(ax), len(bx)), dtype=float)
        for i, av in enumerate(ax):
            grid = pd.DataFrame([base] * len(bx))
            grid[fa] = av
            grid[fb] = bx
            term_matrix = ebm.eval_terms(grid)    # [len(bx), n_terms]
            Z[i, :] = term_matrix[:, t_idx]       # pair-term logits

        plt.figure(figsize=(7, 5))
        if pd.api.types.is_numeric_dtype(X[fa]) and pd.api.types.is_numeric_dtype(X[fb]):
            plt.pcolormesh(ax, bx, Z.T, shading="auto")
            plt.xlabel(fa); plt.ylabel(fb)
        else:
            xi = np.arange(len(ax)); yi = np.arange(len(bx))
            plt.imshow(Z.T, origin="lower", aspect="auto",
                       extent=[xi.min()-0.5, xi.max()+0.5, yi.min()-0.5, yi.max()+0.5])
            plt.xticks(xi, [str(v) for v in ax], rotation=45, ha="right")
            plt.yticks(yi, [str(v) for v in bx])
            plt.xlabel(fa); plt.ylabel(fb)
        plt.title(f"Interaction surface: {term_names[t_idx]} — {name}")
        cb = plt.colorbar(); cb.set_label("Logit contribution")
        plt.tight_layout()
        plt.savefig(outdir / f"global_{name}_interaction_{rank:02d}_{fa}__{fb}.png", dpi=150)
        plt.close()

# ---------- LOCAL: per-sample breakdown via eval_terms ----------
def save_local_explanations(ebm, X: pd.DataFrame, y: pd.Series, name: str, outdir: Path, n_examples=5):
    outdir.mkdir(parents=True, exist_ok=True)

    p = ebm.predict_proba(X)[:, 1]
    z = _safe_logit(p)
    T = ebm.eval_terms(X)                      # [n_samples, n_terms]
    term_names = _term_names(ebm, _feat_names_in(ebm, X.columns))

    hi = np.argsort(-p)[:n_examples]
    lo = np.argsort(p)[:n_examples]

    # robust intercept extraction
    icpt = getattr(ebm, "intercept_", 0.0)
    try:
        icpt_val = float(icpt[0]) if np.ndim(icpt) else float(icpt)
    except Exception:
        icpt_val = float(icpt)

    payload = {"name": name, "intercept_logit": icpt_val, "examples": []}

    def _dump(idx, label):
        contribs = T[idx]  # per-term logits
        items = list(zip(term_names, contribs.tolist()))
        items_sorted = sorted(items, key=lambda t: abs(t[1]), reverse=True)[:10]

        plt.figure(figsize=(8, 5))
        plt.barh([k for k, _ in items_sorted][::-1], [v for _, v in items_sorted][::-1])
        plt.xlabel("Logit contribution")
        plt.title(f"Local Explanation — {name} [{label}] idx={idx}")
        plt.tight_layout()
        plt.savefig(outdir / f"local_{name}_{label}_idx{idx}.png", dpi=150)
        plt.close()

        payload["examples"].append({
            "row_index": int(idx),
            "label": label,
            "p": float(p[idx]),
            "logit": float(z[idx]),
            "top_terms": items_sorted
        })

    for i in hi: _dump(i, "HighP")
    for i in lo: _dump(i, "LowP")

    with open(outdir / f"local_{name}.json", "w") as f:
        json.dump(payload, f, indent=2)

In [37]:
imp_gam  = save_global_importances(ebm_gam,  "EBM_GAM",  OUTDIR)
imp_ga2m = save_global_importances(ebm_ga2m, "EBM_GA2M", OUTDIR)

save_main_effects(ebm_gam,  X_train, "EBM_GAM",  OUTDIR, n_grid=40, features=list(X_train.columns))
save_main_effects(ebm_ga2m, X_train, "EBM_GA2M", OUTDIR, n_grid=40, features=list(X_train.columns))

save_interactions(ebm_ga2m, X_train, "EBM_GA2M", OUTDIR, k=5, n_grid=25)

save_local_explanations(ebm_gam,  X_test, y_test, "EBM_GAM",  OUTDIR, n_examples=5)
save_local_explanations(ebm_ga2m, X_test, y_test, "EBM_GA2M", OUTDIR, n_examples=5)

In [38]:

def cohort_metrics(y_true, p, thr=0.5):
    y_hat = (p >= thr).astype(int)
    return {
        "acc": float(accuracy_score(y_true, y_hat)),
        "bacc": float(balanced_accuracy_score(y_true, y_hat)),
        "roc_auc": float(roc_auc_score(y_true, p)),
        "pr_auc": float(average_precision_score(y_true, p)),
    }

def save_slice_metrics(model, X, y, name, outdir: Path):
    p = model.predict_proba(X)[:, 1]
    df = X.copy()
    df["y_true"] = y.values
    df["p"] = p

    slices = {}

    # Binary cohorts
    if "is_star_forming" in df.columns:
        for val in [0, 1]:
            sel = df["is_star_forming"] == val
            if sel.any():
                slices[f"is_star_forming={val}"] = cohort_metrics(df.loc[sel, "y_true"].values,
                                                                  df.loc[sel, "p"].values)

    if "is_central" in df.columns:
        for val in [0, 1]:
            sel = df["is_central"] == val
            if sel.any():
                slices[f"is_central={val}"] = cohort_metrics(df.loc[sel, "y_true"].values,
                                                             df.loc[sel, "p"].values)

    # Mass bins (log_M_halo)
    if "log_M_halo" in df.columns:
        bins = np.quantile(df["log_M_halo"].dropna(), [0, 0.25, 0.5, 0.75, 1.0])
        labels = ["Q1","Q2","Q3","Q4"]
        cats = pd.cut(df["log_M_halo"], bins=bins, include_lowest=True, labels=labels, duplicates="drop")
        for lab in cats.dropna().unique():
            sel = cats == lab
            slices[f"log_M_halo={lab}"] = cohort_metrics(df.loc[sel, "y_true"].values,
                                                         df.loc[sel, "p"].values)

    # Impact parameter bins (impact_param_group)
    if "impact_param_group" in df.columns:
        bins = np.quantile(df["impact_param_group"].dropna(), [0, 0.25, 0.5, 0.75, 1.0])
        labels = ["Q1","Q2","Q3","Q4"]
        cats = pd.cut(df["impact_param_group"], bins=bins, include_lowest=True, labels=labels, duplicates="drop")
        for lab in cats.dropna().unique():
            sel = cats == lab
            slices[f"impact_param_group={lab}"] = cohort_metrics(df.loc[sel, "y_true"].values,
                                                                df.loc[sel, "p"].values)

    with open(outdir / f"slices_{name}.json", "w") as f:
        json.dump(slices, f, indent=2)

save_slice_metrics(ebm_gam,  X_test, y_test, "EBM_GAM",  OUTDIR)
save_slice_metrics(ebm_ga2m, X_test, y_test, "EBM_GA2M", OUTDIR)
print("[slices] Slice metrics saved.")


[slices] Slice metrics saved.


In [16]:

# Optional K-fold CV on training data (stratified) for ROC-AUC and PR-AUC
K = 5
skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=RANDOM_STATE)

def cv_scores_ebm(model_factory, X, y):
    rocs, aps = [], []
    for tr_idx, va_idx in skf.split(X, y):
        Xtr, Xva = X.iloc[tr_idx], X.iloc[va_idx]
        ytr, yva = y.iloc[tr_idx], y.iloc[va_idx]

        pos_rate = ytr.mean()
        w_pos = 0.5 / pos_rate
        w_neg = 0.5 / (1.0 - pos_rate)
        sw = np.where(ytr.values == 1, w_pos, w_neg).astype(np.float64)

        m = model_factory()
        m.fit(Xtr, ytr, sample_weight=sw)
        p = m.predict_proba(Xva)[:, 1]
        rocs.append(roc_auc_score(yva, p))
        aps.append(average_precision_score(yva, p))
    return np.mean(rocs), np.std(rocs), np.mean(aps), np.std(aps)

def mk_gam():
    return ExplainableBoostingClassifier(
        interactions=0, outer_bags=8, inner_bags=0, learning_rate=0.02,
        max_bins=256, max_leaves=3, min_samples_leaf=200, max_rounds=2000,
        validation_size=0.15, early_stopping_tolerance=1e-5,
        n_jobs=-1, random_state=RANDOM_STATE
    )

def mk_ga2m():
    return ExplainableBoostingClassifier(
        interactions=10, outer_bags=8, inner_bags=0, learning_rate=0.02,
        max_bins=256, max_leaves=3, min_samples_leaf=200, max_rounds=2000,
        validation_size=0.15, early_stopping_tolerance=1e-5,
        n_jobs=-1, random_state=RANDOM_STATE
    )

gam_roc_m, gam_roc_s, gam_ap_m, gam_ap_s   = cv_scores_ebm(mk_gam,  X_train, y_train)
ga2m_roc_m, ga2m_roc_s, ga2m_ap_m, ga2m_ap_s = cv_scores_ebm(mk_ga2m, X_train, y_train)

cv_summary = {
    "GAM":  {"roc_auc_mean":gam_roc_m,  "roc_auc_std":gam_roc_s,  "pr_auc_mean":gam_ap_m,  "pr_auc_std":gam_ap_s},
    "GA2M": {"roc_auc_mean":ga2m_roc_m, "roc_auc_std":ga2m_roc_s, "pr_auc_mean":ga2m_ap_m, "pr_auc_std":ga2m_ap_s},
}
with open(OUTDIR / "cv_summary.json", "w") as f:
    json.dump(cv_summary, f, indent=2)

print("[cv] 5-fold CV summary saved.")


[cv] 5-fold CV summary saved.


In [17]:

dump(ebm_gam,  OUTDIR / "EBM_GAM.joblib")
dump(ebm_ga2m, OUTDIR / "EBM_GA2M.joblib")
print(f"[save] Models saved to: {OUTDIR}")


[save] Models saved to: /Users/wavefunction/ASU Dropbox/Tanmay Singh/Synthetic_IGrM_Sightlines/TNG50_fitting_results/ebm_outputs



## Notes
- `SAMPLE_FRACTION` can be reduced for faster iteration if hardware is constrained.
- `interactions=0` yields a GAM (additive main effects). `interactions>0` yields a GA2M that learns pairwise terms.
- All plots use pure Matplotlib to keep environments simple.
- The JSON exports (`global_*.json`, `local_*.json`) retain full explanation payloads from `interpret` for reproducible analysis.
- To deploy, load the `joblib` artifacts and call `predict_proba` on new feature tables constructed identically to training.
