In [1]:
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from pmlb import fetch_data

SEED = 42

In [2]:
datasets = ["allbp", "Hill_Valley_with_noise","Hill_Valley_without_noise","adult","allhyper","breast_cancer"]

In [3]:
def tabnet_benchmark(dataset_name, seed):
    df = fetch_data(dataset_name)

    X_train, X_valid, y_train, y_valid = train_test_split(df.drop(["target"], axis="columns"), df["target"], train_size = 0.8, stratify=df["target"], random_state = seed)

        # Step 1: Fit TabNet
    tabnet = TabNetClassifier(
        n_d=8, n_a=8, n_steps=3,
        gamma=1.3, lambda_sparse=1e-3,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2),
        mask_type='sparsemax',
        verbose=0
    )

    tabnet.fit(
        X_train=X_train.values, y_train=y_train.values,
        max_epochs=100,
        patience=10,
        batch_size=1024,
        virtual_batch_size=128,
        num_workers=0,
        drop_last=False
    )

    # Predict on the test set
    y_pred = tabnet.predict(X_valid.values)

    acc = accuracy_score(y_valid, y_pred)
    print(f"Accuracy on {dataset_name}: {acc}")
    return acc

In [4]:
tabnet_acc = []
for i in datasets:
    tabnet_acc.append(tabnet_benchmark(i, SEED))

for i in range(len(datasets)):
    print(f"Accuracy of Tabnet on {datasets[i]}:{tabnet_acc[i]}")



Accuracy on allbp: 0.9615894039735099




Accuracy on Hill_Valley_with_noise: 0.4403292181069959




Accuracy on Hill_Valley_without_noise: 0.48559670781893005




Accuracy on adult: 0.852492578564848




Accuracy on allhyper: 0.9867549668874173




Accuracy on breast_cancer: 0.5517241379310345
Accuracy of Tabnet on allbp:0.9615894039735099
Accuracy of Tabnet on Hill_Valley_with_noise:0.4403292181069959
Accuracy of Tabnet on Hill_Valley_without_noise:0.48559670781893005
Accuracy of Tabnet on adult:0.852492578564848
Accuracy of Tabnet on allhyper:0.9867549668874173
Accuracy of Tabnet on breast_cancer:0.5517241379310345
