In [None]:
# https://github.com/dreamquark-ai/tabnet

In [2]:
import torch
import pandas as pd
import numpy as np
import os

In [3]:
Data = pd.read_csv("dataset.tsv", sep="\t")
LabelNames = Data.columns[56:]

In [4]:
#X, Y = Data.loc[:, SelectedCols], Data.loc[:, LabelNames]
X, Y = Data.iloc[:, 1:56], Data.loc[:, LabelNames]

In [10]:
X.iloc[:,0]

0       Female
1       Female
2       Female
3       Female
4       Female
         ...  
4639    Female
4640      Male
4641    Female
4642      Male
4643      Male
Name: Sex.Male.1..Female.2., Length: 4644, dtype: object

In [11]:
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

<br><br><br><br><br><br><br>
## Fine-tunning the Pretrained!

In [12]:
import pytorch_tabnet
from pytorch_tabnet.pretraining import TabNetPretrainer
nCategory = [len(X_train.iloc[:,idx].unique()) for idx, col in enumerate(X_train.columns)]
unsupervised_model = TabNetPretrainer(
    cat_idxs=[idx for idx, nCat in enumerate(nCategory) if nCat <= 2],
    cat_dims=[nCat for idx, nCat in enumerate(nCategory) if nCat <= 2],
    cat_emb_dim=5,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=1e-2),
    mask_type='entmax', # "sparsemax",
    n_shared_decoder=1, # nb shared glu for decoding
    n_indep_decoder=1, # nb independent glu for decoding
    # grouped_features=[[0, 1]], # you can group features together here
    verbose=5,
)



In [None]:
unsupervised_model.load_model('./MODEL/PretrainedModel_0.3788.zip')

In [6]:
unsupervised_model.fit(
    X_train=X_train.values,
    eval_set=[X_test.values],
    max_epochs=2000,
    patience=150,
    batch_size=512,
    virtual_batch_size=128,
    num_workers=0,
    drop_last=False,
    pretraining_ratio=0.5,
)

epoch 0  | loss: 10076.23709| val_0_unsup_loss_numpy: 28.58395004272461|  0:00:01s
epoch 5  | loss: 2.20777 | val_0_unsup_loss_numpy: 1.53056001663208|  0:00:03s
epoch 10 | loss: 1.14497 | val_0_unsup_loss_numpy: 1.1392500400543213|  0:00:05s
epoch 15 | loss: 1.10751 | val_0_unsup_loss_numpy: 1.0642499923706055|  0:00:06s
epoch 20 | loss: 1.07711 | val_0_unsup_loss_numpy: 1.024049997329712|  0:00:08s
epoch 25 | loss: 1.03102 | val_0_unsup_loss_numpy: 0.9632899761199951|  0:00:10s
epoch 30 | loss: 0.94618 | val_0_unsup_loss_numpy: 0.8659800291061401|  0:00:12s
epoch 35 | loss: 0.87737 | val_0_unsup_loss_numpy: 0.7919399738311768|  0:00:14s
epoch 40 | loss: 0.84995 | val_0_unsup_loss_numpy: 0.7683899998664856|  0:00:16s
epoch 45 | loss: 0.80579 | val_0_unsup_loss_numpy: 0.7403200268745422|  0:00:17s
epoch 50 | loss: 0.81125 | val_0_unsup_loss_numpy: 0.7226899862289429|  0:00:19s
epoch 55 | loss: 0.78451 | val_0_unsup_loss_numpy: 0.7099300026893616|  0:00:21s
epoch 60 | loss: 0.77585 | va

epoch 505| loss: 0.5565  | val_0_unsup_loss_numpy: 0.3974800109863281|  0:03:07s
epoch 510| loss: 0.55386 | val_0_unsup_loss_numpy: 0.3972199857234955|  0:03:09s
epoch 515| loss: 0.54283 | val_0_unsup_loss_numpy: 0.39548999071121216|  0:03:11s
epoch 520| loss: 0.5343  | val_0_unsup_loss_numpy: 0.4000000059604645|  0:03:12s
epoch 525| loss: 0.5531  | val_0_unsup_loss_numpy: 0.39478999376296997|  0:03:14s
epoch 530| loss: 0.55051 | val_0_unsup_loss_numpy: 0.3981100022792816|  0:03:16s
epoch 535| loss: 0.55632 | val_0_unsup_loss_numpy: 0.3954800069332123|  0:03:18s
epoch 540| loss: 0.56735 | val_0_unsup_loss_numpy: 0.3996100127696991|  0:03:20s
epoch 545| loss: 0.56326 | val_0_unsup_loss_numpy: 0.40084001421928406|  0:03:22s
epoch 550| loss: 0.54587 | val_0_unsup_loss_numpy: 0.39831000566482544|  0:03:23s
epoch 555| loss: 0.54907 | val_0_unsup_loss_numpy: 0.3943299949169159|  0:03:25s
epoch 560| loss: 0.54799 | val_0_unsup_loss_numpy: 0.39438000321388245|  0:03:27s
epoch 565| loss: 0.5321

