In [1]:
import os
import pandas as pd
import numpy as np

import joblib
import sidetable

from sklearn.feature_selection import RFE, RFECV
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from itertools import combinations
from sklearn.calibration import calibration_curve, CalibrationDisplay, CalibratedClassifierCV
from sklearn.metrics import brier_score_loss

from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

import torch
from pytorch_tabnet.tab_model import TabNetClassifier

from numpy import array
from matplotlib import pyplot as plt
%matplotlib inline

from MetsDataSamplingSeed import get_mets_data, get_metric, get_calib_metric

In [2]:
def get_calib_prob(prob, label, beta):    
    return beta*prob/(beta*prob-prob+1)

In [3]:
def calibrated_plot(model_name, model, X_train, y_train, X_valid, y_valid, X_test, y_test, bins=10, method = 'sigmoid', is_tabnet=False):
    # uncalibrated
    #model.fit(X_train, y_train)
    #y_prob = model.predict_proba(X_test)[:,1]

    # calibrated
    calibrator = CalibratedClassifierCV(model, method=method, cv='prefit') #isotonic
    if is_tabnet:
        y_prob = model.predict_proba(X_test.values[:])[:,1]
        calibrator.fit(X_valid.values[:], y_valid.values[:])
        y_hat = calibrator.predict_proba(X_test.values[:])[:,1]
    else :
        y_prob = model.predict_proba(X_test)[:,1]
        calibrator.fit(X_valid, y_valid)
        y_hat = calibrator.predict_proba(X_test)[:,1]
    
    prob_true_uncalibrated, prob_pred_uncalibrated = calibration_curve(y_test, y_prob, n_bins=bins)
    prob_true_calibrated, prob_pred_calibrated = calibration_curve(y_test, y_hat, n_bins=bins)
    
    brier_score_org = brier_score_loss(y_test, y_prob)
    brier_score_cali = brier_score_loss(y_test, y_hat)
    
    # plot perfectly calibrated
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')

    # plot model reliabilities
    plt.title('Calibration Plot')
    plt.plot(prob_pred_uncalibrated, prob_true_uncalibrated, color = 'c', marker='o', label=(model_name+':'+ str(round(brier_score_org,3))))
    plt.plot(prob_pred_calibrated, prob_true_calibrated, color = 'm', marker='s', label=(model_name+'(calibrated):'+str(round(brier_score_cali,3))))
    plt.xlabel('Mean predicted value')
    plt.ylabel('Fraction of positives')
    plt.legend()
    plt.savefig('./fig/'+model_name+'_calibration.png')
    plt.show()
    
    plt.title('Probability Histogram')
    plt.hist(y_prob, bins=10, color='c', alpha=0.5, label='origin')
    plt.hist(y_hat, bins=10, color='m', alpha=0.5, label='calibrated')
    plt.xlabel('Predictied Probabilty')
    plt.ylabel('Count')
    plt.legend()
    plt.savefig('./fig/'+model_name+'_histogram.png')
    plt.show()
    
    return y_prob, y_hat, brier_score_cali

In [5]:
seeds = range(0, 300, 10)
all_ft = ['WC', 'BP', 'BPWC_add', 'BPWC_mul', 'BPWC_dif','bWC','whr','CUNBAE', 'clbe', 'G1_INT', 'ss18', 'fate', 'smoke_merge_0','regrp15','regrp18','regrp19','regrp38']

### Logistic Regression

In [16]:
# For LR
lr_a = ['sbp', 'CUNBAE', 'AVI', 'dbp']
lr_b = ['drink_0', 'drdu', 'dr_soju', 'ss04', 'exer_merge', 'smoke_merge_0']
lr_c = ['AVI', 'sbp', 'CUNBAE', 'dbp', 'drdu', 'smoke_merge_0']
lr_d = ['WC', 'BP', 'CUNBAE', 'clbe', 'smoke_merge_0']

args = {
    'penalty' : 'none',
    #'solver' : 'liblinear',
    'random_state' : 100,
    #'C': 1.0
}

prob_all = None
prob_idx = [0]
res_all = pd.DataFrame()

for i, s in enumerate(seeds) :
    print('sampling: ', i+1)
    X_train, y_train, X_valid, y_valid, X_test, y_test, _, _, valid_info, test_info, beta, tau, _ = get_mets_data(
        one_hot=True, 
        resampling = False,
        feature_set = None, # 7 = 'anthropometric', 8 = 'lifestyle', 9 = 'anthropometric+lifestyle', 10 = 'anthropometric+lifestyle+synthesis'
        set_feature = lr_d,
        add_feature= False,
        is_tabnet = False,
        is_eval = False,
        seed = s
    )
    sc = StandardScaler()
    sX_train = sc.fit_transform(X_train)
    sX_valid = sc.transform(X_valid)
    sX_test = sc.transform(X_test)

    X_train_lr = pd.DataFrame(sX_train, columns=X_train.columns)
    X_valid_lr = pd.DataFrame(sX_valid, columns=X_valid.columns)
    X_test_lr = pd.DataFrame(sX_test, columns=X_test.columns)
    
    lr = LogisticRegression(**args)
    lr.fit(X_train_lr, y_train)
    prob = lr.predict_proba(X_valid_lr)
    
    res = get_metric(prob, y_valid, 0.5)
    res = pd.DataFrame.from_dict(res, orient='index')
    res_all = pd.concat([res_all,res], axis=1)

sampling:  1
sampling:  2
sampling:  3
sampling:  4
sampling:  5
sampling:  6
sampling:  7
sampling:  8
sampling:  9
sampling:  10
sampling:  11
sampling:  12
sampling:  13
sampling:  14
sampling:  15
sampling:  16
sampling:  17
sampling:  18
sampling:  19
sampling:  20
sampling:  21
sampling:  22
sampling:  23
sampling:  24
sampling:  25
sampling:  26
sampling:  27
sampling:  28
sampling:  29
sampling:  30


In [17]:
res_all_lr = res_all
res_all_lr.transpose().describe()

Unnamed: 0,acc,bac,recall,ppv,npv,sepecificity,f1,auc
count,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
mean,0.794622,0.804579,0.818282,0.382996,0.964839,0.790875,0.521681,0.887254
std,0.00466,0.006437,0.013719,0.010109,0.002753,0.005763,0.010068,0.004833
min,0.785917,0.79328,0.795167,0.368223,0.959504,0.77731,0.507396,0.878202
25%,0.79164,0.799923,0.809593,0.374969,0.962717,0.787919,0.512824,0.884452
50%,0.794127,0.80593,0.820244,0.380184,0.965337,0.791078,0.519484,0.887155
75%,0.796969,0.809584,0.828752,0.390936,0.966964,0.794395,0.5292,0.890831
max,0.804705,0.814833,0.840855,0.406164,0.969779,0.801944,0.54185,0.895939


In [18]:
res_all.transpose().to_csv('./fig/feature_d_logisticregression.csv', index=False)

### Decision Tree

In [32]:
# For DT
dt_a = ['waist', 'sbp', 'BFP', 'BAI', 'WHtR']
dt_b = ['ss18', 'ss05', 'ss23', 'ss12', 'ss14', 'ss19', 'ss20']
dt_c = ['waist', 'sbp', 'BFP', 'BAI', 'WHtR']
dt_d = ['BPWC_add', 'BPWC_mul', 'BPWC_dif']

args = {#'criterion': 'entropy',
        #'max_depth': 5,
        #'max_features': None,
        #'min_samples_leaf': 100,
        #'min_samples_split': 0.01,
        #'splitter': 'best',
        'random_state' : 100}

prob_all = None
res_all = pd.DataFrame()

for i, s in enumerate(seeds) :
    print('sampling: ', i+1)
    
    X_train_dt, y_train, X_valid_dt, y_valid, X_test_dt, y_test, _, _, valid_info, test_info, beta, tau, _ = get_mets_data(
        one_hot=True, 
        resampling = False,
        feature_set = None, # 7 = 'anthropometric', 8 = 'lifestyle', 9 = 'anthropometric+lifestyle', 10 = 'anthropometric+lifestyle+synthesis'
        set_feature = dt_d,#dt_d,
        add_feature= False,
        is_tabnet = False,
        is_eval = False,
        seed = s
    )

    dt = DecisionTreeClassifier(**args)
    dt.fit(X_train_dt, y_train)
    
    prob = dt.predict_proba(X_valid_dt)

    res = get_metric(prob, y_valid, 0.5)
    res = pd.DataFrame.from_dict(res, orient='index')
    res_all = pd.concat([res_all,res], axis=1)

sampling:  1
sampling:  2
sampling:  3
sampling:  4
sampling:  5
sampling:  6
sampling:  7
sampling:  8
sampling:  9
sampling:  10
sampling:  11
sampling:  12
sampling:  13
sampling:  14
sampling:  15
sampling:  16
sampling:  17
sampling:  18
sampling:  19
sampling:  20
sampling:  21
sampling:  22
sampling:  23
sampling:  24
sampling:  25
sampling:  26
sampling:  27
sampling:  28
sampling:  29
sampling:  30


In [33]:
res_all_dt = res_all
res_all_dt.transpose().describe()

Unnamed: 0,acc,bac,recall,ppv,npv,sepecificity,f1,auc
count,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
mean,0.736749,0.765889,0.806009,0.318053,0.959351,0.725769,0.456004,0.792142
std,0.006895,0.007998,0.016789,0.01099,0.003365,0.008379,0.012189,0.010519
min,0.721819,0.748774,0.771111,0.297404,0.951541,0.709266,0.433355,0.770794
25%,0.733028,0.760622,0.793473,0.307956,0.957805,0.719476,0.444846,0.785111
50%,0.736107,0.765975,0.805982,0.318271,0.95924,0.724563,0.457897,0.792094
75%,0.741869,0.772215,0.816415,0.326159,0.961348,0.731399,0.465257,0.800954
max,0.748184,0.779631,0.841981,0.3388,0.967068,0.744387,0.477273,0.811304


In [34]:
res_all.transpose().to_csv('./fig/feature_d_decisiontree.csv', index=False)

