In [1]:
%%capture
!pip install pytorch-tabnet

In [2]:
import numpy as np
import pandas as pd
import os
import shutil
import joblib
import pickle
import json

# ml frameworks
from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance
from sklearn.datasets import make_classification # for test of funcs
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    precision_recall_curve,
    auc,
    RocCurveDisplay,
    PrecisionRecallDisplay,
)

import torch
#time management
from tqdm import tqdm
import time

#stats
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import cdist
from itertools import combinations
#download data from hub
from huggingface_hub import hf_hub_download
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.callbacks import EarlyStopping

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [28]:
REPO_ID = "powidla/Friend-Or-Foe"

X_train_ID = "Generative/AGORA/50/GEN/df_train_AG-50.csv"
X_test_ID = "Generative/AGORA/50/GEN/df_test_AG-50.csv"

train = pd.read_csv(hf_hub_download(repo_id=REPO_ID, filename=X_train_ID, repo_type="dataset"))
test = pd.read_csv(hf_hub_download(repo_id=REPO_ID, filename=X_test_ID, repo_type="dataset"))

In [24]:
aug = pd.read_csv('ctgan.csv')

In [25]:
X_train, y_train = train.drop(columns="label"), train["label"]
X_test, y_test = test.drop(columns="label"), test["label"]
X_aug, y_aug = aug.drop(columns="label"), aug["label"]

In [26]:
X_aug.columns = X_train.columns.values

In [27]:
def create_confusion_matrix(y_true, y_pred):
    '''
    Description: Create a confusion matrix.
    Arguments: y_true (array-like): Ground truth labels;
               y_pred (array-like): Predicted labels.
    Outputs:
        pd.DataFrame: A confusion matrix as a pandas DataFrame.
    '''
    cm = confusion_matrix(y_true, y_pred)
    cm_df = pd.DataFrame(cm, index=["True Negative", "True Positive"],
                             columns=["Predicted Negative", "Predicted Positive"])
    return cm_df

def score_metrics(y_true, y_pred, y_prob):
    '''
    Description: Calculate various metrics for binary classification.
    Arguments: y_true (array-like): Ground truth labels;
               y_pred (array-like): Predicted labels;
               y_prob (array-like): Predicted probabilities for the positive class.
    Outputs:
        dict
    '''
    metrics = {
        "Accuracy": accuracy_score(y_true, y_pred),
        "ROC AUC": roc_auc_score(y_true, y_prob),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "F1 Score": f1_score(y_true, y_pred),
        "MCC": matthews_corrcoef(y_true, y_pred),
    }
    # PR AUC
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    metrics["PR AUC"] = auc(recall, precision)
    return metrics

In [16]:
def train_and_evaluate_with_augmentation(
    X_train, y_train, X_val, y_val, X_test, y_test, X_aug, y_aug,
    output_dir="tabnet_comparison", seed=4221, max_epochs=100, patience=10
):
    os.makedirs(output_dir, exist_ok=True)

    results = {}

    for tag, (xtr, ytr) in {
        "Original": (X_train, y_train),
        "Augmented": (pd.concat([X_train, X_aug]), pd.concat([y_train, y_aug])),
    }.items():
        print(f"\n--- Training on {tag} data ---")

        clf = TabNetClassifier(
            cat_idxs=[],
            cat_dims=[],
            cat_emb_dim=1,
            optimizer_fn=torch.optim.AdamW,
            optimizer_params=dict(lr=1e-4, weight_decay=0.02),
            scheduler_params={"step_size": 50, "gamma": 0.99},
            scheduler_fn=torch.optim.lr_scheduler.StepLR,
            mask_type='entmax',
            n_d=64,
            n_a=64,
            seed=seed,
            device_name=device,
        )

        clf.fit(
            X_train=xtr.values, y_train=ytr.values,
            eval_set=[(X_val.values, y_val.values)],
            eval_metric=['accuracy'],
            max_epochs=max_epochs,
            patience=patience,
            batch_size=1024,
        )

        y_pred = clf.predict(X_test.values)
        y_proba = clf.predict_proba(X_test.values)[:, 1]

        metrics = score_metrics(y_test, y_pred, y_proba)
        results[tag] = metrics

        with open(os.path.join(output_dir, f"metrics_{tag}.json"), "w") as f:
            json.dump(metrics, f, indent=4)

        clf.save_model(os.path.join(output_dir, f"model_{tag}.zip"))

        print(f"\n{tag} Test Metrics:")
        for k, v in metrics.items():
            print(f"{k}: {v:.4f}")

    return results

