In [1]:
# ============================================================
# ✅ FULL UPDATED END-TO-END SCRIPT (Z-CROSS) — FIXED SHAP + BETTER RELIABILITY
# Train on one PROMISE dataset, Test on another (PC1–PC4)
#
# ✅ Supports your selected 12 pairs:
#   PC1→PC2, PC1→PC3, PC1→PC4,
#   PC2→PC1, PC2→PC3, PC2→PC4,
#   PC3→PC1, PC3→PC2, PC3→PC4,
#   PC4→PC1, PC4→PC2, PC4→PC3
#
# ✅ Uses PC1→NASA mapping before alignment
# ✅ TRAIN-only: SMOTE + scaler + calibration + threshold tuning
# ✅ TEST-only: evaluation + GLR + Hit@k + ECE + ReliabilityScore
#
# ✅ FIX: SHAP now works even when model is CalibratedClassifierCV
#    - calibrated model used for probabilities (ECE/Brier)
#    - base estimator unwrapped for TreeSHAP
#
# ============================================================

import warnings, numpy as np, pandas as pd, matplotlib.pyplot as plt, random, os, re
warnings.filterwarnings("ignore")

if not hasattr(np, "bool"): np.bool = np.bool_
if not hasattr(np, "int"):  np.int  = int

from typing import Tuple, Optional, List, Dict
import inspect

from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, brier_score_loss
from sklearn.inspection import permutation_importance
from sklearn.calibration import CalibratedClassifierCV
from imblearn.over_sampling import SMOTE
from scipy.stats import spearmanr

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.isotonic import IsotonicRegression

_HAS_XGB = _HAS_LGBM = _HAS_CB = False
PRINTED_SHAP_FALLBACK = False

try:
    from xgboost import XGBClassifier
    _HAS_XGB = True
except Exception:
    pass

try:
    from lightgbm import LGBMClassifier
    _HAS_LGBM = True
except Exception:
    pass

try:
    from catboost import CatBoostClassifier
    _HAS_CB = True
except Exception:
    pass

try:
    from scipy.io import arff as _arff
except Exception:
    _arff = None

try:
    import shap
    _HAS_SHAP = True
except Exception:
    _HAS_SHAP = False

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
    RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier, AdaBoostClassifier
)
from sklearn.neural_network import MLPClassifier

# -----------------------------
# Metric-family mapping (PC1-style ↔ PC2/PC3/PC4-style)
# -----------------------------
PC1_TO_NASA_MAP = {
    "loc": "LOC_EXECUTABLE",
    "lOComment": "LOC_COMMENTS",
    "locCodeAndComment": "LOC_CODE_AND_COMMENT",
    "v(g)": "CYCLOMATIC_COMPLEXITY",
    "iv(G)": "DESIGN_COMPLEXITY",
    "ev(g)": "ESSENTIAL_COMPLEXITY",
    # Halstead-ish (PC2/3/4 often have these)
    "D": "HALSTEAD_DIFFICULTY",
    "E": "HALSTEAD_EFFORT",
    "I": "HALSTEAD_CONTENT",
    # best-effort matches (may or may not exist depending on PC2/3/4)
    "B": "HALSTEAD_ERROR_EST",
}

def _norm_raw_key(x: str) -> str:
    return str(x).strip().lower()

def _build_casefold_map(m: dict) -> dict:
    out = {}
    for k, v in m.items():
        out[_norm_raw_key(k)] = v
    return out

_PC1_TO_NASA_CASEFOLD = _build_casefold_map(PC1_TO_NASA_MAP)

# -----------------------------
# Target inference helpers
# -----------------------------
_KEYWORDS = ["defect", "bug", "fault", "class", "label", "target", "isdefect", "is_defect"]

def _fit_optional_weight(estimator, X, y, sample_weight=None, **kwargs):
    """Call estimator.fit, passing sample_weight only if supported."""
    try:
        sig = inspect.signature(estimator.fit)
        if "sample_weight" in sig.parameters and sample_weight is not None:
            return estimator.fit(X, y, sample_weight=sample_weight, **kwargs)
        else:
            return estimator.fit(X, y, **kwargs)
    except (TypeError, ValueError):
        return estimator.fit(X, y, **kwargs)

def _class_balanced_weights(y: np.ndarray) -> np.ndarray:
    cls = np.unique(y)
    cw = compute_class_weight(class_weight='balanced', classes=cls, y=y)
    m = {c:w for c,w in zip(cls, cw)}
    return np.vectorize(m.get)(y)

def _tune_threshold(y_val: np.ndarray, p_val_cal: np.ndarray) -> float:
    grid = np.linspace(0.05, 0.95, 61)  # wider grid for imbalance
    best_t, best_f1 = 0.5, -1.0
    for t in grid:
        pred = (p_val_cal >= t).astype(int)
        f1 = f1_score(y_val, pred, zero_division=0)
        if f1 > best_f1:
            best_f1, best_t = f1, t
    return float(best_t)

def _safe_decode_col(s: pd.Series) -> pd.Series:
    if s.dtype == object:
        try:
            return s.apply(lambda x: x.decode("utf-8") if isinstance(x, (bytes, bytearray)) else x)
        except Exception:
            return s
    return s