### Random Forest

In [18]:
# For RF
rf_a = ['sbp', 'AVI', 'waist', 'dbp', 'bmi', 'BFP', 'weight', 'CUNBAE']
rf_b = ['ss18', 'ss04', 'ss24', 'ss20', 'ss05', 'ss14']
rf_c = ['AVI', 'sbp', 'waist', 'dbp', 'BFP', 'CUNBAE', 'ss18', 'ss20']
rf_d = ['bWC', 'WC', 'BPWC_dif', 'CUNBAE', 'G1_INT', 'regrp15', 'ss18', 'fate']
args = {#'bootstrap': True,
    #'max_depth': 6,
    #'min_samples_leaf': 4,
    #'min_samples_split': 10,
    #'n_estimators': 300,
    'random_state' : 100}

prob_all = None
res_all = pd.DataFrame()

for i, s in enumerate(seeds) :
    print('sampling: ', i+1)
    X_train_rf, y_train, X_valid_rf, y_valid, X_test_rf, y_test, _, _, valid_info, test_info, beta, tau, _ = get_mets_data(
        one_hot=True, 
        resampling = False,
        feature_set = None, # 7 = 'anthropometric', 8 = 'lifestyle', 9 = 'anthropometric+lifestyle', 10 = 'anthropometric+lifestyle+synthesis'
        set_feature = rf_d,
        add_feature= False,
        is_tabnet = False,
        is_eval = False,
        seed = s
    )

    rf = RandomForestClassifier(**args)
    rf.fit(X_train_rf, y_train)
    
    prob = rf.predict_proba(X_valid_rf)
    
    res = get_metric(prob, y_valid, 0.5)
    res = pd.DataFrame.from_dict(res, orient='index')
    res_all = pd.concat([res_all,res], axis=1)

sampling:  1
sampling:  2
sampling:  3
sampling:  4
sampling:  5
sampling:  6
sampling:  7
sampling:  8
sampling:  9
sampling:  10
sampling:  11
sampling:  12
sampling:  13
sampling:  14
sampling:  15
sampling:  16
sampling:  17
sampling:  18
sampling:  19
sampling:  20
sampling:  21
sampling:  22
sampling:  23
sampling:  24
sampling:  25
sampling:  26
sampling:  27
sampling:  28
sampling:  29
sampling:  30


In [19]:
res_all_rf = res_all
res_all_rf.transpose().describe()

Unnamed: 0,acc,bac,recall,ppv,npv,sepecificity,f1,auc
count,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
mean,0.772361,0.810826,0.863787,0.361402,0.972286,0.757865,0.509506,0.885146
std,0.005215,0.006654,0.013571,0.009956,0.00282,0.006171,0.01052,0.004691
min,0.763025,0.795776,0.834884,0.346062,0.966853,0.745288,0.493471,0.875281
25%,0.769222,0.806202,0.854321,0.353336,0.970176,0.75375,0.500608,0.881982
50%,0.77234,0.811039,0.862745,0.360157,0.972622,0.757876,0.508677,0.885041
75%,0.775182,0.815148,0.870481,0.366317,0.974008,0.761506,0.517494,0.888348
max,0.785128,0.824404,0.898534,0.380834,0.978365,0.771319,0.531229,0.894162


In [20]:
res_all.transpose().to_csv('./fig/feature_d_randomforest.csv', index=False)

### Xgboost

In [21]:
# For Xgb
xgb_a = ['sbp', 'AVI', 'BFP', 'dbp', 'whr', 'CUNBAE', 'bmi']
xgb_b = ['sm_total', 'w087', 'eat5_0', 'dr_soju', 'w026', 'w018']
xgb_c = ['AVI', 'sbp', 'BFP', 'dbp', 'whr', 'CUNBAE', 'bmi']
xgb_d = ['BPWC_add', 'WC', 'bWC', 'whr', 'BP', 'regrp18']

args = {#'learning_rate': 0.05, 
        #'max_depth': 4, 
        #'n_estimators': 200, 
        #'subsample': 0.6,
        'random_state' : 100}

prob_all = None
res_all = pd.DataFrame()

for i, s in enumerate(seeds) :
    print('sampling: ', i+1)
    X_train_xgb, y_train, X_valid_xgb, y_valid, X_test_xgb, y_test, _, _, valid_info, test_info, beta, tau, _ = get_mets_data(
        one_hot=True, 
        resampling = False,
        feature_set = None, # 7 = 'anthropometric', 8 = 'lifestyle', 9 = 'anthropometric+lifestyle', 10 = 'anthropometric+lifestyle+synthesis'
        set_feature = xgb_d,
        add_feature= False,
        is_tabnet = False,
        is_eval = False,
        seed = s
    )
    
    xgb = XGBClassifier(**args)
    xgb.fit(X_train_xgb, y_train)

    prob = xgb.predict_proba(X_valid_xgb)
        
    res = get_metric(xgb.predict_proba(X_valid_xgb), y_valid, 0.5)
    res = pd.DataFrame.from_dict(res, orient='index')
    res_all = pd.concat([res_all,res], axis=1)

sampling:  1
sampling:  2
sampling:  3
sampling:  4
sampling:  5
sampling:  6
sampling:  7
sampling:  8
sampling:  9
sampling:  10
sampling:  11
sampling:  12
sampling:  13
sampling:  14
sampling:  15
sampling:  16
sampling:  17
sampling:  18
sampling:  19
sampling:  20
sampling:  21
sampling:  22
sampling:  23
sampling:  24
sampling:  25
sampling:  26
sampling:  27
sampling:  28
sampling:  29
sampling:  30


In [22]:
res_all_xgb = res_all
res_all_xgb.transpose().describe()

Unnamed: 0,acc,bac,recall,ppv,npv,sepecificity,f1,auc
count,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
mean,0.773634,0.808695,0.856961,0.362039,0.971026,0.760428,0.508937,0.883715
std,0.005226,0.005784,0.011325,0.010229,0.002431,0.00611,0.010442,0.004686
min,0.765867,0.797865,0.835414,0.34217,0.966208,0.750728,0.487583,0.874141
25%,0.770287,0.804495,0.848738,0.355457,0.969297,0.755533,0.500819,0.880732
50%,0.772261,0.806724,0.855061,0.359409,0.970863,0.760052,0.508442,0.883603
75%,0.776287,0.813427,0.863973,0.36764,0.972902,0.763606,0.515825,0.886578
max,0.788759,0.819584,0.885613,0.38408,0.977074,0.777086,0.529311,0.892713


In [23]:
res_all.transpose().to_csv('./fig/feature_d_xgboost.csv', index=False)

### TabNet

In [24]:
def train_model(model, X_train, y_train) :
    model.fit(
    X_train=X_train.values[:], y_train=y_train.values[:],
    #eval_set=[(X_train.values[:], y_train.values[:]), (X_valid.values[:], y_valid.values[:])],
    #eval_name=['train', 'valid'],
    eval_metric=['auc'],
    max_epochs=100, 
    patience=50,
    batch_size=1024, 
    virtual_batch_size=128,
    num_workers=0,
    weights=1,
    drop_last=False)
    return model

def train_with_valid(model, X_train, y_train, X_valid, y_valid) :
    model.fit(
    X_train=X_train.values[:], y_train=y_train.values[:],
    eval_set=[(X_train.values[:], y_train.values[:]), (X_valid.values[:], y_valid.values[:])],
    eval_name=['train', 'valid'],
    eval_metric=['auc'],
    max_epochs=100, 
    patience=50,
    batch_size=1024, 
    virtual_batch_size=128,
    num_workers=0,
    weights=1,
    drop_last=False)
    return model

In [25]:
all_ft

['WC',
 'BP',
 'BPWC_add',
 'BPWC_mul',
 'BPWC_dif',
 'bWC',
 'whr',
 'CUNBAE',
 'clbe',
 'G1_INT',
 'ss18',
 'fate',
 'smoke_merge_0',
 'regrp15',
 'regrp18',
 'regrp19',
 'regrp38']

In [36]:
# For TabNet
tn_a = ['waist', 'sbp', 'BRI', 'dbp', 'AVI', 'BFP', 'bmi', 'sex']
tn_b = ['smoke_merge', 'w014', 'exer_merge', 'w095', 'dr_fruwine', 'w099', 'w097', 'w074']
tn_c = ['AVI', 'sbp', 'dbp', 'BRI']
tn_d = ['bWC', 'BP', 'regrp38', 'regrp19']
all_tn = ['WC','BP','BPWC_add','BPWC_mul','BPWC_dif',
          'bWC','whr','CUNBAE','clbe','G1_INT','ss18','fate','smoke_merge',
          'regrp15','regrp18','regrp19','regrp38']
args = {#'gamma': 0.7, 'momentum': 0.03, 'n_independent': 3, 'n_shared': 4, 'n_steps': 2, 
    'seed': 100}

prob_all = None
res_all = pd.DataFrame()

for i, s in enumerate(seeds) :
    print('sampling: ', i+1)
    X_train_tn, y_train, X_valid_tn, y_valid, X_test_tn, y_test, cat_idxs, cat_dims, valid_info, test_info, beta, tau, cat_col = get_mets_data(
        one_hot=False, 
        resampling = False,
        feature_set = None, # 7 = 'anthropometric', 8 = 'lifestyle', 9 = 'anthropometric+lifestyle', 10 = 'anthropometric+lifestyle+synthesis'
        set_feature = tn_d,
        add_feature= False,
        is_tabnet = True,
        is_eval = False,
        seed = s
    )
    
    estimator = TabNetClassifier(cat_idxs=cat_idxs, cat_dims=cat_dims, **args)
    tn = train_with_valid(estimator, X_train_tn, y_train, X_valid_tn, y_valid)
    
    prob = tn.predict_proba(X_valid_tn.values[:])
    
    res = get_metric(prob, y_valid, 0.5)
    res = pd.DataFrame.from_dict(res, orient='index')
    res_all = pd.concat([res_all,res], axis=1)

sampling:  1