In [29]:
X_train_base, X_val, y_train_base, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train)

# Train and evaluate comparing test metrincs
results = train_and_evaluate_with_augmentation(
    X_train_base, y_train_base,
    X_val, y_val,
    X_test, y_test,
    X_aug, y_aug,
    output_dir="tabnet_comparison_AG50"
)


--- Training on Original data ---




epoch 0  | loss: 1.79068 | val_0_accuracy: 0.51031 |  0:00:00s
epoch 1  | loss: 1.74192 | val_0_accuracy: 0.49938 |  0:00:01s
epoch 2  | loss: 1.65618 | val_0_accuracy: 0.49406 |  0:00:02s
epoch 3  | loss: 1.58204 | val_0_accuracy: 0.49344 |  0:00:02s
epoch 4  | loss: 1.52516 | val_0_accuracy: 0.49562 |  0:00:03s
epoch 5  | loss: 1.46996 | val_0_accuracy: 0.49375 |  0:00:04s
epoch 6  | loss: 1.41495 | val_0_accuracy: 0.49531 |  0:00:04s
epoch 7  | loss: 1.38491 | val_0_accuracy: 0.50156 |  0:00:05s
epoch 8  | loss: 1.33403 | val_0_accuracy: 0.50156 |  0:00:06s
epoch 9  | loss: 1.29832 | val_0_accuracy: 0.49781 |  0:00:07s
epoch 10 | loss: 1.26726 | val_0_accuracy: 0.50438 |  0:00:07s

Early stopping occurred at epoch 10 with best_epoch = 0 and best_val_0_accuracy = 0.51031




Successfully saved model at tabnet_comparison_AG50/model_Original.zip.zip

Original Test Metrics:
Accuracy: 0.4918
ROC AUC: 0.4988
Precision: 0.5074
Recall: 0.1010
F1 Score: 0.1685
MCC: -0.0016
PR AUC: 0.5106

--- Training on Augmented data ---




epoch 0  | loss: 1.59151 | val_0_accuracy: 0.51344 |  0:00:01s
epoch 1  | loss: 1.51977 | val_0_accuracy: 0.49312 |  0:00:02s
epoch 2  | loss: 1.45153 | val_0_accuracy: 0.4925  |  0:00:03s
epoch 3  | loss: 1.36601 | val_0_accuracy: 0.49781 |  0:00:03s
epoch 4  | loss: 1.33931 | val_0_accuracy: 0.49688 |  0:00:04s
epoch 5  | loss: 1.26871 | val_0_accuracy: 0.50062 |  0:00:05s
epoch 6  | loss: 1.23794 | val_0_accuracy: 0.49938 |  0:00:06s
epoch 7  | loss: 1.21242 | val_0_accuracy: 0.50531 |  0:00:07s
epoch 8  | loss: 1.15156 | val_0_accuracy: 0.50719 |  0:00:08s
epoch 9  | loss: 1.13052 | val_0_accuracy: 0.5075  |  0:00:09s
epoch 10 | loss: 1.12723 | val_0_accuracy: 0.50875 |  0:00:10s

Early stopping occurred at epoch 10 with best_epoch = 0 and best_val_0_accuracy = 0.51344




Successfully saved model at tabnet_comparison_AG50/model_Augmented.zip.zip

Augmented Test Metrics:
Accuracy: 0.4938
ROC AUC: 0.5058
Precision: 0.5167
Recall: 0.1059
F1 Score: 0.1758
MCC: 0.0048
PR AUC: 0.5168