epoch 1010| loss: 0.55613 | val_0_unsup_loss_numpy: 0.39333000779151917|  0:06:16s
epoch 1015| loss: 0.53448 | val_0_unsup_loss_numpy: 0.3737800121307373|  0:06:18s
epoch 1020| loss: 0.52174 | val_0_unsup_loss_numpy: 0.3698900043964386|  0:06:20s
epoch 1025| loss: 0.51175 | val_0_unsup_loss_numpy: 0.38106998801231384|  0:06:21s
epoch 1030| loss: 0.53075 | val_0_unsup_loss_numpy: 0.3800300061702728|  0:06:23s
epoch 1035| loss: 0.52976 | val_0_unsup_loss_numpy: 0.37564998865127563|  0:06:25s
epoch 1040| loss: 0.5223  | val_0_unsup_loss_numpy: 0.3884899914264679|  0:06:27s
epoch 1045| loss: 0.53956 | val_0_unsup_loss_numpy: 0.3829199969768524|  0:06:29s
epoch 1050| loss: 0.54764 | val_0_unsup_loss_numpy: 0.38596999645233154|  0:06:31s
epoch 1055| loss: 0.53472 | val_0_unsup_loss_numpy: 0.3803099989891052|  0:06:32s
epoch 1060| loss: 0.53371 | val_0_unsup_loss_numpy: 0.37970998883247375|  0:06:34s
epoch 1065| loss: 0.5107  | val_0_unsup_loss_numpy: 0.3782399892807007|  0:06:36s
epoch 1070|



In [148]:
unsupervised_model.save_model('./MODEL/PretrainedModel_0.3600')

Successfully saved model at ./MODEL/PretrainedModel_0.3600.zip


'./MODEL/PretrainedModel_0.3600.zip'

<br><br><br><br><br><br><br>
## Fine-tunning the Multi-label Classifier!!

In [150]:
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from torch.optim.lr_scheduler import CosineAnnealingLR

clf = TabNetMultiTaskClassifier(
    cat_idxs=[idx for idx, nCat in enumerate(nCategory) if nCat <= 2],
    cat_dims=[nCat for idx, nCat in enumerate(nCategory) if nCat <= 2],
    cat_emb_dim=5,
    optimizer_fn=torch.optim.Adam, optimizer_params=dict(lr=1e-2),
    scheduler_fn=CosineAnnealingLR, scheduler_params={"T_max": 20, "eta_min": 1e-5}, 
    # scheduler_params={"step_size":50, "gamma":0.9}, scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type='entmax' # "sparsemax"
)

In [151]:
clf.fit(
    X_train=X_train.values, y_train=Y_train.values,  
    eval_set=[(X_train.values, Y_train.values), (X_test.values, Y_test.values)],
    eval_name=['train', 'valid'],
    max_epochs=2000,
    patience=100,
    batch_size=128,
    virtual_batch_size=64,
    num_workers=0,
    drop_last=False,
    loss_fn=[torch.nn.functional.cross_entropy]*Y_train.shape[1],  
    from_unsupervised=unsupervised_model
)



epoch 0  | loss: 0.8306  | train_logloss: 0.47772 | valid_logloss: 0.488   |  0:00:01s
epoch 1  | loss: 0.4127  | train_logloss: 0.37586 | valid_logloss: 0.38641 |  0:00:03s
epoch 2  | loss: 0.34546 | train_logloss: 0.32953 | valid_logloss: 0.33996 |  0:00:05s
epoch 3  | loss: 0.32017 | train_logloss: 0.30274 | valid_logloss: 0.31295 |  0:00:06s
epoch 4  | loss: 0.29874 | train_logloss: 0.28914 | valid_logloss: 0.29986 |  0:00:08s
epoch 5  | loss: 0.28368 | train_logloss: 0.26872 | valid_logloss: 0.27983 |  0:00:10s
epoch 6  | loss: 0.27262 | train_logloss: 0.26351 | valid_logloss: 0.27497 |  0:00:11s
epoch 7  | loss: 0.26083 | train_logloss: 0.25024 | valid_logloss: 0.26228 |  0:00:13s
epoch 8  | loss: 0.2502  | train_logloss: 0.23981 | valid_logloss: 0.25242 |  0:00:15s
epoch 9  | loss: 0.24425 | train_logloss: 0.23222 | valid_logloss: 0.24382 |  0:00:16s
epoch 10 | loss: 0.23964 | train_logloss: 0.22757 | valid_logloss: 0.23942 |  0:00:18s
epoch 11 | loss: 0.23263 | train_logloss: 0