epoch 0  | loss: 0.69133 | train_auc: 0.65511 | valid_auc: 0.65726 |  0:00:00s
epoch 1  | loss: 0.48603 | train_auc: 0.76084 | valid_auc: 0.76587 |  0:00:01s
epoch 2  | loss: 0.45039 | train_auc: 0.80735 | valid_auc: 0.80332 |  0:00:01s
epoch 3  | loss: 0.43453 | train_auc: 0.81288 | valid_auc: 0.81305 |  0:00:02s
epoch 4  | loss: 0.42724 | train_auc: 0.86337 | valid_auc: 0.86604 |  0:00:02s
epoch 5  | loss: 0.42434 | train_auc: 0.87254 | valid_auc: 0.87701 |  0:00:03s
epoch 6  | loss: 0.42927 | train_auc: 0.85049 | valid_auc: 0.85186 |  0:00:03s
epoch 7  | loss: 0.42642 | train_auc: 0.87092 | valid_auc: 0.87729 |  0:00:04s
epoch 8  | loss: 0.42199 | train_auc: 0.87737 | valid_auc: 0.88393 |  0:00:04s
epoch 9  | loss: 0.4198  | train_auc: 0.88023 | valid_auc: 0.88517 |  0:00:05s
epoch 10 | loss: 0.41929 | train_auc: 0.87903 | valid_auc: 0.88444 |  0:00:06s
epoch 11 | loss: 0.42083 | train_auc: 0.88359 | valid_auc: 0.88698 |  0:00:06s
epoch 12 | loss: 0.41411 | train_auc: 0.88575 | vali



epoch 0  | loss: 0.69993 | train_auc: 0.63875 | valid_auc: 0.63336 |  0:00:00s
epoch 1  | loss: 0.48201 | train_auc: 0.7576  | valid_auc: 0.75736 |  0:00:01s
epoch 2  | loss: 0.44222 | train_auc: 0.81477 | valid_auc: 0.81803 |  0:00:01s
epoch 3  | loss: 0.44032 | train_auc: 0.84583 | valid_auc: 0.85106 |  0:00:02s
epoch 4  | loss: 0.43043 | train_auc: 0.85717 | valid_auc: 0.86323 |  0:00:02s
epoch 5  | loss: 0.43156 | train_auc: 0.86371 | valid_auc: 0.86776 |  0:00:03s
epoch 6  | loss: 0.43077 | train_auc: 0.87053 | valid_auc: 0.87466 |  0:00:03s
epoch 7  | loss: 0.43174 | train_auc: 0.87489 | valid_auc: 0.87728 |  0:00:04s
epoch 8  | loss: 0.4213  | train_auc: 0.87591 | valid_auc: 0.87912 |  0:00:05s
epoch 9  | loss: 0.41736 | train_auc: 0.87662 | valid_auc: 0.8817  |  0:00:05s
epoch 10 | loss: 0.42228 | train_auc: 0.88003 | valid_auc: 0.88417 |  0:00:06s
epoch 11 | loss: 0.4226  | train_auc: 0.87829 | valid_auc: 0.88506 |  0:00:06s
epoch 12 | loss: 0.42747 | train_auc: 0.8824  | vali



epoch 0  | loss: 0.694   | train_auc: 0.68727 | valid_auc: 0.66727 |  0:00:00s
epoch 1  | loss: 0.47163 | train_auc: 0.82752 | valid_auc: 0.8257  |  0:00:01s
epoch 2  | loss: 0.45073 | train_auc: 0.83841 | valid_auc: 0.83935 |  0:00:01s
epoch 3  | loss: 0.43685 | train_auc: 0.8606  | valid_auc: 0.8631  |  0:00:02s
epoch 4  | loss: 0.43015 | train_auc: 0.86253 | valid_auc: 0.86418 |  0:00:02s
epoch 5  | loss: 0.42853 | train_auc: 0.87372 | valid_auc: 0.87124 |  0:00:03s
epoch 6  | loss: 0.42493 | train_auc: 0.87728 | valid_auc: 0.87605 |  0:00:03s
epoch 7  | loss: 0.4253  | train_auc: 0.87331 | valid_auc: 0.87615 |  0:00:04s
epoch 8  | loss: 0.42645 | train_auc: 0.8778  | valid_auc: 0.87661 |  0:00:04s
epoch 9  | loss: 0.41454 | train_auc: 0.88469 | valid_auc: 0.883   |  0:00:05s
epoch 10 | loss: 0.42115 | train_auc: 0.88302 | valid_auc: 0.88053 |  0:00:06s
epoch 11 | loss: 0.42164 | train_auc: 0.88663 | valid_auc: 0.88466 |  0:00:06s
epoch 12 | loss: 0.42151 | train_auc: 0.88792 | vali



epoch 0  | loss: 0.70154 | train_auc: 0.61814 | valid_auc: 0.60787 |  0:00:00s
epoch 1  | loss: 0.48237 | train_auc: 0.81979 | valid_auc: 0.82177 |  0:00:01s
epoch 2  | loss: 0.44074 | train_auc: 0.86191 | valid_auc: 0.8658  |  0:00:01s
epoch 3  | loss: 0.43562 | train_auc: 0.82516 | valid_auc: 0.8252  |  0:00:02s
epoch 4  | loss: 0.42249 | train_auc: 0.84859 | valid_auc: 0.85098 |  0:00:02s
epoch 5  | loss: 0.42302 | train_auc: 0.86144 | valid_auc: 0.86276 |  0:00:03s
epoch 6  | loss: 0.42093 | train_auc: 0.87302 | valid_auc: 0.87788 |  0:00:03s
epoch 7  | loss: 0.42189 | train_auc: 0.87848 | valid_auc: 0.88328 |  0:00:04s
epoch 8  | loss: 0.42298 | train_auc: 0.88055 | valid_auc: 0.88495 |  0:00:04s
epoch 9  | loss: 0.40809 | train_auc: 0.88073 | valid_auc: 0.88517 |  0:00:05s
epoch 10 | loss: 0.41775 | train_auc: 0.88519 | valid_auc: 0.88906 |  0:00:06s
epoch 11 | loss: 0.42146 | train_auc: 0.8857  | valid_auc: 0.88912 |  0:00:06s
epoch 12 | loss: 0.42435 | train_auc: 0.88605 | vali



epoch 0  | loss: 0.69867 | train_auc: 0.67839 | valid_auc: 0.6704  |  0:00:00s
epoch 1  | loss: 0.48597 | train_auc: 0.81083 | valid_auc: 0.81124 |  0:00:01s
epoch 2  | loss: 0.43274 | train_auc: 0.84455 | valid_auc: 0.84382 |  0:00:01s
epoch 3  | loss: 0.43607 | train_auc: 0.85189 | valid_auc: 0.85621 |  0:00:02s
epoch 4  | loss: 0.4323  | train_auc: 0.86892 | valid_auc: 0.86988 |  0:00:02s
epoch 5  | loss: 0.41901 | train_auc: 0.87239 | valid_auc: 0.87269 |  0:00:03s
epoch 6  | loss: 0.42195 | train_auc: 0.87863 | valid_auc: 0.87963 |  0:00:03s
epoch 7  | loss: 0.41966 | train_auc: 0.87853 | valid_auc: 0.87928 |  0:00:04s
epoch 8  | loss: 0.41764 | train_auc: 0.88129 | valid_auc: 0.87932 |  0:00:04s
epoch 9  | loss: 0.42011 | train_auc: 0.88134 | valid_auc: 0.87955 |  0:00:05s
epoch 10 | loss: 0.40845 | train_auc: 0.8839  | valid_auc: 0.88249 |  0:00:06s
epoch 11 | loss: 0.41881 | train_auc: 0.88598 | valid_auc: 0.88322 |  0:00:06s
epoch 12 | loss: 0.40676 | train_auc: 0.88792 | vali



epoch 0  | loss: 0.69116 | train_auc: 0.71351 | valid_auc: 0.71606 |  0:00:00s
epoch 1  | loss: 0.47802 | train_auc: 0.83633 | valid_auc: 0.83192 |  0:00:01s
epoch 2  | loss: 0.44324 | train_auc: 0.85137 | valid_auc: 0.8498  |  0:00:01s
epoch 3  | loss: 0.43142 | train_auc: 0.84367 | valid_auc: 0.83952 |  0:00:02s
epoch 4  | loss: 0.42303 | train_auc: 0.85634 | valid_auc: 0.85704 |  0:00:02s
epoch 5  | loss: 0.4265  | train_auc: 0.87008 | valid_auc: 0.86979 |  0:00:03s
epoch 6  | loss: 0.41894 | train_auc: 0.8775  | valid_auc: 0.87794 |  0:00:03s
epoch 7  | loss: 0.42281 | train_auc: 0.87779 | valid_auc: 0.87929 |  0:00:04s
epoch 8  | loss: 0.43086 | train_auc: 0.87888 | valid_auc: 0.88063 |  0:00:04s
epoch 9  | loss: 0.41809 | train_auc: 0.88098 | valid_auc: 0.88212 |  0:00:05s
epoch 10 | loss: 0.41261 | train_auc: 0.88255 | valid_auc: 0.8839  |  0:00:06s
epoch 11 | loss: 0.42274 | train_auc: 0.88547 | valid_auc: 0.88784 |  0:00:06s
epoch 12 | loss: 0.4264  | train_auc: 0.88619 | vali