def _coerce_binary(series: pd.Series) -> pd.Series:
    s = series.copy()
    if pd.api.types.is_numeric_dtype(s):
        uniq = sorted(pd.unique(s.dropna()))
        if set(uniq).issubset({0,1}):
            return s.astype(int)
        if set(uniq).issubset({-1,1}):
            return ((s + 1) // 2).astype(int)
        if len(uniq) == 2:
            m = {uniq[0]:0, uniq[1]:1}
            return s.map(m).astype(int)

    sval = s.astype(str).str.strip().str.lower()
    mapping = {
        "true":1, "false":0, "t":1, "f":0,
        "yes":1, "no":0, "y":1, "n":0,
        "defective":1, "non-defective":0, "nondefective":0, "clean":0, "faulty":1,
        "positive":1, "negative":0, "pos":1, "neg":0,
        "1":1, "0":0
    }
    mapped = sval.map(mapping)
    if not mapped.isna().all():
        if mapped.isna().any():
            u = list(pd.unique(sval))
            if len(u) == 2:
                m = {u[0]:0, u[1]:1}
                mapped = sval.map(m)
        if mapped.isna().any():
            raise ValueError("Could not fully coerce target to binary.")
        return mapped.astype(int)

    u = list(pd.unique(sval))
    if len(u) == 2:
        return sval.map({u[0]:0, u[1]:1}).astype(int)
    raise ValueError("Column is not binary or has unexpected labels.")

def _is_binary_col(s: pd.Series) -> bool:
    try:
        _ = _coerce_binary(s)
        return True
    except Exception:
        return False

def _score_target_candidate(colname: str, s: pd.Series, pos_index: int, total_cols: int) -> float:
    name = colname.lower()
    score = 0.0
    if any(k in name for k in _KEYWORDS): score += 2.0
    if pos_index == total_cols - 1:       score += 1.0
    if _is_binary_col(s):                 score += 2.0
    try:
        y = _coerce_binary(s)
        p = y.mean()
        if 0.05 <= p <= 0.95: score += 1.0
        else:                  score += 0.5
    except Exception:
        pass
    return score

def infer_target_and_features(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series, str]:
    for c in df.columns:
        df[c] = _safe_decode_col(df[c])

    cols = list(df.columns)
    scores = []
    for i, c in enumerate(cols):
        s = df[c]
        scores.append((_score_target_candidate(c, s, i, len(cols)), i, c))
    scores.sort(reverse=True)

    for _, idx, cname in scores:
        s = df[cname]
        if _is_binary_col(s):
            try:
                y = _coerce_binary(s)
                X = df.drop(columns=[cname]).copy()
                print(f"[Target] Auto-detected column: '{cname}' (pos {idx+1}/{len(cols)})")
                return X, y.astype(int), cname
            except Exception:
                continue

    last = cols[-1]
    if _is_binary_col(df[last]):
        y = _coerce_binary(df[last])
        X = df.iloc[:, :-1].copy()
        print(f"[Target] Fallback to last column: '{last}'")
        return X, y.astype(int), last

    for c in cols:
        if df[c].nunique(dropna=True) == 2:
            y = _coerce_binary(df[c])
            X = df.drop(columns=[c]).copy()
            print(f"[Target] Fallback to 2-unique column: '{c}'")
            return X, y.astype(int), c

    raise ValueError("Could not automatically infer a binary target column.")

def load_dataset_auto(data_path: str) -> pd.DataFrame:
    ext = os.path.splitext(data_path)[1].lower()
    if ext == ".csv":
        return pd.read_csv(data_path)
    elif ext == ".arff":
        if _arff is None:
            raise ImportError("scipy is required for reading .arff files (scipy.io.arff).")
        data, meta = _arff.loadarff(data_path)
        df = pd.DataFrame(data)
        for c in df.columns:
            df[c] = _safe_decode_col(df[c])
        for c in df.columns:
            if df[c].dtype == object:
                try:
                    df[c] = pd.to_numeric(df[c])
                except Exception:
                    pass
        return df
    else:
        raise ValueError(f"Unsupported file extension: {ext}. Use .csv or .arff")

def load_and_preprocess(data_path: str) -> Tuple[pd.DataFrame, pd.Series, str]:
    df = load_dataset_auto(data_path)
    X, y, tgt = infer_target_and_features(df)

    const_cols = [c for c in X.columns if X[c].nunique(dropna=False) <= 1]
    if const_cols:
        print("Dropping constant columns:", const_cols)
        X.drop(columns=const_cols, inplace=True)

    if X.isna().sum().any():
        print("Missing values detected; forward/back fill.")
        X = X.fillna(method="ffill").fillna(method="bfill")
    else:
        print("No missing values detected.")

    for c in X.columns:
        if X[c].dtype == object:
            try:
                X[c] = pd.to_numeric(X[c])
            except Exception:
                pass

    return X, y.astype(int), tgt

# -----------------------------
# Models
# -----------------------------
def make_model(name: str, random_state: int):
    n = name.strip().lower()
    if n in ("random_forest","rf"):
        return RandomForestClassifier(
            n_estimators=800, max_depth=None, min_samples_leaf=2,
            random_state=random_state, n_jobs=-1, class_weight="balanced_subsample", oob_score=False
        )
    if n in ("extra_trees","extratrees","et"):
        return ExtraTreesClassifier(
            n_estimators=1200, max_depth=None, min_samples_leaf=1,
            random_state=random_state, n_jobs=-1, class_weight="balanced_subsample"
        )
    if n in ("gradient_boosting","gbrt","gb"):
        return GradientBoostingClassifier(
            random_state=random_state, learning_rate=0.05, n_estimators=1000, max_depth=3
        )
    if n in ("adaboost","ada"):
        base = DecisionTreeClassifier(max_depth=2, random_state=random_state)
        params = dict(n_estimators=1000, learning_rate=0.05, random_state=random_state)
        try:
            return AdaBoostClassifier(estimator=base, **params)
        except TypeError:
            return AdaBoostClassifier(base_estimator=base, **params)
    if n in ("xgboost","xgb"):
        if not _HAS_XGB:
            print("[WARN] xgboost not installed; falling back to RandomForest.")
            return make_model("rf", random_state)
        return XGBClassifier(
            n_estimators=1500, max_depth=6, learning_rate=0.05,
            subsample=0.8, colsample_bytree=0.8, reg_lambda=1.0,
            eval_metric="logloss", tree_method="hist", random_state=random_state, n_jobs=-1
        )
    if n in ("lightgbm","lgbm","lgb"):
        if not _HAS_LGBM:
            print("[WARN] lightgbm not installed; falling back to RandomForest.")
            return make_model("rf", random_state)
        return LGBMClassifier(
            n_estimators=2000, num_leaves=31, learning_rate=0.03,
            subsample=0.8, colsample_bytree=0.8, random_state=random_state, n_jobs=-1
        )
    if n in ("catboost","cb"):
        if not _HAS_CB:
            print("[WARN] catboost not installed; falling back to RandomForest.")
            return make_model("rf", random_state)
        return CatBoostClassifier(
            iterations=1500, depth=6, learning_rate=0.05,
            verbose=False, random_seed=random_state, loss_function="Logloss"
        )
    if n in ("mlp","nn","pytorch_mlp"):
        return MLPClassifier(
            hidden_layer_sizes=(256,128,64), activation="relu", solver="adam",
            alpha=3e-4, learning_rate_init=1e-3, max_iter=500,
            batch_size='auto', early_stopping=True, n_iter_no_change=25,
            random_state=random_state
        )
    raise ValueError(f"Unknown model name: {name}")

# -----------------------------
# Reliability helpers
# -----------------------------
def spearman_rank_corr(a: pd.Series, b: pd.Series) -> float:
    a, b = a.align(b, join='inner')
    av, bv = a.values, b.values
    if np.nanstd(av) == 0 and np.nanstd(bv) == 0: return 1.0
    if np.nanstd(av) == 0 or np.nanstd(bv) == 0:  return 0.0
    return spearmanr(pd.Series(av).rank(), pd.Series(bv).rank(), nan_policy='omit').correlation

def bce_loss(p, y):
    p = np.clip(p, 1e-8, 1-1e-8)
    return -(y*np.log(p) + (1-y)*np.log(1-p))

def finite_diff_grad_loss(model, X_eval: np.ndarray, y_eval: np.ndarray, eps_vec: np.ndarray):
    n, d = X_eval.shape
    grads = np.zeros((n, d), dtype=float)
    for j in range(d):
        e = np.zeros_like(X_eval); e[:, j] = eps_vec[j]
        pp = model.predict_proba(X_eval + e)[:, 1]
        pm = model.predict_proba(X_eval - e)[:, 1]
        grads[:, j] = (bce_loss(pp, y_eval) - bce_loss(pm, y_eval)) / (2.0 * eps_vec[j] + 1e-12)
    return np.abs(grads)

def normalize_rows(A: np.ndarray, eps=1e-12):
    s = A.sum(axis=1, keepdims=True) + eps
    return A / s

def expected_calibration_error(y_true, y_prob, n_bins=10):
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)
    bins = np.linspace(0.0, 1.0, n_bins+1)
    inds = np.digitize(y_prob, bins) - 1
    ece = 0.0
    for b in range(n_bins):
        mask = inds == b
        if not np.any(mask):
            continue
        conf = y_prob[mask].mean()
        acc = y_true[mask].mean()
        ece += (np.sum(mask) / len(y_true)) * np.abs(acc - conf)
    return ece

