This notebook has stacking implementation with the following metalearners for multiclass classification:
1. TabNet

# Imports & Dataset Setup

In [1]:
import numpy as np
import optuna
from optuna import Trial, visualization
seed = 42
np.random.seed(seed)
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score, classification_report
from sklearn.model_selection import PredefinedSplit, GridSearchCV, RandomizedSearchCV
from scipy.stats import uniform, randint
import torch
import matplotlib.pyplot as plt

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
!pip install -q pytorch_tabnet

In [4]:
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from pytorch_tabnet.tab_model import TabNetClassifier

In [5]:
m_mobilenet_train = pd.read_csv("/kaggle/input/fork-of-koa-mobilenetv2/m_mobilenet_train.csv").drop(columns=['Unnamed: 0'])
m_mobilenet_val = pd.read_csv("/kaggle/input/fork-of-koa-mobilenetv2/m_mobilenet_val.csv").drop(columns=['Unnamed: 0'])
m_mobilenet_test = pd.read_csv("/kaggle/input/fork-of-koa-mobilenetv2/m_mobilenet_test.csv").drop(columns=['Unnamed: 0'])

In [6]:
m_densenet_train = pd.read_csv("/kaggle/input/koa-densenet-preds/m_densenet_train.csv").drop(columns=['Unnamed: 0'])
m_densenet_val = pd.read_csv("/kaggle/input/koa-densenet-preds/m_densenet_val.csv").drop(columns=['Unnamed: 0'])
m_densenet_test = pd.read_csv("/kaggle/input/koa-densenet-preds/m_densenet_test.csv").drop(columns=['Unnamed: 0'])

In [7]:
m_yolov8_train = pd.read_csv("/kaggle/input/koa-yolov8-preds/m_yolov8_train.csv").drop(columns=['Unnamed: 0'])
m_yolov8_val = pd.read_csv("/kaggle/input/koa-yolov8-preds/m_yolov8_val.csv").drop(columns=['Unnamed: 0'])
m_yolov8_test = pd.read_csv("/kaggle/input/koa-yolov8-preds/m_yolov8_test.csv").drop(columns=['Unnamed: 0'])

In [8]:
m_yolov8_train.rename(columns = {'FilePath': 'FileName'}, inplace=True)
m_yolov8_val.rename(columns = {'FilePath': 'FileName'}, inplace=True)
m_yolov8_test.rename(columns = {'FilePath': 'FileName'}, inplace=True)

In [32]:
train = m_mobilenet_train.merge(m_yolov8_train).merge(m_densenet_train)
val = m_mobilenet_val.merge(m_yolov8_val).merge(m_densenet_val)
test = m_mobilenet_test.merge(m_yolov8_test).merge(m_densenet_test)

In [42]:
X_train = train[['m_0', 'm_1', 'm_2', 'm_3', 'm_4', 'y_0', 'y_1', 'y_2', 'y_3', 'y_4', 'd_0', 'd_1', 'd_2', 'd_3', 'd_4']]
X_val = val[['m_0', 'm_1', 'm_2', 'm_3', 'm_4', 'y_0', 'y_1', 'y_2', 'y_3', 'y_4', 'd_0', 'd_1', 'd_2', 'd_3', 'd_4']]
X_test = test[['m_0', 'm_1', 'm_2', 'm_3', 'm_4', 'y_0', 'y_1', 'y_2', 'y_3', 'y_4', 'd_0', 'd_1', 'd_2', 'd_3', 'd_4']]
y_train = train[['y_true']]
y_val = val[['y_true']]
y_test = test[['y_true']]

# TabNet

In [43]:
X_train = X_train.to_numpy()
y_train = y_train.to_numpy().flatten().reshape(-1,1)

In [44]:
X_val= X_val.to_numpy()
y_val= y_val.to_numpy()
y_val= y_val.flatten()
y_val=y_val.reshape(len(y_val),1)