epoch 0  | loss: 0.68788 | train_auc: 0.64793 | valid_auc: 0.64305 |  0:00:00s
epoch 1  | loss: 0.46198 | train_auc: 0.81992 | valid_auc: 0.81704 |  0:00:01s
epoch 2  | loss: 0.43615 | train_auc: 0.84619 | valid_auc: 0.84175 |  0:00:01s
epoch 3  | loss: 0.43112 | train_auc: 0.86566 | valid_auc: 0.86815 |  0:00:02s
epoch 4  | loss: 0.4212  | train_auc: 0.86848 | valid_auc: 0.86822 |  0:00:02s
epoch 5  | loss: 0.42811 | train_auc: 0.87453 | valid_auc: 0.87673 |  0:00:03s
epoch 6  | loss: 0.41386 | train_auc: 0.87962 | valid_auc: 0.88299 |  0:00:03s
epoch 7  | loss: 0.41401 | train_auc: 0.88054 | valid_auc: 0.88358 |  0:00:04s
epoch 8  | loss: 0.42599 | train_auc: 0.88439 | valid_auc: 0.88792 |  0:00:04s
epoch 9  | loss: 0.40348 | train_auc: 0.88692 | valid_auc: 0.8899  |  0:00:05s
epoch 10 | loss: 0.40669 | train_auc: 0.88919 | valid_auc: 0.89253 |  0:00:06s
epoch 11 | loss: 0.41186 | train_auc: 0.88855 | valid_auc: 0.89148 |  0:00:06s
epoch 12 | loss: 0.41387 | train_auc: 0.88934 | vali



sampling:  8




epoch 0  | loss: 0.68857 | train_auc: 0.56324 | valid_auc: 0.54878 |  0:00:00s
epoch 1  | loss: 0.47614 | train_auc: 0.773   | valid_auc: 0.78704 |  0:00:01s
epoch 2  | loss: 0.44058 | train_auc: 0.82724 | valid_auc: 0.82958 |  0:00:01s
epoch 3  | loss: 0.43758 | train_auc: 0.85341 | valid_auc: 0.86078 |  0:00:02s
epoch 4  | loss: 0.43203 | train_auc: 0.8597  | valid_auc: 0.8666  |  0:00:02s
epoch 5  | loss: 0.42071 | train_auc: 0.86121 | valid_auc: 0.86571 |  0:00:03s
epoch 6  | loss: 0.43569 | train_auc: 0.86675 | valid_auc: 0.8746  |  0:00:03s
epoch 7  | loss: 0.42799 | train_auc: 0.87529 | valid_auc: 0.88515 |  0:00:04s
epoch 8  | loss: 0.42777 | train_auc: 0.87712 | valid_auc: 0.88507 |  0:00:04s
epoch 9  | loss: 0.43057 | train_auc: 0.87634 | valid_auc: 0.8852  |  0:00:05s
epoch 10 | loss: 0.42506 | train_auc: 0.88092 | valid_auc: 0.89055 |  0:00:06s
epoch 11 | loss: 0.427   | train_auc: 0.88058 | valid_auc: 0.88857 |  0:00:06s
epoch 12 | loss: 0.42328 | train_auc: 0.88171 | vali



epoch 0  | loss: 0.69785 | train_auc: 0.54804 | valid_auc: 0.56854 |  0:00:00s
epoch 1  | loss: 0.46273 | train_auc: 0.6042  | valid_auc: 0.63979 |  0:00:01s
epoch 2  | loss: 0.4367  | train_auc: 0.7574  | valid_auc: 0.76828 |  0:00:01s
epoch 3  | loss: 0.43552 | train_auc: 0.80032 | valid_auc: 0.80065 |  0:00:02s
epoch 4  | loss: 0.42726 | train_auc: 0.85825 | valid_auc: 0.85976 |  0:00:02s
epoch 5  | loss: 0.42726 | train_auc: 0.86496 | valid_auc: 0.85533 |  0:00:03s
epoch 6  | loss: 0.41752 | train_auc: 0.87709 | valid_auc: 0.87032 |  0:00:03s
epoch 7  | loss: 0.41847 | train_auc: 0.87796 | valid_auc: 0.872   |  0:00:04s
epoch 8  | loss: 0.42417 | train_auc: 0.88294 | valid_auc: 0.87804 |  0:00:05s
epoch 9  | loss: 0.41934 | train_auc: 0.88272 | valid_auc: 0.87871 |  0:00:05s
epoch 10 | loss: 0.42933 | train_auc: 0.88422 | valid_auc: 0.8805  |  0:00:06s
epoch 11 | loss: 0.41939 | train_auc: 0.88637 | valid_auc: 0.87979 |  0:00:06s
epoch 12 | loss: 0.42126 | train_auc: 0.88811 | vali



epoch 0  | loss: 0.69368 | train_auc: 0.55542 | valid_auc: 0.55218 |  0:00:00s
epoch 1  | loss: 0.47135 | train_auc: 0.7855  | valid_auc: 0.78026 |  0:00:01s
epoch 2  | loss: 0.43832 | train_auc: 0.84068 | valid_auc: 0.83877 |  0:00:01s
epoch 3  | loss: 0.42936 | train_auc: 0.85798 | valid_auc: 0.85908 |  0:00:02s
epoch 4  | loss: 0.42912 | train_auc: 0.86132 | valid_auc: 0.86244 |  0:00:02s
epoch 5  | loss: 0.42774 | train_auc: 0.86323 | valid_auc: 0.86645 |  0:00:03s
epoch 6  | loss: 0.42412 | train_auc: 0.87358 | valid_auc: 0.8745  |  0:00:03s
epoch 7  | loss: 0.42041 | train_auc: 0.87666 | valid_auc: 0.87578 |  0:00:04s
epoch 8  | loss: 0.41799 | train_auc: 0.88145 | valid_auc: 0.8804  |  0:00:05s
epoch 9  | loss: 0.42008 | train_auc: 0.88135 | valid_auc: 0.88041 |  0:00:05s
epoch 10 | loss: 0.42249 | train_auc: 0.88394 | valid_auc: 0.88257 |  0:00:06s
epoch 11 | loss: 0.40797 | train_auc: 0.88549 | valid_auc: 0.88234 |  0:00:06s
epoch 12 | loss: 0.41461 | train_auc: 0.88743 | vali



epoch 0  | loss: 0.72443 | train_auc: 0.57727 | valid_auc: 0.5737  |  0:00:00s
epoch 1  | loss: 0.49279 | train_auc: 0.7396  | valid_auc: 0.7427  |  0:00:01s
epoch 2  | loss: 0.46417 | train_auc: 0.82071 | valid_auc: 0.82126 |  0:00:01s
epoch 3  | loss: 0.4411  | train_auc: 0.84404 | valid_auc: 0.84779 |  0:00:02s
epoch 4  | loss: 0.43254 | train_auc: 0.863   | valid_auc: 0.86874 |  0:00:02s
epoch 5  | loss: 0.42497 | train_auc: 0.86388 | valid_auc: 0.8671  |  0:00:03s
epoch 6  | loss: 0.42842 | train_auc: 0.87503 | valid_auc: 0.87962 |  0:00:03s
epoch 7  | loss: 0.42589 | train_auc: 0.87595 | valid_auc: 0.8781  |  0:00:04s
epoch 8  | loss: 0.42662 | train_auc: 0.8785  | valid_auc: 0.88089 |  0:00:05s
epoch 9  | loss: 0.42462 | train_auc: 0.88146 | valid_auc: 0.88433 |  0:00:05s
epoch 10 | loss: 0.41718 | train_auc: 0.8804  | valid_auc: 0.88338 |  0:00:06s
epoch 11 | loss: 0.42411 | train_auc: 0.88345 | valid_auc: 0.88817 |  0:00:06s
epoch 12 | loss: 0.42722 | train_auc: 0.88227 | vali



epoch 0  | loss: 0.70806 | train_auc: 0.52955 | valid_auc: 0.52022 |  0:00:00s
epoch 1  | loss: 0.4932  | train_auc: 0.64513 | valid_auc: 0.63505 |  0:00:01s
epoch 2  | loss: 0.45759 | train_auc: 0.6919  | valid_auc: 0.67986 |  0:00:01s
epoch 3  | loss: 0.44817 | train_auc: 0.78518 | valid_auc: 0.77664 |  0:00:02s
epoch 4  | loss: 0.4376  | train_auc: 0.85176 | valid_auc: 0.85241 |  0:00:02s
epoch 5  | loss: 0.4388  | train_auc: 0.86331 | valid_auc: 0.86577 |  0:00:03s
epoch 6  | loss: 0.43702 | train_auc: 0.86437 | valid_auc: 0.86585 |  0:00:03s
epoch 7  | loss: 0.43722 | train_auc: 0.86646 | valid_auc: 0.8698  |  0:00:04s
epoch 8  | loss: 0.4365  | train_auc: 0.8751  | valid_auc: 0.88191 |  0:00:05s
epoch 9  | loss: 0.42797 | train_auc: 0.87456 | valid_auc: 0.88009 |  0:00:05s
epoch 10 | loss: 0.43419 | train_auc: 0.88078 | valid_auc: 0.88669 |  0:00:06s
epoch 11 | loss: 0.42931 | train_auc: 0.88226 | valid_auc: 0.88836 |  0:00:06s
epoch 12 | loss: 0.42283 | train_auc: 0.88109 | vali



epoch 0  | loss: 0.70757 | train_auc: 0.59026 | valid_auc: 0.58452 |  0:00:00s
epoch 1  | loss: 0.46325 | train_auc: 0.79487 | valid_auc: 0.78511 |  0:00:01s
epoch 2  | loss: 0.42753 | train_auc: 0.81281 | valid_auc: 0.80848 |  0:00:01s
epoch 3  | loss: 0.42228 | train_auc: 0.85798 | valid_auc: 0.85172 |  0:00:02s
epoch 4  | loss: 0.42155 | train_auc: 0.87008 | valid_auc: 0.86482 |  0:00:02s
epoch 5  | loss: 0.43134 | train_auc: 0.87311 | valid_auc: 0.86789 |  0:00:03s
epoch 6  | loss: 0.42688 | train_auc: 0.87799 | valid_auc: 0.87151 |  0:00:03s
epoch 7  | loss: 0.42343 | train_auc: 0.87735 | valid_auc: 0.87226 |  0:00:04s
epoch 8  | loss: 0.41525 | train_auc: 0.88111 | valid_auc: 0.87686 |  0:00:04s
epoch 9  | loss: 0.41695 | train_auc: 0.88388 | valid_auc: 0.87994 |  0:00:05s
epoch 10 | loss: 0.41588 | train_auc: 0.88429 | valid_auc: 0.87952 |  0:00:06s
epoch 11 | loss: 0.42285 | train_auc: 0.88593 | valid_auc: 0.88097 |  0:00:06s
epoch 12 | loss: 0.41699 | train_auc: 0.88736 | vali