def rescale01_rho(rho):
    return (rho + 1.0) / 2.0 if not np.isnan(rho) else np.nan

def _shap_to_posclass_2d(sv):
    if isinstance(sv, list):
        return sv[1] if len(sv) > 1 else sv[0]
    sv = np.asarray(sv)
    if sv.ndim == 3:
        return sv[..., 1] if sv.shape[-1] >= 2 else sv.mean(axis=-1)
    return sv

def make_eps_vec_from_train(Xtr: pd.DataFrame, base_eps: float = 1e-3) -> np.ndarray:
    std = Xtr.std(axis=0).values.astype(float)
    std = np.where(std < 1e-12, 1.0, std)
    return base_eps * std

# -----------------------------
# ✅ NEW: unwrap calibrated model for SHAP
# -----------------------------
def unwrap_for_shap(model):
    """
    If model is CalibratedClassifierCV, extract a fitted base estimator for SHAP.
    Otherwise return model itself.
    """
    if isinstance(model, CalibratedClassifierCV):
        try:
            cc = model.calibrated_classifiers_[0]
            if hasattr(cc, "estimator"):
                return cc.estimator
            if hasattr(cc, "base_estimator"):
                return cc.base_estimator
        except Exception:
            return model
    return model

# -----------------------------
# ✅ UPDATED: global_importance uses unwrapped estimator for TreeSHAP
# -----------------------------
def global_importance(model, X_train: pd.DataFrame, X_eval: pd.DataFrame, y_eval: Optional[pd.Series] = None):
    feature_names = X_eval.columns

    model_shap = unwrap_for_shap(model)

    if isinstance(model_shap, AdaBoostClassifier) and hasattr(model_shap, "feature_importances_"):
        imp = np.abs(np.asarray(model_shap.feature_importances_, dtype=float))
        imp_s = pd.Series(imp, index=feature_names, name="mean|FI|")
        sv_abs = np.tile(imp, (len(X_eval), 1))
        return imp_s, sv_abs

    def _is_supported_tree(m):
        name = m.__class__.__name__.lower()
        return any(k in name for k in ["randomforest", "extratrees", "gradientboost", "xgb", "lgbm", "catboost"])

    if _HAS_SHAP:
        try:
            bg = shap.sample(X_train, min(256, len(X_train)))
            if _is_supported_tree(model_shap):
                explainer = shap.TreeExplainer(
                    model_shap, data=bg, feature_perturbation="interventional", model_output="probability"
                )
                sv = explainer.shap_values(X_eval, check_additivity=False)
                sv = _shap_to_posclass_2d(sv)
                sv_abs = np.abs(np.asarray(sv))
                return pd.Series(sv_abs.mean(axis=0), index=feature_names, name="mean|SHAP|"), sv_abs
        except Exception as e:
            global PRINTED_SHAP_FALLBACK
            if not PRINTED_SHAP_FALLBACK:
                print(f"[SHAP] Falling back to permutation importance: {e}")
                PRINTED_SHAP_FALLBACK = True

    # Fallback permutation importance
    if y_eval is None or len(np.unique(y_eval)) != 2:
        y_for_perm = (model.predict_proba(X_eval)[:, 1] >= 0.5).astype(int)
        scoring = "neg_brier_score"
    else:
        y_for_perm = y_eval
        scoring = "roc_auc"

    pi = permutation_importance(
        model, X_eval, y_for_perm,
        scoring=scoring, n_repeats=8, random_state=42, n_jobs=-1
    )
    imp = np.abs(pi.importances_mean)
    imp_s = pd.Series(imp, index=feature_names, name="mean|PermImp|")
    sv_abs = np.tile(imp, (len(X_eval), 1))
    return imp_s, sv_abs

