STACKING ENSEMBLE

In [None]:
def run_stacking_ensemble_pipeline(
    base_val_probs: dict,
    base_test_probs: dict,
    y_val, y_test,
    meta_model="logistic",  # or "xgboost"
    random_state=42
):
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import (
        accuracy_score, classification_report, f1_score,
        roc_auc_score, roc_curve
    )
    from xgboost import XGBClassifier

    # === Step 1: Stack validation and test probs into feature matrices
    model_names = list(base_val_probs.keys())
    X_val_stack = np.column_stack([base_val_probs[m] for m in model_names])
    X_test_stack = np.column_stack([base_test_probs[m] for m in model_names])
    y_val = y_val.values.ravel() if hasattr(y_val, "values") else y_val
    y_test = y_test.values.ravel() if hasattr(y_test, "values") else y_test

    # === Step 2: Train Meta Model
    if meta_model == "logistic":
        meta_clf = LogisticRegression(max_iter=1000, random_state=random_state)
    elif meta_model == "xgboost":
        meta_clf = XGBClassifier(use_label_encoder=False, eval_metric="logloss", random_state=random_state)
    else:
        raise ValueError("Unsupported meta_model")

    meta_clf.fit(X_val_stack, y_val)

    # === Step 3: Predict
    val_probs = meta_clf.predict_proba(X_val_stack)[:, 1]
    test_probs = meta_clf.predict_proba(X_test_stack)[:, 1]

    # === Step 4: Threshold search
    thresholds = np.linspace(0, 1, 101)
    f1s = [f1_score(y_val, (val_probs >= t).astype(int)) for t in thresholds]
    best_threshold = thresholds[np.argmax(f1s)]
    print(f"\n✅ Best threshold (F1) for ensemble: {best_threshold:.2f}")

    test_preds = (test_probs >= best_threshold).astype(int)

    # === Step 5: Evaluate
    accuracy = accuracy_score(y_test, test_preds)
    auc = roc_auc_score(y_test, test_probs)
    report_str = classification_report(y_test, test_preds, target_names=["Class 0", "Class 1"])
    report_df = pd.DataFrame(classification_report(y_test, test_preds, output_dict=True)).T

    print(f"\n🔎 Test Accuracy (Stacked): {accuracy:.2f}")
    print("Classification Report (Stacked):")
    print(report_str)
    print(f"AUC-ROC (Stacked): {auc:.2f}")

    # === Step 6: ROC Curve
    fpr, tpr, _ = roc_curve(y_test, test_probs)
    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve - Stacked Ensemble")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # === Step 7: Meta Feature Importance
    try:
        if meta_model == "logistic":
            coefs = meta_clf.coef_[0]
            feature_importance = pd.DataFrame({
                "Model": model_names,
                "Importance": np.abs(coefs)
            }).sort_values(by="Importance", ascending=False)
        elif meta_model == "xgboost":
            importances = meta_clf.feature_importances_
            feature_importance = pd.DataFrame({
                "Model": model_names,
                "Importance": importances
            }).sort_values(by="Importance", ascending=False)

        # Round to 2 decimal places and sort
        feature_importance["Importance"] = feature_importance["Importance"].round(2)
        feature_importance = feature_importance.sort_values(by="Importance", ascending=False)

    except:
        feature_importance = pd.DataFrame(columns=["Model", "Importance"])

    print("\n📊 Meta-model Feature Importance:")
    print(feature_importance.head(10))

    # === Step 8: Return results
    results_ensemble = {
        "val_probs": val_probs,
        "test_probs": test_probs,
        "test_preds": test_preds,
        "accuracy": accuracy,
        "auc": auc,
        "best_threshold": best_threshold,
        "fpr": fpr,
        "tpr": tpr,
        "report_df": report_df,
        "feature_importance": feature_importance
    }

    return meta_clf, results_ensemble