epoch 0  | loss: 0.69834 | train_auc: 0.6212  | valid_auc: 0.61375 |  0:00:00s
epoch 1  | loss: 0.47386 | train_auc: 0.71817 | valid_auc: 0.71426 |  0:00:01s
epoch 2  | loss: 0.45037 | train_auc: 0.79375 | valid_auc: 0.78612 |  0:00:01s
epoch 3  | loss: 0.44604 | train_auc: 0.83069 | valid_auc: 0.82761 |  0:00:02s
epoch 4  | loss: 0.42748 | train_auc: 0.86265 | valid_auc: 0.86334 |  0:00:02s
epoch 5  | loss: 0.43152 | train_auc: 0.8667  | valid_auc: 0.86508 |  0:00:03s
epoch 6  | loss: 0.43177 | train_auc: 0.86267 | valid_auc: 0.86015 |  0:00:03s
epoch 7  | loss: 0.43121 | train_auc: 0.8723  | valid_auc: 0.86962 |  0:00:04s
epoch 8  | loss: 0.42977 | train_auc: 0.87651 | valid_auc: 0.87623 |  0:00:04s
epoch 9  | loss: 0.42653 | train_auc: 0.87293 | valid_auc: 0.87139 |  0:00:05s
epoch 10 | loss: 0.42545 | train_auc: 0.88051 | valid_auc: 0.87947 |  0:00:06s
epoch 11 | loss: 0.42359 | train_auc: 0.88478 | valid_auc: 0.88451 |  0:00:06s
epoch 12 | loss: 0.42911 | train_auc: 0.88604 | vali



epoch 0  | loss: 0.67918 | train_auc: 0.59221 | valid_auc: 0.59959 |  0:00:00s
epoch 1  | loss: 0.46989 | train_auc: 0.72658 | valid_auc: 0.72326 |  0:00:01s
epoch 2  | loss: 0.43928 | train_auc: 0.81637 | valid_auc: 0.83263 |  0:00:01s
epoch 3  | loss: 0.42924 | train_auc: 0.84549 | valid_auc: 0.85336 |  0:00:02s
epoch 4  | loss: 0.43191 | train_auc: 0.85771 | valid_auc: 0.86641 |  0:00:02s
epoch 5  | loss: 0.42267 | train_auc: 0.86869 | valid_auc: 0.87651 |  0:00:03s
epoch 6  | loss: 0.42261 | train_auc: 0.86061 | valid_auc: 0.86869 |  0:00:03s
epoch 7  | loss: 0.42452 | train_auc: 0.87102 | valid_auc: 0.88054 |  0:00:04s
epoch 8  | loss: 0.42213 | train_auc: 0.87268 | valid_auc: 0.88267 |  0:00:05s
epoch 9  | loss: 0.41894 | train_auc: 0.87333 | valid_auc: 0.88286 |  0:00:05s
epoch 10 | loss: 0.41747 | train_auc: 0.8798  | valid_auc: 0.88966 |  0:00:06s
epoch 11 | loss: 0.41828 | train_auc: 0.87981 | valid_auc: 0.88872 |  0:00:06s
epoch 12 | loss: 0.42494 | train_auc: 0.88244 | vali



epoch 0  | loss: 0.6788  | train_auc: 0.62982 | valid_auc: 0.61903 |  0:00:00s
epoch 1  | loss: 0.48966 | train_auc: 0.74415 | valid_auc: 0.73522 |  0:00:01s
epoch 2  | loss: 0.44845 | train_auc: 0.8299  | valid_auc: 0.83322 |  0:00:01s
epoch 3  | loss: 0.43053 | train_auc: 0.84335 | valid_auc: 0.84352 |  0:00:02s
epoch 4  | loss: 0.4344  | train_auc: 0.85803 | valid_auc: 0.86351 |  0:00:02s
epoch 5  | loss: 0.42565 | train_auc: 0.86911 | valid_auc: 0.87303 |  0:00:03s
epoch 6  | loss: 0.42364 | train_auc: 0.87205 | valid_auc: 0.87541 |  0:00:04s
epoch 7  | loss: 0.43016 | train_auc: 0.87247 | valid_auc: 0.87658 |  0:00:04s
epoch 8  | loss: 0.42641 | train_auc: 0.87528 | valid_auc: 0.88015 |  0:00:05s
epoch 9  | loss: 0.42127 | train_auc: 0.87905 | valid_auc: 0.88693 |  0:00:05s
epoch 10 | loss: 0.42086 | train_auc: 0.8767  | valid_auc: 0.881   |  0:00:06s
epoch 11 | loss: 0.4246  | train_auc: 0.88173 | valid_auc: 0.88853 |  0:00:06s
epoch 12 | loss: 0.42826 | train_auc: 0.88382 | vali



epoch 0  | loss: 0.69136 | train_auc: 0.68744 | valid_auc: 0.68325 |  0:00:00s
epoch 1  | loss: 0.48    | train_auc: 0.75525 | valid_auc: 0.75734 |  0:00:01s
epoch 2  | loss: 0.44969 | train_auc: 0.85369 | valid_auc: 0.8515  |  0:00:01s
epoch 3  | loss: 0.4301  | train_auc: 0.86577 | valid_auc: 0.86552 |  0:00:02s
epoch 4  | loss: 0.41799 | train_auc: 0.86327 | valid_auc: 0.86228 |  0:00:02s
epoch 5  | loss: 0.42811 | train_auc: 0.86894 | valid_auc: 0.86769 |  0:00:03s
epoch 6  | loss: 0.42044 | train_auc: 0.87129 | valid_auc: 0.86974 |  0:00:03s
epoch 7  | loss: 0.41726 | train_auc: 0.87823 | valid_auc: 0.87629 |  0:00:04s
epoch 8  | loss: 0.41718 | train_auc: 0.88073 | valid_auc: 0.87915 |  0:00:04s
epoch 9  | loss: 0.41425 | train_auc: 0.88045 | valid_auc: 0.87871 |  0:00:05s
epoch 10 | loss: 0.4221  | train_auc: 0.88468 | valid_auc: 0.88284 |  0:00:06s
epoch 11 | loss: 0.41885 | train_auc: 0.8851  | valid_auc: 0.8823  |  0:00:06s
epoch 12 | loss: 0.42143 | train_auc: 0.88721 | vali



epoch 0  | loss: 0.69175 | train_auc: 0.75782 | valid_auc: 0.76668 |  0:00:00s
epoch 1  | loss: 0.475   | train_auc: 0.82044 | valid_auc: 0.82025 |  0:00:01s
epoch 2  | loss: 0.43522 | train_auc: 0.86558 | valid_auc: 0.86628 |  0:00:01s
epoch 3  | loss: 0.4253  | train_auc: 0.85903 | valid_auc: 0.85996 |  0:00:02s
epoch 4  | loss: 0.42978 | train_auc: 0.87276 | valid_auc: 0.87244 |  0:00:02s
epoch 5  | loss: 0.43081 | train_auc: 0.86907 | valid_auc: 0.86889 |  0:00:03s
epoch 6  | loss: 0.42648 | train_auc: 0.87972 | valid_auc: 0.87924 |  0:00:03s
epoch 7  | loss: 0.42403 | train_auc: 0.87707 | valid_auc: 0.87419 |  0:00:04s
epoch 8  | loss: 0.41817 | train_auc: 0.88087 | valid_auc: 0.87881 |  0:00:05s
epoch 9  | loss: 0.4194  | train_auc: 0.88212 | valid_auc: 0.87884 |  0:00:05s
epoch 10 | loss: 0.42148 | train_auc: 0.88485 | valid_auc: 0.8828  |  0:00:06s
epoch 11 | loss: 0.41712 | train_auc: 0.88486 | valid_auc: 0.88157 |  0:00:06s
epoch 12 | loss: 0.42685 | train_auc: 0.88599 | vali



epoch 0  | loss: 0.68123 | train_auc: 0.63924 | valid_auc: 0.62702 |  0:00:00s
epoch 1  | loss: 0.46974 | train_auc: 0.80721 | valid_auc: 0.79093 |  0:00:01s
epoch 2  | loss: 0.44599 | train_auc: 0.80958 | valid_auc: 0.80016 |  0:00:01s
epoch 3  | loss: 0.43908 | train_auc: 0.83966 | valid_auc: 0.83204 |  0:00:02s
epoch 4  | loss: 0.43377 | train_auc: 0.85646 | valid_auc: 0.8433  |  0:00:02s
epoch 5  | loss: 0.42594 | train_auc: 0.86345 | valid_auc: 0.85271 |  0:00:03s
epoch 6  | loss: 0.42685 | train_auc: 0.87486 | valid_auc: 0.86489 |  0:00:03s
epoch 7  | loss: 0.41554 | train_auc: 0.88192 | valid_auc: 0.87303 |  0:00:04s
epoch 8  | loss: 0.42934 | train_auc: 0.88298 | valid_auc: 0.87316 |  0:00:04s
epoch 9  | loss: 0.42121 | train_auc: 0.88384 | valid_auc: 0.87448 |  0:00:05s
epoch 10 | loss: 0.41288 | train_auc: 0.88495 | valid_auc: 0.87525 |  0:00:06s
epoch 11 | loss: 0.41438 | train_auc: 0.88968 | valid_auc: 0.88037 |  0:00:06s
epoch 12 | loss: 0.4183  | train_auc: 0.88712 | vali