def safe_smote(Xtr: pd.DataFrame, ytr: pd.Series, random_state: int):
    try:
        minority = int((ytr == 1).sum())
        k = min(5, max(1, minority - 1))
        return SMOTE(random_state=random_state, k_neighbors=k).fit_resample(Xtr, ytr)
    except Exception as e:
        print(f"[SMOTE] skipped: {e}")
        return Xtr, ytr

# -----------------------------
# Improved evaluation: calibrate + threshold tune for ALL models
# -----------------------------
def evaluate_once_train_test(model, Xtr_fit, ytr_fit, Xte: pd.DataFrame, yte: pd.Series,
                             model_name: str, seed: int,
                             calibrate_all: bool = True,
                             calib_method: str = "sigmoid",
                             calib_cv: int = 3):
    name = model_name.strip().lower()

    scaler = None
    Xtr_use, Xte_use = Xtr_fit, Xte
    sample_weight = None

    if name in ("mlp","nn","pytorch_mlp"):
        scaler = StandardScaler().fit(Xtr_fit.values.astype(float))
        Xtr_use = pd.DataFrame(scaler.transform(Xtr_fit.values.astype(float)), columns=Xtr_fit.columns)
        Xte_use = pd.DataFrame(scaler.transform(Xte.values.astype(float)),      columns=Xte.columns)
        sample_weight = _class_balanced_weights(np.asarray(ytr_fit).astype(int))

    _fit_optional_weight(model, Xtr_use, ytr_fit, sample_weight=sample_weight)

    fitted_model = model
    if calibrate_all:
        try:
            fitted_model = CalibratedClassifierCV(model, method=calib_method, cv=calib_cv)
            fitted_model.fit(Xtr_use, ytr_fit)
        except Exception as e:
            print(f"[Calib] skipped ({calib_method}): {e}")
            fitted_model = model

    thresh = 0.5
    try:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
        (i_tr, i_val), = sss.split(Xtr_use, ytr_fit)
        Xval, yval = Xtr_use.iloc[i_val], ytr_fit.iloc[i_val]
        p_val = fitted_model.predict_proba(Xval)[:, 1]
        thresh = _tune_threshold(np.asarray(yval).astype(int), np.asarray(p_val).astype(float))
    except Exception as e:
        print(f"[Thr] tuning skipped: {e}")
        thresh = 0.5

    proba_te = fitted_model.predict_proba(Xte_use)[:, 1]
    pred = (proba_te >= thresh).astype(int)

    metrics = {
        "AUC": roc_auc_score(yte, proba_te) if len(np.unique(yte)) == 2 else np.nan,
        "F1":  f1_score(yte, pred, zero_division=0),
        "Precision": precision_score(yte, pred, zero_division=0),
        "Recall":    recall_score(yte, pred, zero_division=0),
        "Brier":     brier_score_loss(yte, proba_te),
        "thr":       float(thresh),
    }
    return proba_te, metrics, (scaler, Xtr_use, Xte_use, fitted_model)