epoch 95 | loss: 0.17128 | train_logloss: 0.1608  | valid_logloss: 0.17588 |  0:02:42s
epoch 96 | loss: 0.17058 | train_logloss: 0.16094 | valid_logloss: 0.17637 |  0:02:44s
epoch 97 | loss: 0.16989 | train_logloss: 0.16094 | valid_logloss: 0.17618 |  0:02:45s
epoch 98 | loss: 0.17077 | train_logloss: 0.16116 | valid_logloss: 0.17635 |  0:02:47s
epoch 99 | loss: 0.16978 | train_logloss: 0.16119 | valid_logloss: 0.17657 |  0:02:49s
epoch 100| loss: 0.16809 | train_logloss: 0.16086 | valid_logloss: 0.17627 |  0:02:50s
epoch 101| loss: 0.16996 | train_logloss: 0.16082 | valid_logloss: 0.1762  |  0:02:52s
epoch 102| loss: 0.16902 | train_logloss: 0.1606  | valid_logloss: 0.17595 |  0:02:54s
epoch 103| loss: 0.1698  | train_logloss: 0.16045 | valid_logloss: 0.17594 |  0:02:55s
epoch 104| loss: 0.16992 | train_logloss: 0.16048 | valid_logloss: 0.17566 |  0:02:57s
epoch 105| loss: 0.16953 | train_logloss: 0.16023 | valid_logloss: 0.17536 |  0:02:59s
epoch 106| loss: 0.17104 | train_logloss: 0

epoch 190| loss: 0.16348 | train_logloss: 0.15357 | valid_logloss: 0.16482 |  0:05:23s
epoch 191| loss: 0.16221 | train_logloss: 0.15253 | valid_logloss: 0.16412 |  0:05:24s
epoch 192| loss: 0.16436 | train_logloss: 0.15328 | valid_logloss: 0.16495 |  0:05:26s
epoch 193| loss: 0.16345 | train_logloss: 0.15379 | valid_logloss: 0.16532 |  0:05:28s
epoch 194| loss: 0.16873 | train_logloss: 0.1543  | valid_logloss: 0.16516 |  0:05:29s
epoch 195| loss: 0.16877 | train_logloss: 0.16064 | valid_logloss: 0.17094 |  0:05:31s
epoch 196| loss: 0.16916 | train_logloss: 0.15976 | valid_logloss: 0.16977 |  0:05:33s
epoch 197| loss: 0.17013 | train_logloss: 0.15569 | valid_logloss: 0.16634 |  0:05:35s
epoch 198| loss: 0.16662 | train_logloss: 0.15514 | valid_logloss: 0.16671 |  0:05:36s
epoch 199| loss: 0.16494 | train_logloss: 0.15483 | valid_logloss: 0.16956 |  0:05:38s
epoch 200| loss: 0.17041 | train_logloss: 0.15657 | valid_logloss: 0.16895 |  0:05:40s
epoch 201| loss: 0.1671  | train_logloss: 0

epoch 285| loss: 0.16395 | train_logloss: 0.15204 | valid_logloss: 0.16483 |  0:08:03s
epoch 286| loss: 0.16149 | train_logloss: 0.15301 | valid_logloss: 0.16336 |  0:08:05s
epoch 287| loss: 0.16313 | train_logloss: 0.15205 | valid_logloss: 0.16348 |  0:08:07s
epoch 288| loss: 0.16131 | train_logloss: 0.15135 | valid_logloss: 0.16276 |  0:08:08s
epoch 289| loss: 0.16129 | train_logloss: 0.15238 | valid_logloss: 0.16357 |  0:08:10s
epoch 290| loss: 0.16016 | train_logloss: 0.1511  | valid_logloss: 0.1624  |  0:08:12s
epoch 291| loss: 0.16188 | train_logloss: 0.15079 | valid_logloss: 0.16088 |  0:08:13s
epoch 292| loss: 0.1592  | train_logloss: 0.14929 | valid_logloss: 0.1601  |  0:08:15s
epoch 293| loss: 0.15878 | train_logloss: 0.14932 | valid_logloss: 0.16005 |  0:08:17s
epoch 294| loss: 0.15697 | train_logloss: 0.14889 | valid_logloss: 0.15972 |  0:08:18s
epoch 295| loss: 0.15786 | train_logloss: 0.14903 | valid_logloss: 0.16009 |  0:08:20s
epoch 296| loss: 0.15684 | train_logloss: 0