epoch 0  | loss: 0.69442 | train_auc: 0.74544 | valid_auc: 0.7314  |  0:00:00s
epoch 1  | loss: 0.46872 | train_auc: 0.8442  | valid_auc: 0.83197 |  0:00:01s
epoch 2  | loss: 0.43395 | train_auc: 0.82151 | valid_auc: 0.8133  |  0:00:01s
epoch 3  | loss: 0.43143 | train_auc: 0.85409 | valid_auc: 0.84331 |  0:00:02s
epoch 4  | loss: 0.42429 | train_auc: 0.86749 | valid_auc: 0.85917 |  0:00:02s
epoch 5  | loss: 0.42188 | train_auc: 0.87267 | valid_auc: 0.86339 |  0:00:03s
epoch 6  | loss: 0.419   | train_auc: 0.87429 | valid_auc: 0.86273 |  0:00:03s
epoch 7  | loss: 0.42038 | train_auc: 0.8763  | valid_auc: 0.8654  |  0:00:04s
epoch 8  | loss: 0.41839 | train_auc: 0.88008 | valid_auc: 0.86903 |  0:00:05s
epoch 9  | loss: 0.41482 | train_auc: 0.88193 | valid_auc: 0.8705  |  0:00:05s
epoch 10 | loss: 0.41479 | train_auc: 0.88077 | valid_auc: 0.87132 |  0:00:06s
epoch 11 | loss: 0.41333 | train_auc: 0.88557 | valid_auc: 0.87602 |  0:00:06s
epoch 12 | loss: 0.40635 | train_auc: 0.8874  | vali



epoch 0  | loss: 0.69419 | train_auc: 0.48899 | valid_auc: 0.50816 |  0:00:00s
epoch 1  | loss: 0.49249 | train_auc: 0.49932 | valid_auc: 0.48605 |  0:00:01s
epoch 2  | loss: 0.45937 | train_auc: 0.75382 | valid_auc: 0.76285 |  0:00:01s
epoch 3  | loss: 0.43828 | train_auc: 0.81499 | valid_auc: 0.81414 |  0:00:02s
epoch 4  | loss: 0.43824 | train_auc: 0.85022 | valid_auc: 0.86162 |  0:00:02s
epoch 5  | loss: 0.4378  | train_auc: 0.85109 | valid_auc: 0.85992 |  0:00:03s
epoch 6  | loss: 0.43185 | train_auc: 0.86883 | valid_auc: 0.87834 |  0:00:03s
epoch 7  | loss: 0.4263  | train_auc: 0.86785 | valid_auc: 0.87884 |  0:00:04s
epoch 8  | loss: 0.4292  | train_auc: 0.87644 | valid_auc: 0.88597 |  0:00:04s
epoch 9  | loss: 0.42786 | train_auc: 0.87633 | valid_auc: 0.88528 |  0:00:05s
epoch 10 | loss: 0.4268  | train_auc: 0.88124 | valid_auc: 0.89074 |  0:00:06s
epoch 11 | loss: 0.41912 | train_auc: 0.88148 | valid_auc: 0.89208 |  0:00:06s
epoch 12 | loss: 0.42457 | train_auc: 0.88248 | vali



epoch 0  | loss: 0.70509 | train_auc: 0.6087  | valid_auc: 0.62351 |  0:00:00s
epoch 1  | loss: 0.4933  | train_auc: 0.70469 | valid_auc: 0.71153 |  0:00:01s
epoch 2  | loss: 0.45458 | train_auc: 0.79012 | valid_auc: 0.80207 |  0:00:01s
epoch 3  | loss: 0.43705 | train_auc: 0.84009 | valid_auc: 0.8516  |  0:00:02s
epoch 4  | loss: 0.43193 | train_auc: 0.84345 | valid_auc: 0.85797 |  0:00:02s
epoch 5  | loss: 0.42999 | train_auc: 0.85477 | valid_auc: 0.868   |  0:00:03s
epoch 6  | loss: 0.42996 | train_auc: 0.8634  | valid_auc: 0.87744 |  0:00:03s
epoch 7  | loss: 0.42589 | train_auc: 0.86932 | valid_auc: 0.88318 |  0:00:04s
epoch 8  | loss: 0.42368 | train_auc: 0.87271 | valid_auc: 0.88616 |  0:00:05s
epoch 9  | loss: 0.42657 | train_auc: 0.87721 | valid_auc: 0.88944 |  0:00:05s
epoch 10 | loss: 0.43361 | train_auc: 0.87948 | valid_auc: 0.89214 |  0:00:06s
epoch 11 | loss: 0.42995 | train_auc: 0.88229 | valid_auc: 0.89438 |  0:00:06s
epoch 12 | loss: 0.42574 | train_auc: 0.88328 | vali



epoch 0  | loss: 0.7056  | train_auc: 0.56987 | valid_auc: 0.55601 |  0:00:00s
epoch 1  | loss: 0.47256 | train_auc: 0.75682 | valid_auc: 0.75318 |  0:00:01s
epoch 2  | loss: 0.44069 | train_auc: 0.77771 | valid_auc: 0.76996 |  0:00:01s
epoch 3  | loss: 0.42963 | train_auc: 0.8379  | valid_auc: 0.83975 |  0:00:02s
epoch 4  | loss: 0.42151 | train_auc: 0.87139 | valid_auc: 0.87064 |  0:00:02s
epoch 5  | loss: 0.42397 | train_auc: 0.8672  | valid_auc: 0.86543 |  0:00:03s
epoch 6  | loss: 0.41859 | train_auc: 0.87843 | valid_auc: 0.87776 |  0:00:03s
epoch 7  | loss: 0.42508 | train_auc: 0.87997 | valid_auc: 0.88021 |  0:00:04s
epoch 8  | loss: 0.41231 | train_auc: 0.88318 | valid_auc: 0.88243 |  0:00:04s
epoch 9  | loss: 0.41639 | train_auc: 0.88449 | valid_auc: 0.88444 |  0:00:05s
epoch 10 | loss: 0.41853 | train_auc: 0.88665 | valid_auc: 0.88635 |  0:00:06s
epoch 11 | loss: 0.41354 | train_auc: 0.88748 | valid_auc: 0.88652 |  0:00:06s
epoch 12 | loss: 0.41704 | train_auc: 0.88782 | vali



epoch 0  | loss: 0.67726 | train_auc: 0.76177 | valid_auc: 0.76273 |  0:00:00s
epoch 1  | loss: 0.4597  | train_auc: 0.83363 | valid_auc: 0.82139 |  0:00:01s
epoch 2  | loss: 0.42504 | train_auc: 0.86342 | valid_auc: 0.85102 |  0:00:01s
epoch 3  | loss: 0.42235 | train_auc: 0.86503 | valid_auc: 0.85878 |  0:00:02s
epoch 4  | loss: 0.42599 | train_auc: 0.86819 | valid_auc: 0.86076 |  0:00:02s
epoch 5  | loss: 0.4225  | train_auc: 0.87213 | valid_auc: 0.86221 |  0:00:03s
epoch 6  | loss: 0.41522 | train_auc: 0.8759  | valid_auc: 0.86469 |  0:00:03s
epoch 7  | loss: 0.41547 | train_auc: 0.87974 | valid_auc: 0.8701  |  0:00:04s
epoch 8  | loss: 0.42346 | train_auc: 0.87981 | valid_auc: 0.86999 |  0:00:05s
epoch 9  | loss: 0.42017 | train_auc: 0.88311 | valid_auc: 0.87128 |  0:00:05s
epoch 10 | loss: 0.4185  | train_auc: 0.88343 | valid_auc: 0.87244 |  0:00:06s
epoch 11 | loss: 0.42214 | train_auc: 0.88536 | valid_auc: 0.87485 |  0:00:06s
epoch 12 | loss: 0.4165  | train_auc: 0.8882  | vali



epoch 0  | loss: 0.68823 | train_auc: 0.62631 | valid_auc: 0.63519 |  0:00:00s
epoch 1  | loss: 0.48896 | train_auc: 0.71402 | valid_auc: 0.70516 |  0:00:01s
epoch 2  | loss: 0.44925 | train_auc: 0.79606 | valid_auc: 0.79058 |  0:00:01s
epoch 3  | loss: 0.43226 | train_auc: 0.82719 | valid_auc: 0.8242  |  0:00:02s
epoch 4  | loss: 0.41955 | train_auc: 0.82431 | valid_auc: 0.82739 |  0:00:02s
epoch 5  | loss: 0.42893 | train_auc: 0.8689  | valid_auc: 0.86532 |  0:00:03s
epoch 6  | loss: 0.42609 | train_auc: 0.87273 | valid_auc: 0.87045 |  0:00:03s
epoch 7  | loss: 0.41921 | train_auc: 0.87823 | valid_auc: 0.8753  |  0:00:04s
epoch 8  | loss: 0.41122 | train_auc: 0.87825 | valid_auc: 0.87492 |  0:00:04s
epoch 9  | loss: 0.41892 | train_auc: 0.88269 | valid_auc: 0.87698 |  0:00:05s
epoch 10 | loss: 0.42255 | train_auc: 0.88578 | valid_auc: 0.88024 |  0:00:06s
epoch 11 | loss: 0.41611 | train_auc: 0.88751 | valid_auc: 0.88391 |  0:00:06s
epoch 12 | loss: 0.41484 | train_auc: 0.8885  | vali



epoch 0  | loss: 0.69782 | train_auc: 0.57978 | valid_auc: 0.57868 |  0:00:00s
epoch 1  | loss: 0.47078 | train_auc: 0.751   | valid_auc: 0.74759 |  0:00:01s
epoch 2  | loss: 0.44591 | train_auc: 0.68058 | valid_auc: 0.68107 |  0:00:01s
epoch 3  | loss: 0.42491 | train_auc: 0.81749 | valid_auc: 0.8115  |  0:00:02s
epoch 4  | loss: 0.42424 | train_auc: 0.80693 | valid_auc: 0.80067 |  0:00:02s
epoch 5  | loss: 0.41621 | train_auc: 0.85808 | valid_auc: 0.85273 |  0:00:03s
epoch 6  | loss: 0.42768 | train_auc: 0.87258 | valid_auc: 0.8689  |  0:00:03s
epoch 7  | loss: 0.43013 | train_auc: 0.87535 | valid_auc: 0.8705  |  0:00:04s
epoch 8  | loss: 0.42191 | train_auc: 0.8793  | valid_auc: 0.87535 |  0:00:04s
epoch 9  | loss: 0.42493 | train_auc: 0.88279 | valid_auc: 0.8799  |  0:00:05s
epoch 10 | loss: 0.41744 | train_auc: 0.88434 | valid_auc: 0.88239 |  0:00:06s
epoch 11 | loss: 0.4261  | train_auc: 0.88408 | valid_auc: 0.88347 |  0:00:06s
epoch 12 | loss: 0.41716 | train_auc: 0.88486 | vali