def glr_rhos_for_test(shap_abs_test: np.ndarray, grad_norm: np.ndarray, feature_names):
    n = shap_abs_test.shape[0]
    rhos = []
    for i in range(n):
        s = pd.Series(shap_abs_test[i, :], index=feature_names)
        g = pd.Series(grad_norm[i, :],      index=feature_names)
        rhos.append(float(spearman_rank_corr(s, g)))
    return rhos

def epsilon_hitk_for_subset(model, Xte_np, yte_np, shap_abs_test, top_k, eps, m_limit, chosen_idx):
    hits = []
    d = Xte_np.shape[1]
    proba = model.predict_proba(Xte_np)[:, 1]
    base_loss = bce_loss(proba, yte_np)

    for i in chosen_idx:
        x = Xte_np[i, :]; y_i = yte_np[i]
        shap_vec = shap_abs_test[i, :]
        M = min(m_limit, d)
        cand_idx = np.argsort(-shap_vec)[:M]
        topk_shap_idx = cand_idx[:top_k]

        X_batch = np.tile(x, (M, 1))
        X_batch[np.arange(M), cand_idx] += eps
        p_pert = model.predict_proba(X_batch)[:, 1]
        dL = bce_loss(p_pert, np.full(M, y_i)) - base_loss[i]
        topk_delta_idx = cand_idx[np.argsort(-dL)[:top_k]]

        hits.append(int(len(set(topk_shap_idx).intersection(set(topk_delta_idx))) > 0))
    return hits

# -----------------------------
# Z-CROSS alignment with PC1 mapping
# -----------------------------
def apply_pc1_to_nasa_mapping(X: pd.DataFrame) -> pd.DataFrame:
    rename_dict = {}
    for c in X.columns:
        key = _norm_raw_key(c)
        if key in _PC1_TO_NASA_CASEFOLD:
            rename_dict[c] = _PC1_TO_NASA_CASEFOLD[key]
    if rename_dict:
        X = X.rename(columns=rename_dict).copy()
    return X

def _norm_colname(c: str) -> str:
    return re.sub(r"[^a-z0-9]+", "_", str(c).strip().lower())

def align_train_test(Xtr: pd.DataFrame, Xte: pd.DataFrame, train_name: str, test_name: str, verbose: bool = True):
    if "pc1" in train_name.lower():
        Xtr = apply_pc1_to_nasa_mapping(Xtr)
    if "pc1" in test_name.lower():
        Xte = apply_pc1_to_nasa_mapping(Xte)

    tr_map = {_norm_colname(c): c for c in Xtr.columns}
    te_map = {_norm_colname(c): c for c in Xte.columns}
    common = sorted(set(tr_map.keys()).intersection(set(te_map.keys())))

    if len(common) == 0:
        raise ValueError("No common features between train and test after mapping + normalization.")

    Xtr_aligned = Xtr[[tr_map[k] for k in common]].copy()
    Xte_aligned = Xte[[te_map[k] for k in common]].copy()

    Xtr_aligned.columns = common
    Xte_aligned.columns = common

    if verbose:
        print(f"[Align] common_features={len(common)} | drop_train={Xtr.shape[1]-len(common)} | drop_test={Xte.shape[1]-len(common)}")
    return Xtr_aligned, Xte_aligned, common