epoch 380| loss: 0.15386 | train_logloss: 0.14632 | valid_logloss: 0.1598  |  0:10:44s
epoch 381| loss: 0.15396 | train_logloss: 0.14652 | valid_logloss: 0.16016 |  0:10:45s
epoch 382| loss: 0.15327 | train_logloss: 0.14627 | valid_logloss: 0.15982 |  0:10:47s
epoch 383| loss: 0.1529  | train_logloss: 0.14602 | valid_logloss: 0.15956 |  0:10:49s
epoch 384| loss: 0.15315 | train_logloss: 0.14573 | valid_logloss: 0.15953 |  0:10:50s
epoch 385| loss: 0.15353 | train_logloss: 0.14605 | valid_logloss: 0.15973 |  0:10:52s
epoch 386| loss: 0.15456 | train_logloss: 0.14597 | valid_logloss: 0.15996 |  0:10:54s
epoch 387| loss: 0.15427 | train_logloss: 0.14612 | valid_logloss: 0.15965 |  0:10:55s
epoch 388| loss: 0.15377 | train_logloss: 0.1463  | valid_logloss: 0.15761 |  0:10:57s
epoch 389| loss: 0.1554  | train_logloss: 0.14763 | valid_logloss: 0.15934 |  0:10:59s
epoch 390| loss: 0.15572 | train_logloss: 0.1468  | valid_logloss: 0.15807 |  0:11:01s
epoch 391| loss: 0.15794 | train_logloss: 0

epoch 475| loss: 0.15475 | train_logloss: 0.14611 | valid_logloss: 0.16056 |  0:13:24s
epoch 476| loss: 0.15466 | train_logloss: 0.14549 | valid_logloss: 0.15978 |  0:13:26s
epoch 477| loss: 0.15711 | train_logloss: 0.14657 | valid_logloss: 0.16164 |  0:13:28s
epoch 478| loss: 0.15968 | train_logloss: 0.14652 | valid_logloss: 0.16276 |  0:13:29s
epoch 479| loss: 0.1589  | train_logloss: 0.14859 | valid_logloss: 0.16314 |  0:13:31s
epoch 480| loss: 0.15807 | train_logloss: 0.14802 | valid_logloss: 0.15948 |  0:13:33s
epoch 481| loss: 0.15688 | train_logloss: 0.14634 | valid_logloss: 0.15944 |  0:13:34s
epoch 482| loss: 0.15755 | train_logloss: 0.14723 | valid_logloss: 0.15936 |  0:13:36s
epoch 483| loss: 0.15523 | train_logloss: 0.14538 | valid_logloss: 0.15847 |  0:13:38s
epoch 484| loss: 0.15811 | train_logloss: 0.1461  | valid_logloss: 0.15962 |  0:13:40s
epoch 485| loss: 0.15483 | train_logloss: 0.14503 | valid_logloss: 0.1597  |  0:13:41s
epoch 486| loss: 0.15431 | train_logloss: 0



In [152]:
clf.save_model('./MODEL/ClfModel_0.1576')

Successfully saved model at ./MODEL/ClfModel_0.1576.zip


'./MODEL/ClfModel_0.1576.zip'

<br><br><br><br><br><br><br>
# Inference

In [169]:
from sklearn.metrics import roc_auc_score

In [167]:
res = clf.predict_proba(X_test.values)
AUROCs = [roc_auc_score(Y_test.values[:,idx], res[idx][:,1]) for idx in range(Y_test.shape[1])]

In [179]:
AUROCs

[0.8118014152783195,
 0.7994980449013663,
 0.7880316482011398,
 0.7849857498029228,
 0.9050495198079231,
 0.9491614255765198,
 1.0,
 0.9999888118147237,
 0.9958181660599758,
 0.9990641711229946,
 0.92191204362257,
 0.9113363363363363,
 0.9835679460390776,
 0.9969939195190272,
 0.9680563518715701,
 0.9994828275229789,
 0.9972013744901634,
 0.999166301427292,
 0.9988548709210503,
 1.0,
 0.9991616364855801,
 0.9982722377400869,
 0.9931412200876306,
 0.8165197155545665,
 0.9948226869743122,
 0.8650846508754657,
 0.9999406985708356,
 0.9973229224762967]

In [191]:
X_test.shape

(929, 31)

In [192]:
clf.feature_importances_

array([0.11061031, 0.01071639, 0.        , 0.        , 0.067394  ,
       0.07172143, 0.02985679, 0.01490689, 0.02709706, 0.00676125,
       0.        , 0.00171039, 0.03122171, 0.00174737, 0.02425583,
       0.        , 0.05045785, 0.08060017, 0.1182835 , 0.03502563,
       0.0035272 , 0.01536798, 0.11151161, 0.00495678, 0.00292322,
       0.03191978, 0.00019516, 0.        , 0.09396729, 0.        ,
       0.05326441])