epoch 0  | loss: 0.69208 | train_auc: 0.71665 | valid_auc: 0.72953 |  0:00:00s
epoch 1  | loss: 0.4898  | train_auc: 0.77311 | valid_auc: 0.77425 |  0:00:01s
epoch 2  | loss: 0.44878 | train_auc: 0.75776 | valid_auc: 0.74975 |  0:00:01s
epoch 3  | loss: 0.43531 | train_auc: 0.83877 | valid_auc: 0.83553 |  0:00:02s
epoch 4  | loss: 0.44271 | train_auc: 0.85961 | valid_auc: 0.86162 |  0:00:02s
epoch 5  | loss: 0.42543 | train_auc: 0.86815 | valid_auc: 0.87238 |  0:00:03s
epoch 6  | loss: 0.43535 | train_auc: 0.87139 | valid_auc: 0.87425 |  0:00:03s
epoch 7  | loss: 0.42054 | train_auc: 0.87212 | valid_auc: 0.87412 |  0:00:04s
epoch 8  | loss: 0.42545 | train_auc: 0.88129 | valid_auc: 0.88297 |  0:00:04s
epoch 9  | loss: 0.41988 | train_auc: 0.87899 | valid_auc: 0.88251 |  0:00:05s
epoch 10 | loss: 0.41942 | train_auc: 0.88193 | valid_auc: 0.88615 |  0:00:06s
epoch 11 | loss: 0.42173 | train_auc: 0.88266 | valid_auc: 0.88645 |  0:00:06s
epoch 12 | loss: 0.42694 | train_auc: 0.8839  | vali



epoch 0  | loss: 0.67844 | train_auc: 0.65625 | valid_auc: 0.65475 |  0:00:00s
epoch 1  | loss: 0.4651  | train_auc: 0.7225  | valid_auc: 0.71145 |  0:00:01s
epoch 2  | loss: 0.43594 | train_auc: 0.76167 | valid_auc: 0.7448  |  0:00:01s
epoch 3  | loss: 0.43132 | train_auc: 0.82755 | valid_auc: 0.8173  |  0:00:02s
epoch 4  | loss: 0.42955 | train_auc: 0.85587 | valid_auc: 0.84641 |  0:00:02s
epoch 5  | loss: 0.43567 | train_auc: 0.856   | valid_auc: 0.84581 |  0:00:03s
epoch 6  | loss: 0.42058 | train_auc: 0.86097 | valid_auc: 0.85578 |  0:00:03s
epoch 7  | loss: 0.41946 | train_auc: 0.87427 | valid_auc: 0.86681 |  0:00:04s
epoch 8  | loss: 0.42042 | train_auc: 0.87547 | valid_auc: 0.86718 |  0:00:04s
epoch 9  | loss: 0.42127 | train_auc: 0.8798  | valid_auc: 0.87209 |  0:00:05s
epoch 10 | loss: 0.4195  | train_auc: 0.88263 | valid_auc: 0.87451 |  0:00:05s
epoch 11 | loss: 0.42339 | train_auc: 0.88398 | valid_auc: 0.87722 |  0:00:06s
epoch 12 | loss: 0.42356 | train_auc: 0.88541 | vali



epoch 0  | loss: 0.71089 | train_auc: 0.58354 | valid_auc: 0.60805 |  0:00:00s
epoch 1  | loss: 0.4708  | train_auc: 0.78845 | valid_auc: 0.78827 |  0:00:01s
epoch 2  | loss: 0.43632 | train_auc: 0.82196 | valid_auc: 0.81516 |  0:00:01s
epoch 3  | loss: 0.42823 | train_auc: 0.85676 | valid_auc: 0.84966 |  0:00:02s
epoch 4  | loss: 0.42449 | train_auc: 0.86224 | valid_auc: 0.858   |  0:00:02s
epoch 5  | loss: 0.42197 | train_auc: 0.87756 | valid_auc: 0.8753  |  0:00:03s
epoch 6  | loss: 0.4218  | train_auc: 0.87779 | valid_auc: 0.87366 |  0:00:03s
epoch 7  | loss: 0.41835 | train_auc: 0.88406 | valid_auc: 0.88209 |  0:00:04s
epoch 8  | loss: 0.42317 | train_auc: 0.881   | valid_auc: 0.87975 |  0:00:04s
epoch 9  | loss: 0.42456 | train_auc: 0.88382 | valid_auc: 0.88258 |  0:00:05s
epoch 10 | loss: 0.41665 | train_auc: 0.88781 | valid_auc: 0.88626 |  0:00:06s
epoch 11 | loss: 0.41792 | train_auc: 0.88823 | valid_auc: 0.88653 |  0:00:06s
epoch 12 | loss: 0.4131  | train_auc: 0.88918 | vali



epoch 0  | loss: 0.71283 | train_auc: 0.52785 | valid_auc: 0.53442 |  0:00:00s
epoch 1  | loss: 0.46992 | train_auc: 0.66395 | valid_auc: 0.67538 |  0:00:01s
epoch 2  | loss: 0.4478  | train_auc: 0.76282 | valid_auc: 0.77301 |  0:00:01s
epoch 3  | loss: 0.43766 | train_auc: 0.84567 | valid_auc: 0.85253 |  0:00:02s
epoch 4  | loss: 0.4347  | train_auc: 0.86095 | valid_auc: 0.86418 |  0:00:02s
epoch 5  | loss: 0.43358 | train_auc: 0.87386 | valid_auc: 0.87727 |  0:00:03s
epoch 6  | loss: 0.42093 | train_auc: 0.873   | valid_auc: 0.87715 |  0:00:03s
epoch 7  | loss: 0.42298 | train_auc: 0.87259 | valid_auc: 0.87635 |  0:00:04s
epoch 8  | loss: 0.42134 | train_auc: 0.88072 | valid_auc: 0.88561 |  0:00:04s
epoch 9  | loss: 0.42146 | train_auc: 0.8842  | valid_auc: 0.88901 |  0:00:05s
epoch 10 | loss: 0.42168 | train_auc: 0.88596 | valid_auc: 0.88939 |  0:00:06s
epoch 11 | loss: 0.41842 | train_auc: 0.88673 | valid_auc: 0.88986 |  0:00:06s
epoch 12 | loss: 0.42495 | train_auc: 0.88494 | vali



In [37]:
#prob_all_tn_a = prob_all
#prob_all_tn_b = prob_all
#prob_all_tn_c = prob_all
#prob_all_tn_d = prob_all # prob_all_tn_a, prob_all_tn_b,prob_all_tn_c
res_all_tn = res_all
res_all_tn.transpose().describe()

Unnamed: 0,acc,bac,recall,ppv,npv,sepecificity,f1,auc
count,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
mean,0.76193,0.813331,0.884087,0.35354,0.975928,0.742575,0.504652,0.891294
std,0.018106,0.007056,0.021403,0.017182,0.003991,0.023483,0.016131,0.004733
min,0.688664,0.791922,0.845238,0.294659,0.968656,0.650356,0.447928,0.881317
25%,0.758802,0.809521,0.871307,0.342628,0.973209,0.736059,0.495456,0.88937
50%,0.763893,0.81428,0.885122,0.356172,0.97629,0.74557,0.507859,0.891457
75%,0.773248,0.819199,0.896949,0.364483,0.977896,0.756775,0.516569,0.894948
max,0.782444,0.822543,0.933489,0.380039,0.984787,0.772843,0.526926,0.899301


In [38]:
res_all.transpose().to_csv('./fig/feature_d_tabnet.csv', index=False)

In [318]:
get_metric(tn.predict_proba(X_valid_tn.values[:]), y_valid, 0.5)

{'acc': 0.7630175658720201,
 'bac': 0.8120140086436312,
 'recall': 0.8788235294117647,
 'ppv': 0.34663573085846866,
 'npv': 0.9755981994787964,
 'sepecificity': 0.7452044878754976,
 'f1': 0.49717138103161396,
 'auc': 0.8853640544165549}

In [869]:
get_metric(tn.predict_proba(X_test_tn.values[:]), y_test, 0.5)

{'acc': 0.760446075663467,
 'bac': 0.8122309292613555,
 'recall': 0.8841893252769386,
 'ppv': 0.3569105691056911,
 'npv': 0.9751297577854672,
 'sepecificity': 0.7402725332457725,
 'f1': 0.5085432956849117,
 'auc': 0.8867687835535003}

In [876]:
ds = {'pred':tn.predict(X_test_tn.values[:]), 'prob':prob_tn_cali}
pd.DataFrame(data=ds)

Unnamed: 0,pred,prob
0,0,0.015244
1,0,0.015244
2,1,0.257028
3,0,0.013536
4,0,0.040385
...,...,...
7079,0,0.003861
7080,0,0.000000
7081,0,0.000000
7082,0,0.013536


----

### Ensamble Model

In [359]:
prob_lr_train = lr.predict_proba(X_train_lr)
prob_lr_valid = lr.predict_proba(X_valid_lr)
prob_lr_test = lr.predict_proba(X_test_lr)

prob_dt_train = dt.predict_proba(X_train_dt)
prob_dt_valid = dt.predict_proba(X_valid_dt)
prob_dt_test = dt.predict_proba(X_test_dt)

prob_rf_train = rf.predict_proba(X_train_rf)
prob_rf_valid = rf.predict_proba(X_valid_rf)
prob_rf_test = rf.predict_proba(X_test_rf)

prob_xgb_train = xgb.predict_proba(X_train_xgb)
prob_xgb_valid = xgb.predict_proba(X_valid_xgb)
prob_xgb_test = xgb.predict_proba(X_test_xgb)