# -----------------------------
# Z-CROSS core runner
# -----------------------------
def train_on_A_test_on_B(
    train_path: str,
    test_path: str,
    model_name: str = "random_forest",
    rng: int = 42,
    top_k: int = 10,
    eps: float = 1e-3,
    calib_bins: int = 10,
    m_limit: int = 20,
    max_eps_samples: int = 200,
    verbose: bool = True
):
    np.random.seed(rng); random.seed(rng)

    Xtr, ytr, tgt_tr = load_and_preprocess(train_path)
    Xte, yte, tgt_te = load_and_preprocess(test_path)

    Xtr, Xte, common_cols = align_train_test(
        Xtr, Xte,
        train_name=os.path.basename(train_path),
        test_name=os.path.basename(test_path),
        verbose=verbose
    )

    feature_names = Xtr.columns.tolist()
    eps_vec = make_eps_vec_from_train(Xtr, base_eps=eps)

    # TRAIN-only SMOTE
    Xtr_fit, ytr_fit = safe_smote(Xtr, ytr, random_state=rng)

    base_model = make_model(model_name, random_state=rng)

    proba_te, metrics, (_scaler, Xtr_used, Xte_used, fitted_model) = evaluate_once_train_test(
        base_model, Xtr_fit, ytr_fit, Xte, yte,
        model_name=model_name, seed=rng,
        calibrate_all=True,
        calib_method="sigmoid",
        calib_cv=3
    )

    model_for_probs = fitted_model
    model_for_shap  = unwrap_for_shap(fitted_model)

    imp_test, shap_abs_test = global_importance(model_for_shap, Xtr_used, Xte_used, yte)

    Xte_np = Xte_used.values.astype(float)
    yte_np = yte.values.astype(int)

    grad_abs  = finite_diff_grad_loss(model_for_probs, Xte_np, yte_np, eps_vec)
    grad_norm = normalize_rows(grad_abs)

    glr_rhos = np.array(glr_rhos_for_test(shap_abs_test, grad_norm, feature_names), dtype=float)
    GLR_mean = float(np.nanmean(glr_rhos)) if len(glr_rhos) else np.nan

    hitk_flags = []
    if max_eps_samples > 0:
        idxs = np.arange(len(Xte_np))
        np.random.shuffle(idxs)
        chosen = idxs[:min(max_eps_samples, len(idxs))]
        hitk_flags = epsilon_hitk_for_subset(model_for_probs, Xte_np, yte_np, shap_abs_test,
                                             top_k, eps, m_limit, chosen)

    Hitk = float(np.mean(hitk_flags)) if len(hitk_flags) else np.nan
    ECE = float(expected_calibration_error(yte_np, proba_te, n_bins=calib_bins))
    ReliabilityScore = float(np.nanmean([rescale01_rho(GLR_mean), Hitk, max(0.0, 1.0 - ECE)]))

    summary = pd.DataFrame([{
        "Train": os.path.basename(train_path),
        "Test": os.path.basename(test_path),
        "Model": model_name,
        "Train_Target": tgt_tr,
        "Test_Target": tgt_te,
        "n_common_features": len(common_cols),

        "AUC": float(metrics["AUC"]),
        "F1": float(metrics["F1"]),
        "Precision": float(metrics["Precision"]),
        "Recall": float(metrics["Recall"]),
        "Brier": float(metrics["Brier"]),
        "thr": float(metrics["thr"]),

        "GLR_mean": GLR_mean,
        f"Hit@{top_k}": Hitk,
        "ECE": ECE,
        "ReliabilityScore": ReliabilityScore,
        "eps_samples_used": int(min(max_eps_samples, len(Xte_np))),
    }]).round(6)

    artifacts = {
        "summary": summary,
        "metrics": metrics,
        "proba_test": np.asarray(proba_te, dtype=float),
        "y_test": yte_np,
        "glr_rhos": glr_rhos,
        "hitk_flags": np.asarray(hitk_flags, dtype=int) if len(hitk_flags) else np.array([], dtype=int),
        "importance_test": imp_test.sort_values(ascending=False),
        "common_features": common_cols,
    }

    return summary, artifacts

def run_zcross_selected_pairs(
    base_path: str,
    selected_pairs: List[Tuple[str, str]],
    model_name: str = "random_forest",
    rng: int = 42,
    top_k: int = 10,
    eps: float = 1e-3,
    calib_bins: int = 10,
    m_limit: int = 20,
    max_eps_samples: int = 200,
    save_csv: bool = True,
    out_csv: str = None,
):
    rows = []
    artifacts = {}

    for tr_file, te_file in selected_pairs:
        tr_path = os.path.join(base_path, tr_file)
        te_path = os.path.join(base_path, te_file)

        print("\n" + "="*95)
        print(f"[ZCROSS] {tr_file} → {te_file} | model={model_name} | seed={rng}")
        print("="*95)

        summary, art = train_on_A_test_on_B(
            tr_path, te_path,
            model_name=model_name,
            rng=rng,
            top_k=top_k,
            eps=eps,
            calib_bins=calib_bins,
            m_limit=m_limit,
            max_eps_samples=max_eps_samples,
            verbose=True
        )

        rows.append(summary.iloc[0].to_dict())
        artifacts[(tr_file, te_file)] = art

    summary_df = pd.DataFrame(rows)

    if save_csv:
        if out_csv is None:
            out_csv = os.path.join(base_path, f"zcross_{model_name}_summary.csv")
        summary_df.to_csv(out_csv, index=False)
        print(f"\n[Saved] {out_csv}")

    return summary_df, artifacts

import contextlib, io