In [75]:
tabnet_params = {
                 "optimizer_fn":torch.optim.Adam,
                 "optimizer_params":dict(lr=2e-3),
                 "scheduler_params":{"step_size":50, 
                                 "gamma":0.9},
                 "scheduler_fn":torch.optim.lr_scheduler.StepLR,
                 "mask_type":'entmax',
                 "seed": seed
                }

clf = TabNetMultiTaskClassifier(**tabnet_params)

In [82]:
clf.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_train, y_train), (X_val, y_val)],
    eval_name=['train', 'valid'],
    max_epochs=200 , patience=0,
    batch_size=1024, virtual_batch_size=128,
    num_workers=0,
    drop_last=False # Optional, just an example of list usage
) 

epoch 0  | loss: 1.72672 | train_logloss: 1.59694 | valid_logloss: 1.606   |  0:00:00s
epoch 1  | loss: 1.4005  | train_logloss: 1.52146 | valid_logloss: 1.54002 |  0:00:01s
epoch 2  | loss: 1.20745 | train_logloss: 1.42974 | valid_logloss: 1.46744 |  0:00:02s
epoch 3  | loss: 1.0203  | train_logloss: 1.34184 | valid_logloss: 1.40344 |  0:00:02s
epoch 4  | loss: 0.85241 | train_logloss: 1.24302 | valid_logloss: 1.33693 |  0:00:03s
epoch 5  | loss: 0.73541 | train_logloss: 1.13543 | valid_logloss: 1.27166 |  0:00:04s
epoch 6  | loss: 0.66878 | train_logloss: 1.06111 | valid_logloss: 1.23024 |  0:00:04s
epoch 7  | loss: 0.61468 | train_logloss: 1.01397 | valid_logloss: 1.20996 |  0:00:05s
epoch 8  | loss: 0.58893 | train_logloss: 0.98658 | valid_logloss: 1.19992 |  0:00:06s
epoch 9  | loss: 0.5459  | train_logloss: 0.95767 | valid_logloss: 1.19363 |  0:00:06s
epoch 10 | loss: 0.52946 | train_logloss: 0.93408 | valid_logloss: 1.18725 |  0:00:07s
epoch 11 | loss: 0.51087 | train_logloss: 0

In [83]:
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score

In [84]:
print("Testing on training set:")
print("Accuracy: ", accuracy_score(y_train.flatten().reshape(-1,1),clf.predict(X_train)[0].astype('int').reshape(-1,1)).round(3))
print("Balanced Accuracy: ", balanced_accuracy_score(y_train.flatten().reshape(-1,1),clf.predict(X_train)[0].astype('int').reshape(-1,1)).round(3))
print("AUC:", roc_auc_score(y_train,clf.predict_proba(X_train)[0],multi_class='ovr').round(3))

print("\nTesting on validation set:")
print("Accuracy: ", accuracy_score(y_val.flatten().reshape(-1,1),clf.predict(X_val)[0].astype('int').reshape(-1,1)).round(3))
print("Balanced Accuracy: ", balanced_accuracy_score(y_val.flatten().reshape(-1,1),clf.predict(X_val)[0].astype('int').reshape(-1,1)).round(3))
print("AUC:", roc_auc_score(y_val,clf.predict_proba(X_val)[0],multi_class='ovr').round(3))

print("\nTesting on testing set:")
print("Accuracy: ", accuracy_score(y_test.to_numpy().flatten().reshape(-1,1),clf.predict(X_test.to_numpy())[0].astype('int').reshape(-1,1)).round(3))
print("Balanced Accuracy: ", balanced_accuracy_score(y_test.to_numpy().flatten().reshape(-1,1),clf.predict(X_test.to_numpy())[0].astype('int').reshape(-1,1)).round(3))
print("AUC:", roc_auc_score(y_test,clf.predict_proba(X_test.to_numpy())[0],multi_class='ovr').round(3))

Testing on training set:
Accuracy:  0.873
Balanced Accuracy:  0.891
AUC: 0.978

Testing on validation set:
Accuracy:  0.607
Balanced Accuracy:  0.626
AUC: 0.811

Testing on testing set:
Accuracy:  0.708
Balanced Accuracy:  0.723
AUC: 0.889
