TABNET

In [None]:
!pip install pytorch-tabnet

In [None]:
import torch
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import (
    accuracy_score, classification_report, roc_auc_score,
    roc_curve, f1_score
)
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [None]:
def run_tabnet_pipeline(X_train, y_train, X_val, y_val, X_test, y_test, feature_names):
    # === Convert to NumPy ===
    X_train_np = X_train.values if isinstance(X_train, pd.DataFrame) else X_train
    y_train_np = y_train.values.ravel() if isinstance(y_train, (pd.Series, pd.DataFrame)) else y_train
    X_val_np = X_val.values if isinstance(X_val, pd.DataFrame) else X_val
    y_val_np = y_val.values.ravel() if isinstance(y_val, (pd.Series, pd.DataFrame)) else y_val
    X_test_np = X_test.values if isinstance(X_test, pd.DataFrame) else X_test
    y_test_np = y_test.values.ravel() if isinstance(y_test, (pd.Series, pd.DataFrame)) else y_test

    # === 1. Train TabNet ===
    model_tabnet = TabNetClassifier(
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        scheduler_params={"step_size": 10, "gamma": 0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='sparsemax',
        seed=42,
        verbose=1
    )

    model_tabnet.fit(
        X_train=X_train_np, y_train=y_train_np,
        eval_set=[(X_val_np, y_val_np)],
        eval_name=["val"],
        eval_metric=["auc"],
        max_epochs=100,
        patience=10,
        batch_size=256,
        virtual_batch_size=128,
        num_workers=0,
        drop_last=False
    )

    # === 2. Predict ===
    val_probs = model_tabnet.predict_proba(X_val_np)[:, 1]
    test_probs = model_tabnet.predict_proba(X_test_np)[:, 1]

    # === 3. Best Threshold ===
    thresholds = np.linspace(0, 1, 101)
    f1s = [f1_score(y_val_np, (val_probs >= t).astype(int)) for t in thresholds]
    best_threshold = thresholds[np.argmax(f1s)]
    print(f"✅ Best threshold (F1): {best_threshold:.2f}")

    # === 4. Evaluate ===
    test_preds = (test_probs >= best_threshold).astype(int)
    accuracy = accuracy_score(y_test_np, test_preds)
    auc = roc_auc_score(y_test_np, test_probs)
    print(f"\n🔎 Test Accuracy: {accuracy:.2f}")
    print("Classification Report:")
    print(classification_report(y_test_np, test_preds, target_names=["Class 0", "Class 1"]))
    print(f"AUC-ROC (Test): {auc:.2f}")

    report_df = pd.DataFrame(
        classification_report(y_test_np, test_preds, output_dict=True)
    ).T

    # === 5. ROC Curve ===
    fpr, tpr, _ = roc_curve(y_test_np, test_probs)
    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.title("ROC Curve - TabNet")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # === 6. Feature Importance ===
    feature_importance = pd.DataFrame({
        'Feature': feature_names,
        'Importance': model_tabnet.feature_importances_
    }).sort_values(by='Importance', ascending=False)

    print("\n📊 Top 10 Important Features (TabNet):")
    print(feature_importance.head(10))

    # === 7. Return Structured Results ===
    results_tabnet = {
        'val_probs': val_probs,
        'test_probs': test_probs,
        'test_preds': test_preds,
        'accuracy': accuracy,
        'auc': auc,
        'best_threshold': best_threshold,
        'fpr': fpr,
        'tpr': tpr,
        'report_df': report_df,
        'feature_importance': feature_importance
    }

    return model_tabnet, results_tabnet