def run_zcross_with_repeats(
    base_path,
    selected_pairs,
    model_name="random_forest",
    repeats=5,
    seed0=42,

    # behavior controls
    verbose_epochs=False,     # ✅ no per-epoch printing
    verbose_pairs=False,      # ✅ no per-pair printing inside repeats
    final_only=True,          # ✅ return only final aggregated per pair
    collect_importance=True,  # ✅ needed for “one plot for all epochs”

    # saving
    save_csv=True,
    out_csv=None,

    **kwargs
):
    """
    Runs repeats internally but prints almost nothing.
    Returns:
      - final_df: one row per pair (mean±std columns)
      - raw_df (optional): all repeats (if final_only=False)
      - all_pair_importances: dict pair -> list[Series] (if collect_importance=True)
    """

    all_runs = []
    all_pair_importances = {pair: [] for pair in selected_pairs}

    for r in range(repeats):
        seed = seed0 + 101 * r

        # suppress prints during each epoch unless user wants it
        if verbose_epochs:
            print(f"\n######## REPEAT {r+1}/{repeats} | seed={seed} ########")

        # If you want to silence the entire inner prints
        if not verbose_epochs and not verbose_pairs:
            f = io.StringIO()
            with contextlib.redirect_stdout(f):
                summary_df, artifacts = run_zcross_selected_pairs(
                    base_path=base_path,
                    selected_pairs=selected_pairs,
                    model_name=model_name,
                    rng=seed,
                    save_csv=False,
                    **kwargs
                )
        else:
            # allow printing from inside
            summary_df, artifacts = run_zcross_selected_pairs(
                base_path=base_path,
                selected_pairs=selected_pairs,
                model_name=model_name,
                rng=seed,
                save_csv=False,
                **kwargs
            )

        summary_df["repeat"] = r + 1
        all_runs.append(summary_df)

        # collect importances for mean-topk plots across repeats
        if collect_importance:
            for pair in selected_pairs:
                art = artifacts.get(pair, None)
                if art is None:
                    continue

                # prefer "importance_test" from each run
                imp = art.get("importance_test", None)

                # fallbacks
                if imp is None or (hasattr(imp, "empty") and imp.empty):
                    imp = art.get("mean_importance_test", None)
                if imp is None or (hasattr(imp, "empty") and imp.empty):
                    imp = art.get("top_importance_mean_test", None)

                if imp is None or (hasattr(imp, "empty") and imp.empty):
                    continue

                if not isinstance(imp, pd.Series):
                    imp = pd.Series(imp)

                all_pair_importances[pair].append(imp)

    raw_df = pd.concat(all_runs, ignore_index=True)

    # ---------- FINAL ONE ROW PER PAIR ----------
    # group by Train/Test/Model
    grp_cols = ["Train", "Test", "Model", "n_common_features"]

    final_df = raw_df.groupby(grp_cols).agg(
        AUC=("AUC", "mean"),
        F1=("F1", "mean"),
        Precision=("Precision", "mean"),
        Recall=("Recall", "mean"),
        Brier=("Brier", "mean"),
        GLR=("GLR_mean", "mean"),
        ECE=("ECE", "mean"),
        ReliabilityScore=("ReliabilityScore", "mean"),
    ).reset_index()
    # nice rounding
    final_df = final_df.round(6)

    # save only the final per-pair summary
    if save_csv:
        if out_csv is None:
            out_csv = os.path.join(base_path, f"zcross_{model_name}_final_repeats{repeats}.csv")
        final_df.to_csv(out_csv, index=False)
        print(f"[Saved FINAL summary] {out_csv}")

    if final_only:
        return final_df, all_pair_importances
    else:
        return final_df, raw_df, all_pair_importances


def plot_topk(imp: pd.Series, top_k: int = 15, title: str = "", save_path: str = None):
    """
    Simple Top-K horizontal bar plot for a pandas Series (importance).
    """
    if imp is None or len(imp) == 0:
        print("[Plot] Empty importance. Skipping.")
        return

    if not isinstance(imp, pd.Series):
        imp = pd.Series(imp)

    s = imp.dropna().sort_values(ascending=False)
    k = min(top_k, len(s))
    top = s.head(k)[::-1]

    plt.figure(figsize=(10, max(4, 0.5 * k)))
    plt.barh(top.index, top.values)
    plt.xlabel("mean(|importance|)")
    plt.ylabel("Feature")
    plt.title(title if title else f"Top {k} Features")
    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
        plt.close()
        print(f"[Saved] {save_path}")
    else:
        plt.show()


def plot_all_zcross_pairs_mean(
    zcross_artifacts: dict,
    pairs: list,
    top_k: int = 15,
    model_name: str = "",
    only_nonempty: bool = True,
    save_dir: str = None
):
    """
    Plots Top-K MEAN importance for every (train,test) pair in pairs.
    ✅ Does NOT rerun repeats.
    ✅ Uses artifacts already computed.

    Expected artifact keys per pair (tries in this order):
      1) "mean_importance_test"  (best: across repeats/folds)
      2) "importance_test"       (single-run importance)
      3) "top_importance_mean_test" (older name)
    """

    for (tr, te) in pairs:
        key = (tr, te)
        if key not in zcross_artifacts:
            print(f"[SKIP] No artifacts found for {tr} → {te}")
            continue

        art = zcross_artifacts[key]

        # try multiple possible keys
        imp = art.get("mean_importance_test", None)
        if imp is None or (hasattr(imp, "empty") and imp.empty):
            imp = art.get("importance_test", None)
        if imp is None or (hasattr(imp, "empty") and imp.empty):
            imp = art.get("top_importance_mean_test", None)

        if imp is None or (hasattr(imp, "empty") and imp.empty) or (isinstance(imp, pd.Series) and len(imp) == 0):
            if only_nonempty:
                print(f"[SKIP] Empty importance for {tr} → {te}")
                continue
            else:
                print(f"[WARN] Empty importance for {tr} → {te}")
                continue

        title = f"{tr} → {te}"
        if model_name:
            title += f" ({model_name})"
        title += " — mean importance on TEST"

        save_path = None
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            safe_name = f"{tr.replace('.','_')}__TO__{te.replace('.','_')}"
            if model_name:
                safe_name += f"__{model_name}"
            save_path = os.path.join(save_dir, safe_name + ".png")

        plot_topk(imp, top_k=top_k, title=title, save_path=save_path)

In [2]:
BASE_PATH = "/content/"

SELECTED_PAIRS = [
    ("pc1.arff","pc2.arff"), ("pc1.arff","pc3.arff"), ("pc1.arff","pc4.arff"),
    ("pc2.arff","pc1.arff"), ("pc2.arff","pc3.arff"), ("pc2.arff","pc4.arff"),
    # ("pc3.arff","pc1.arff"), ("pc3.arff","pc2.arff"), ("pc3.arff","pc4.arff"),
    # ("pc4.arff","pc1.arff"), ("pc4.arff","pc2.arff"), ("pc4.arff","pc3.arff"),
]