prob_tn_train = tn.predict_proba(X_train_tn.values[:])
prob_tn_valid = tn.predict_proba(X_valid_tn.values[:])
prob_tn_test = tn.predict_proba(X_test_tn.values[:])

##### with Tabnet

In [360]:
X_train_meta = np.vstack((prob_lr_train[:,1],prob_dt_train[:,1],prob_rf_train[:,1],prob_xgb_train[:,1],prob_tn_train[:,1]))
X_valid_meta = np.vstack((prob_lr_valid[:,1],prob_dt_valid[:,1],prob_rf_valid[:,1],prob_xgb_valid[:,1],prob_tn_valid[:,1]))
X_test_meta = np.vstack((prob_lr_test[:,1],prob_dt_test[:,1],prob_rf_test[:,1],prob_xgb_test[:,1],prob_tn_test[:,1]))

X_train_meta = X_train_meta.T
X_valid_meta = X_valid_meta.T
X_test_meta = X_test_meta.T

X_train_meta = pd.DataFrame(X_train_meta, columns={'lr','dt','rf','xgb','tn'})
X_valid_meta = pd.DataFrame(X_valid_meta, columns={'lr','dt','rf','xgb','tn'})
X_test_meta = pd.DataFrame(X_test_meta, columns={'lr','dt','rf','xgb','tn'})

without Tabnet

In [270]:
X_train_meta = np.vstack((prob_lr_train[:,1],prob_dt_train[:,1],prob_rf_train[:,1],prob_xgb_train[:,1]))
X_valid_meta = np.vstack((prob_lr_valid[:,1],prob_dt_valid[:,1],prob_rf_valid[:,1],prob_xgb_valid[:,1]))

X_train_meta = X_train_meta.T
X_valid_meta = X_valid_meta.T

X_train_meta = pd.DataFrame(X_train_meta, columns={'lr','dt','rf','xgb'})
X_valid_meta = pd.DataFrame(X_valid_meta, columns={'lr','dt','rf','xgb'})

#### Calibration Probability

In [855]:
X_test_meta = np.vstack((prob_lr_cali, prob_dt_cali, prob_rf_cali, prob_xgb_cali, prob_tn_cali)).T
X_test_meta = pd.DataFrame(X_test_meta, columns={'lr','dt','rf','xgb','tn'})

In [859]:
prob = X_test_meta.mean(axis=1)
get_metric2(prob, y_test, tau)

{'acc': 0.7531055900621118,
 'bac': 0.8130195724033098,
 'recall': 0.8962739174219537,
 'ppv': 0.3509463722397476,
 'npv': 0.9773526824978013,
 'sepecificity': 0.7297652273846659,
 'f1': 0.5043921790875602,
 'auc': 0.8887942902897858}

#### soft

In [54]:
from sklearn.metrics import confusion_matrix, auc, roc_curve

def get_metric2(prob, label, threshold):
    #prob = prob[:,1]
    prd = np.where(prob>=threshold, 1, 0)
    
    tn, fp, fn, tp = confusion_matrix(label, prd, labels=[0,1]).ravel()
    
    accuracy = (tn+tp)/(tn+fp+fn+tp)
    sensitivity = tp/(fn+tp)
    specificity = tn/(fp+tn)
    ppv = tp/(tp+fp)
    npv = tn/(tn+fn)
    f1 = 2*(sensitivity*ppv)/(sensitivity+ppv)
    balaced_accuracy = 0.5*(sensitivity+specificity)
    fpr, tpr, thresholds = roc_curve(label, prob, pos_label=1)
    auc_roc = auc(fpr,tpr)
    
    res = {'acc' : accuracy,
           'bac' : balaced_accuracy,
           'recall': sensitivity,
           'ppv':ppv,
           'npv':npv,
           'sepecificity':specificity,
           'f1':f1,
           'auc':auc_roc}
    
    return res

In [858]:
prob = X_test_meta.mean(axis=1)
get_metric2(prob, y_test, 0.5)

{'acc': 0.7531055900621118,
 'bac': 0.8130195724033098,
 'recall': 0.8962739174219537,
 'ppv': 0.3509463722397476,
 'npv': 0.9773526824978013,
 'sepecificity': 0.7297652273846659,
 'f1': 0.5043921790875602,
 'auc': 0.8887942902897858}

#### Linear Regression

In [363]:
from sklearn.linear_model import LinearRegression
args = {
    #'penalty' : 'l1',
    #'solver' : 'liblinear',
    #'random_state' : 100
}
meta = LinearRegression(**args)
meta.fit(X_train_meta, y_train)

In [357]:
prob = meta.predict(X_valid_meta)
get_metric2(prob, y_valid, 0.5)

{'acc': 0.758450847829838,
 'bac': 0.818430677200757,
 'recall': 0.9007739401640291,
 'ppv': 0.34909123466738295,
 'npv': 0.9792577210054814,
 'sepecificity': 0.736087414237485,
 'f1': 0.503177931924504,
 'auc': 0.8980555579935069}

In [364]:
prob = meta.predict(X_test_meta)
get_metric2(prob, y_test, 0.5)

{'acc': 0.7446357989836251,
 'bac': 0.8038799093242255,
 'recall': 0.8862034239677744,
 'ppv': 0.3416149068322981,
 'npv': 0.9749334516415262,
 'sepecificity': 0.7215563946806764,
 'f1': 0.49313533202577753,
 'auc': 0.8813449192781583}

#### Random Forest

In [367]:
args = {'bootstrap': True,
        'max_depth': 6,
        'min_samples_leaf': 4,
        'min_samples_split': 10,
        'n_estimators': 300,
        'random_state' : 100}

meta = RandomForestClassifier(**args)
meta.fit(X_train_meta, y_train)

In [262]:
prob = meta.predict_proba(X_valid_meta)
get_metric(prob, y_valid, 0.5)

{'acc': 0.7797992471769134,
 'bac': 0.8022831108556343,
 'recall': 0.8329411764705882,
 'ppv': 0.35939086294416245,
 'npv': 0.9677712210621879,
 'sepecificity': 0.7716250452406804,
 'f1': 0.5021276595744681,
 'auc': 0.8410172233931574}

In [368]:
prob = meta.predict(X_test_meta)
get_metric2(prob, y_test, 0.5)

{'acc': 0.7632693393562959,
 'bac': 0.8046010961974339,
 'recall': 0.8620342396777442,
 'ppv': 0.3572621035058431,
 'npv': 0.9707764505119454,
 'sepecificity': 0.7471679527171237,
 'f1': 0.505163765122455,
 'auc': 0.8046010961974338}

#### Tabnet (Feature + Pred)

In [369]:
X_train_tn.reset_index(drop=True, inplace=True)
X_train_meta.reset_index(drop=True, inplace=True)
X_valid_tn.reset_index(drop=True, inplace=True)
X_valid_meta.reset_index(drop=True, inplace=True)
X_test_tn.reset_index(drop=True, inplace=True)
X_test_meta.reset_index(drop=True, inplace=True)

In [273]:
X_train_meta = pd.concat([X_train_tn,X_train_meta], axis=1)
X_valid_meta = pd.concat([X_valid_tn,X_valid_meta], axis=1)
X_test_meta = pd.concat([X_test_tn,X_valid_meta], axis=1)

In [371]:
args = {'seed':100}

estimator = TabNetClassifier(cat_idxs=cat_idxs, cat_dims=cat_dims, **args)
#meta = train_model_with_valid(estimator, X_train_meta, y_train, X_valid_meta, y_valid)
meta = train_model(estimator, X_train_meta, y_train)



epoch 0  | loss: 0.43698 |  0:00:00s
epoch 1  | loss: 0.39559 |  0:00:00s
epoch 2  | loss: 0.3919  |  0:00:01s
epoch 3  | loss: 0.38764 |  0:00:01s
epoch 4  | loss: 0.3846  |  0:00:02s
epoch 5  | loss: 0.38922 |  0:00:02s
epoch 6  | loss: 0.38507 |  0:00:03s
epoch 7  | loss: 0.3876  |  0:00:03s
epoch 8  | loss: 0.37832 |  0:00:03s
epoch 9  | loss: 0.37751 |  0:00:04s
epoch 10 | loss: 0.38024 |  0:00:04s
epoch 11 | loss: 0.38144 |  0:00:05s
epoch 12 | loss: 0.38836 |  0:00:05s
epoch 13 | loss: 0.38068 |  0:00:05s
epoch 14 | loss: 0.38547 |  0:00:06s
epoch 15 | loss: 0.3851  |  0:00:06s
epoch 16 | loss: 0.37948 |  0:00:07s
epoch 17 | loss: 0.38141 |  0:00:07s
epoch 18 | loss: 0.38486 |  0:00:07s
epoch 19 | loss: 0.37353 |  0:00:08s
epoch 20 | loss: 0.38684 |  0:00:08s
epoch 21 | loss: 0.3712  |  0:00:09s
epoch 22 | loss: 0.37168 |  0:00:09s
epoch 23 | loss: 0.37071 |  0:00:09s
epoch 24 | loss: 0.37599 |  0:00:10s
epoch 25 | loss: 0.37521 |  0:00:10s
epoch 26 | loss: 0.37713 |  0:00:11s
e

In [278]:
prob = meta.predict_proba(X_valid_meta.values[:])
get_metric(prob,y_valid,0.5)

{'acc': 0.7161229611041405,
 'bac': 0.8028792233505779,
 'recall': 0.9211764705882353,
 'ppv': 0.30997624703087884,
 'npv': 0.9825974025974026,
 'sepecificity': 0.6845819761129207,
 'f1': 0.4638625592417061,
 'auc': 0.8779280832854314}

In [372]:
prob = meta.predict_proba(X_test_meta.values[:])
get_metric(prob,y_test,0.5)

{'acc': 0.7780914737436476,
 'bac': 0.78287695364845,
 'recall': 0.7895266868076536,
 'ppv': 0.36516068933395435,
 'npv': 0.957666599149281,
 'sepecificity': 0.7762272204892464,
 'f1': 0.4993630573248408,
 'auc': 0.8667796559168159}