# ---- Recommended: repeats for stability ----
final_df, all_pair_importances = run_zcross_with_repeats(
    base_path=BASE_PATH,
    selected_pairs=SELECTED_PAIRS,
    model_name="lightgbm",
    repeats=5,
    seed0=42,
    verbose_epochs=False,
    verbose_pairs=False,
    final_only=True,
    collect_importance=True,
    top_k=10,
    eps=1e-3,
    calib_bins=10,
    m_limit=20,
    max_eps_samples=200,
    save_csv=True
)

print(final_df)

zcross_artifacts = {}
for pair, imp_list in all_pair_importances.items():
    if imp_list:
        zcross_artifacts[pair] = {
            "mean_importance_test": pd.concat(imp_list, axis=1).mean(axis=1)
        }

plot_all_zcross_pairs_mean(
    zcross_artifacts=zcross_artifacts,
    pairs=SELECTED_PAIRS,
    top_k=15,
    model_name="lightgbm",
    only_nonempty=True,
    save_dir=os.path.join(BASE_PATH, "zcross_plots_mean")
)



[Saved FINAL summary] /content/zcross_lightgbm_final_repeats5.csv
      Train      Test     Model  n_common_features       AUC        F1  \
0  pc1.arff  pc2.arff  lightgbm                 10  0.874505  0.122346   
1  pc1.arff  pc3.arff  lightgbm                 10  0.763076  0.297735   
2  pc1.arff  pc4.arff  lightgbm                 10  0.774102  0.245958   
3  pc2.arff  pc1.arff  lightgbm                 10  0.695553  0.192975   
4  pc2.arff  pc3.arff  lightgbm                 36  0.694107  0.040807   
5  pc2.arff  pc4.arff  lightgbm                 36  0.530944  0.002210   

   Precision    Recall     Brier       GLR       ECE  ReliabilityScore  
0   0.087661  0.208696  0.009916  0.220888  0.036727          0.857906  
1   0.361270  0.253750  0.100885  0.250483  0.077477          0.849255  
2   0.430620  0.173033  0.110542  0.183926  0.080381          0.837194  
3   0.256948  0.155844  0.078149  0.103722  0.074985          0.825626  
4   0.219118  0.022500  0.103828  0.378647  0.1020

In [3]:
BASE_PATH = "/content/"

SELECTED_PAIRS = [
    # ("pc1.arff","pc2.arff"), ("pc1.arff","pc3.arff"), ("pc1.arff","pc4.arff"),
    # ("pc2.arff","pc1.arff"), ("pc2.arff","pc3.arff"), ("pc2.arff","pc4.arff"),
    ("pc3.arff","pc1.arff"), ("pc3.arff","pc2.arff"), ("pc3.arff","pc4.arff"),
    ("pc4.arff","pc1.arff"), ("pc4.arff","pc2.arff"), ("pc4.arff","pc3.arff"),
]

# ---- Recommended: repeats for stability ----
final_df, all_pair_importances = run_zcross_with_repeats(
    base_path=BASE_PATH,
    selected_pairs=SELECTED_PAIRS,
    model_name="lightgbm",
    repeats=5,
    seed0=42,
    verbose_epochs=False,
    verbose_pairs=False,
    final_only=True,
    collect_importance=True,
    top_k=10,
    eps=1e-3,
    calib_bins=10,
    m_limit=20,
    max_eps_samples=200,
    save_csv=True
)

print(final_df)

zcross_artifacts = {}
for pair, imp_list in all_pair_importances.items():
    if imp_list:
        zcross_artifacts[pair] = {
            "mean_importance_test": pd.concat(imp_list, axis=1).mean(axis=1)
        }

plot_all_zcross_pairs_mean(
    zcross_artifacts=zcross_artifacts,
    pairs=SELECTED_PAIRS,
    top_k=15,
    model_name="lightgbm",
    only_nonempty=True,
    save_dir=os.path.join(BASE_PATH, "zcross_plots_mean")
)




[Saved FINAL summary] /content/zcross_lightgbm_final_repeats5.csv
      Train      Test     Model  n_common_features       AUC        F1  \
0  pc3.arff  pc1.arff  lightgbm                 10  0.799967  0.266768   
1  pc3.arff  pc2.arff  lightgbm                 36  0.777147  0.071655   
2  pc3.arff  pc4.arff  lightgbm                 37  0.729632  0.124047   
3  pc4.arff  pc1.arff  lightgbm                 10  0.744561  0.197015   
4  pc4.arff  pc2.arff  lightgbm                 36  0.827323  0.071648   
5  pc4.arff  pc3.arff  lightgbm                 37  0.729776  0.143899   

   Precision    Recall     Brier       GLR       ECE  ReliabilityScore  
0   0.307399  0.236363  0.068123  0.320732  0.037077          0.874430  
1   0.051613  0.130435  0.011407  0.124310  0.063462          0.832898  
2   0.235644  0.084270  0.119672  0.293922  0.075919          0.857014  
3   0.205284  0.189610  0.078594  0.445295  0.043711          0.892979  
4   0.038483  0.521739  0.033657  0.080766  0.0789