In [1]:
# NOTE: this clashes with smartnoise! dunno how to deal with, but we use this opacus
# for fttransformer, so we need to install it here
# ! pip install opacus==1.5.2

In [2]:
import numpy as np
import pandas as pd
import torch

from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, precision_recall_curve, classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

import ydnpd

from ydnpd.datasets.loader import load_dataset, split_train_eval_datasets

from ydnpd.pretraining.ft_transformer import FTTransformerModel
from ydnpd.pretraining.utils import preprocess_acs_for_ft_transformer, preprocess_acs_for_classification, print_model_performance

from pathlib import Path

LLM_PATH = "llm_datasets"
DATA_PATH = 'ydnpd/datasets/data'


In [3]:
LLM_EXPERIMENTS = {
    experiment_name:
    [
        (f"{experiment_name}/{path.stem}", LLM_PATH)
        for path in Path(f"{LLM_PATH}/{experiment_name}").glob("*.csv")
        ]
        for experiment_name in ["acs"]
}


EXPERIMENTS = {
    experiment_name:
    [
        (f"{experiment_name}/{path.stem}", DATA_PATH)
        for path in Path(f"{DATA_PATH}/{experiment_name}").glob("*.csv")
        ]
        for experiment_name in ["acs"]
}

In [4]:
real_dataset, real_processed_schema, real_domain = load_dataset('acs/national', path=DATA_PATH)
llm_dataset, llm_processed_schema, llm_domain = load_dataset(LLM_EXPERIMENTS['acs'][0][0], path=LLM_PATH)

### Baseline, non-private

In [5]:
real_df = preprocess_acs_for_classification(real_dataset)
X = real_df.drop('y', axis=1)
y = real_df['y']

X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(X, y, test_size=0.2)

# naive classifier
print(f"Base accuracy: {y_train_real.sum() / len(y_train_real)}")
print(classification_report(y_train_real, np.zeros_like(y_train_real)))
print('AUC:', roc_auc_score(y_train_real, np.zeros_like(y_train_real)))

# random forest
print('Random Forest')
rf = RandomForestClassifier()
rf.fit(X_train_real, y_train_real)
y_pred_real = rf.predict(X_test_real)
print(classification_report(y_test_real, y_pred_real))
print('AUC:', roc_auc_score(y_test_real, y_pred_real))
print()

# xgboost
print('XGBoost')
xgb = XGBClassifier()
xgb.fit(X_train_real, y_train_real)
y_pred_real = xgb.predict(X_test_real)
print(classification_report(y_test_real, y_pred_real))
print('AUC:', roc_auc_score(y_test_real, y_pred_real))
print()


Base accuracy: 0.5007688759034292
              precision    recall  f1-score   support

           0       0.50      1.00      0.67      6493
           1       0.00      0.00      0.00      6513

    accuracy                           0.50     13006
   macro avg       0.25      0.50      0.33     13006
weighted avg       0.25      0.50      0.33     13006

AUC: 0.5
Random Forest


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.66      0.67      0.66      1636
           1       0.66      0.65      0.66      1616

    accuracy                           0.66      3252
   macro avg       0.66      0.66      0.66      3252
weighted avg       0.66      0.66      0.66      3252

AUC: 0.6592434457382169

XGBoost
              precision    recall  f1-score   support

           0       0.68      0.68      0.68      1636
           1       0.68      0.68      0.68      1616

    accuracy                           0.68      3252
   macro avg       0.68      0.68      0.68      3252
weighted avg       0.68      0.68      0.68      3252

AUC: 0.6786603706214142



### Privacy, with FTTransformer

In [6]:
(
    X_cat_train,
    X_cont_train,
    X_cat_valid,
    X_cont_valid,
    y_train,
    y_valid,
    cat_cardinalities,
    config
) = preprocess_acs_for_ft_transformer(real_dataset)

(
    X_cat_train_llm,
    X_cont_train_llm,
    X_cat_valid_llm,
    X_cont_valid_llm,
    y_train_llm,
    y_valid_llm,
    cat_cardinalities_llm,
    _
) = preprocess_acs_for_ft_transformer(llm_dataset, config=config)


column SEX is categorical
column MSP is categorical
column RAC1P is categorical
column PINCP_DECILE is categorical
column EDU is categorical
column HOUSING_TYPE is categorical
{'SEX', 'EDU', 'HOUSING_TYPE', 'RAC1P', 'PINCP_DECILE', 'MSP'}
Index(['SEX', 'MSP', 'RAC1P', 'PINCP_DECILE', 'EDU', 'HOUSING_TYPE'], dtype='object')
[0 1]
[4 0 5 2 3 1]
[1 0 6 4 7 2 5 3]
[6 7 9 2 3 0 4 8 5 1]
[ 3  7  9  4 11  6  8  5  0  2 10  1]
[0]
cat training features shape: (13006, 6)
cont training features shape: (13006, 0)
cat val features shape: (3252, 6)
cont val features shape: (3252, 0)
training targets shape: (13006,)
val targets shape: (3252,)
cat feature cards: [2, 6, 8, 10, 12, 1]
config: {'SEX': {'type': 'categorical'}, 'MSP': {'type': 'categorical'}, 'RAC1P': {'type': 'categorical'}, 'PINCP_DECILE': {'type': 'categorical'}, 'EDU': {'type': 'categorical'}, 'HOUSING_TYPE': {'type': 'categorical'}}
{'SEX', 'EDU', 'HOUSING_TYPE', 'RAC1P', 'PINCP_DECILE', 'MSP'}
Index(['SEX', 'MSP', 'RAC1P', 'PINCP_DE

In [7]:
print(torch.isnan(X_cat_train).any())
print(torch.isnan(X_cont_train).any())
print(torch.isnan(y_train).any())

tensor(False)
tensor(False)
tensor(False)


#### Baseline, Non-private

In [8]:
classifier = FTTransformerModel(
    dim = 32,
    dim_out = 2,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    batch_size=128,
    num_epochs=20,
    lr=3e-4,
    load_best_model_when_trained=True,
    verbose=True)

classifier.fit(X_cat_train, X_cont_train, y_train.flatten(), cat_cardinalities, X_cont_train.shape[1], use_class_weights=True)

epoch 1/20 loss: 0.5513: 100%|██████████| 92/92 [00:01<00:00, 56.80batch/s]


val loss - new best: 0.6338354349136353
val accuracy - new best: 0.6594926979246734
val AUC - new best: 0.6959497530741441
epoch 1, validation loss: 0.6338354349136353, epochs without improvement: 0/10
reached 1.
epoch 1, val loss: 0.6338354349136353, epochs without improvement: 0/10
val accuracy - new best: 0.6594926979246734
val auc - new best: 0.6959497530741441


epoch 2/20 loss: 0.6333: 100%|██████████| 92/92 [00:01<00:00, 69.27batch/s]


val loss - new best: 0.6162636876106262
val accuracy - new best: 0.6694850115295926
val AUC - new best: 0.7212365705952262
epoch 2, validation loss: 0.6162636876106262, epochs without improvement: 0/10


epoch 3/20 loss: 0.5874: 100%|██████████| 92/92 [00:01<00:00, 68.95batch/s]


epoch 3, validation loss: 0.6169964671134949, epochs without improvement: 1/10


epoch 4/20 loss: 0.5242: 100%|██████████| 92/92 [00:01<00:00, 69.09batch/s]


val loss - new best: 0.6151787042617798
val accuracy - new best: 0.6641045349730976
val AUC - new best: 0.722204860864501
epoch 4, validation loss: 0.6151787042617798, epochs without improvement: 0/10


epoch 5/20 loss: 0.6598: 100%|██████████| 92/92 [00:01<00:00, 69.10batch/s]


val loss - new best: 0.6115995645523071
val accuracy - new best: 0.671022290545734
val AUC - new best: 0.7265467786000748
epoch 5, validation loss: 0.6115995645523071, epochs without improvement: 0/10
reached 5.
epoch 5, val loss: 0.6115995645523071, epochs without improvement: 0/10
val accuracy - new best: 0.671022290545734
val auc - new best: 0.7265467786000748


epoch 6/20 loss: 0.6395: 100%|██████████| 92/92 [00:01<00:00, 68.97batch/s]


epoch 6, validation loss: 0.6127369999885559, epochs without improvement: 1/10


epoch 7/20 loss: 0.5975: 100%|██████████| 92/92 [00:01<00:00, 59.45batch/s]


val loss - new best: 0.6081558465957642
val accuracy - new best: 0.671022290545734
val AUC - new best: 0.7331898653863455
epoch 7, validation loss: 0.6081558465957642, epochs without improvement: 0/10


epoch 8/20 loss: 0.5656: 100%|██████████| 92/92 [00:01<00:00, 68.99batch/s]


epoch 8, validation loss: 0.6098596453666687, epochs without improvement: 1/10


epoch 9/20 loss: 0.6305: 100%|██████████| 92/92 [00:01<00:00, 69.00batch/s]


epoch 9, validation loss: 0.6098175644874573, epochs without improvement: 2/10
reached 9.
epoch 9, val loss: 0.6098175644874573, epochs without improvement: 2/10
val accuracy - new best: 0.671022290545734
val auc - new best: 0.7331898653863455


epoch 10/20 loss: 0.6714: 100%|██████████| 92/92 [00:01<00:00, 69.38batch/s]


epoch 10, validation loss: 0.6138578653335571, epochs without improvement: 3/10


epoch 11/20 loss: 0.6691: 100%|██████████| 92/92 [00:01<00:00, 69.54batch/s]


epoch 11, validation loss: 0.6082061529159546, epochs without improvement: 4/10


epoch 12/20 loss: 0.5802: 100%|██████████| 92/92 [00:01<00:00, 69.54batch/s]


epoch 12, validation loss: 0.6099358201026917, epochs without improvement: 5/10


epoch 13/20 loss: 0.6265: 100%|██████████| 92/92 [00:01<00:00, 69.34batch/s]


epoch 13, validation loss: 0.6098206043243408, epochs without improvement: 6/10
reached 13.
epoch 13, val loss: 0.6098206043243408, epochs without improvement: 6/10
val accuracy - new best: 0.671022290545734
val auc - new best: 0.7331898653863455


epoch 14/20 loss: 0.6757: 100%|██████████| 92/92 [00:01<00:00, 69.49batch/s]


val loss - new best: 0.6078999638557434
val accuracy - new best: 0.6910069177555727
val AUC - new best: 0.7318569866049234
epoch 14, validation loss: 0.6078999638557434, epochs without improvement: 0/10


epoch 15/20 loss: 0.6940: 100%|██████████| 92/92 [00:01<00:00, 69.35batch/s]


val loss - new best: 0.6073724031448364
val accuracy - new best: 0.6810146041506533
val AUC - new best: 0.7312722245107648
epoch 15, validation loss: 0.6073724031448364, epochs without improvement: 0/10


epoch 16/20 loss: 0.5432: 100%|██████████| 92/92 [00:01<00:00, 69.48batch/s]


epoch 16, validation loss: 0.6075178384780884, epochs without improvement: 1/10


epoch 17/20 loss: 0.6156: 100%|██████████| 92/92 [00:01<00:00, 69.43batch/s]


epoch 17, validation loss: 0.6120955944061279, epochs without improvement: 2/10
reached 17.
epoch 17, val loss: 0.6120955944061279, epochs without improvement: 2/10
val accuracy - new best: 0.6810146041506533
val auc - new best: 0.7312722245107648


epoch 18/20 loss: 0.4956: 100%|██████████| 92/92 [00:01<00:00, 69.48batch/s]


epoch 18, validation loss: 0.6092701554298401, epochs without improvement: 3/10


epoch 19/20 loss: 0.5681: 100%|██████████| 92/92 [00:01<00:00, 69.41batch/s]


val loss - new best: 0.6073643565177917
val accuracy - new best: 0.6756341275941583
val AUC - new best: 0.7345937679039002
epoch 19, validation loss: 0.6073643565177917, epochs without improvement: 0/10


epoch 20/20 loss: 0.6843: 100%|██████████| 92/92 [00:01<00:00, 69.38batch/s]


val loss - new best: 0.6069748401641846
val accuracy - new best: 0.6833205226748655
val AUC - new best: 0.7323091710583011
epoch 20, validation loss: 0.6069748401641846, epochs without improvement: 0/10


In [9]:
print_model_performance(classifier, X_cat_valid, X_cont_valid, y_valid)


AUC: 0.7298
Optimal Threshold: 0.3308
              precision    recall  f1-score   support

         0.0       0.75      0.43      0.55      1640
         1.0       0.60      0.86      0.70      1612

    accuracy                           0.64      3252
   macro avg       0.68      0.64      0.63      3252
weighted avg       0.68      0.64      0.63      3252




### Just Private, no pretraining

In [None]:
classifier = FTTransformerModel(
    dim = 32,
    dim_out = 2,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    batch_size=128,
    num_epochs=20,
    lr=3e-4,
    load_best_model_when_trained=True,
    verbose=True,
    dp=True, # turn on privacy!
    epsilon=1.0) # turn on privacy!

classifier.fit(X_cat_train, X_cont_train, y_train.flatten(), cat_cardinalities, X_cont_train.shape[1], use_class_weights=True)



  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 1/20 loss: 0.6831: 100%|██████████| 92/92 [00:05<00:00, 16.51batch/s]


val loss - new best: 0.7014443874359131
val accuracy - new best: 0.528055342044581
val AUC - new best: 0.5418211432927551
epoch 1, validation loss: 0.7014443874359131, epochs without improvement: 0/10
reached 1.
epoch 1, val loss: 0.7014443874359131, epochs without improvement: 0/10
val accuracy - new best: 0.528055342044581
val auc - new best: 0.5418211432927551


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 2/20 loss: 0.6846: 100%|██████████| 92/92 [00:05<00:00, 16.81batch/s]


val loss - new best: 0.697856605052948
val accuracy - new best: 0.531129900076864
val AUC - new best: 0.5555476640293187
epoch 2, validation loss: 0.697856605052948, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 3/20 loss: 0.6390: 100%|██████████| 92/92 [00:05<00:00, 16.68batch/s]


val loss - new best: 0.6829385757446289
val accuracy - new best: 0.5757109915449654
val AUC - new best: 0.6129111682457611
epoch 3, validation loss: 0.6829385757446289, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 4/20 loss: 0.7185: 100%|██████████| 92/92 [00:05<00:00, 16.68batch/s]


epoch 4, validation loss: 0.690453052520752, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 5/20 loss: 0.7713: 100%|██████████| 92/92 [00:05<00:00, 16.67batch/s]


epoch 5, validation loss: 0.7556646466255188, epochs without improvement: 2/10
reached 5.
epoch 5, val loss: 0.7556646466255188, epochs without improvement: 2/10
val accuracy - new best: 0.5757109915449654
val auc - new best: 0.6129111682457611


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 6/20 loss: 0.7056: 100%|██████████| 92/92 [00:05<00:00, 16.66batch/s]


epoch 6, validation loss: 0.8257616758346558, epochs without improvement: 3/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 7/20 loss: 1.3008: 100%|██████████| 92/92 [00:05<00:00, 16.66batch/s]


epoch 7, validation loss: 0.970580518245697, epochs without improvement: 4/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 8/20 loss: 0.7504: 100%|██████████| 92/92 [00:05<00:00, 16.68batch/s]


epoch 8, validation loss: 0.9845970869064331, epochs without improvement: 5/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 9/20 loss: 0.9158: 100%|██████████| 92/92 [00:05<00:00, 16.65batch/s]


epoch 9, validation loss: 0.998762309551239, epochs without improvement: 6/10
reached 9.
epoch 9, val loss: 0.998762309551239, epochs without improvement: 6/10
val accuracy - new best: 0.5757109915449654
val auc - new best: 0.6129111682457611


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 10/20 loss: 0.8140: 100%|██████████| 92/92 [00:05<00:00, 16.63batch/s]


epoch 10, validation loss: 0.9841516613960266, epochs without improvement: 7/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 11/20 loss: 0.9677: 100%|██████████| 92/92 [00:05<00:00, 16.61batch/s]


epoch 11, validation loss: 0.9668220281600952, epochs without improvement: 8/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 12/20 loss: 1.0395: 100%|██████████| 92/92 [00:05<00:00, 16.65batch/s]


epoch 12, validation loss: 0.9981147050857544, epochs without improvement: 9/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 13/20 loss: 0.8439: 100%|██████████| 92/92 [00:05<00:00, 16.03batch/s]

epoch 13, validation loss: 0.9763025045394897, epochs without improvement: 10/10
stopping early at epoch 13. no improvement in validation loss for 10 consecutive epochs.
epoch 13, validation loss: 0.9763025045394897, epochs without improvement: 10/10
val accuracy - new best: 0.5757109915449654
val auc - new best: 0.6129111682457611





In [11]:
print_model_performance(classifier, X_cat_valid, X_cont_valid, y_valid)


AUC: 0.6981
Optimal Threshold: 0.0666
              precision    recall  f1-score   support

         0.0       0.72      0.51      0.60      1640
         1.0       0.62      0.80      0.70      1612

    accuracy                           0.66      3252
   macro avg       0.67      0.66      0.65      3252
weighted avg       0.67      0.66      0.65      3252




### With Pretraining, Private

In [12]:
classifier = FTTransformerModel(
    dim = 32,
    dim_out = 2,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    batch_size=128,
    num_epochs=20,
    lr=3e-4,
    load_best_model_when_trained=True,
    verbose=True,
    epsilon=3.0,
    partial_dp=True,
    partial_pretrain_config={
        'X_cat_pre': X_cat_train_llm,
        'X_cont_pre': X_cont_train_llm,
        'y_pre': y_train_llm,
        'categories': cat_cardinalities,
        'num_continuous': X_cont_train_llm.shape[1],
        'pre_epochs': 3,
        'pre_batch_size': 4,
        'pre_lr': 3e-4,
    }
)

classifier.fit(X_cat_train, X_cont_train, y_train.flatten(), cat_cardinalities, X_cont_train.shape[1])

X_cat_pre torch.Size([13969, 6])
X_cont_pre torch.Size([13969, 0])
y_pre torch.Size([13969])


pretraining Epoch 1/3 loss: 0.1431: 100%|██████████| 3143/3143 [00:41<00:00, 75.62batch/s]


pretraining val loss - new best: 0.12106065452098846
pretraining epoch 1, validation Loss: 0.12106065452098846, epochs without improvement: 0/10


pretraining Epoch 2/3 loss: 0.0690: 100%|██████████| 3143/3143 [00:41<00:00, 76.09batch/s]


pretraining val loss - new best: 0.10496125370264053
pretraining epoch 2, validation Loss: 0.10496125370264053, epochs without improvement: 0/10


pretraining Epoch 3/3 loss: 0.5438: 100%|██████████| 3143/3143 [00:41<00:00, 75.66batch/s]


pretraining val loss - new best: 0.10271900147199631
pretraining epoch 3, validation Loss: 0.10271900147199631, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 1/20 loss: 1.3342: 100%|██████████| 92/92 [00:05<00:00, 16.63batch/s]


val loss - new best: 1.1961641311645508
val accuracy - new best: 0.644119907763259
val AUC - new best: 0.6864195514140825
epoch 1, validation loss: 1.1961641311645508, epochs without improvement: 0/10
reached 1.
epoch 1, val loss: 1.1961641311645508, epochs without improvement: 0/10
val accuracy - new best: 0.644119907763259
val auc - new best: 0.6864195514140825


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 2/20 loss: 1.0732: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


val loss - new best: 1.090497374534607
val accuracy - new best: 0.652574942352037
val AUC - new best: 0.6838686155579861
epoch 2, validation loss: 1.090497374534607, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 3/20 loss: 0.8798: 100%|██████████| 92/92 [00:05<00:00, 16.68batch/s]


val loss - new best: 1.0249730348587036
val accuracy - new best: 0.6502690238278247
val AUC - new best: 0.6853198672329626
epoch 3, validation loss: 1.0249730348587036, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 4/20 loss: 1.0021: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


val loss - new best: 0.9996715784072876
val accuracy - new best: 0.6487317448116833
val AUC - new best: 0.6936793609757714
epoch 4, validation loss: 0.9996715784072876, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 5/20 loss: 0.8377: 100%|██████████| 92/92 [00:05<00:00, 16.68batch/s]


val loss - new best: 0.9988818764686584
val accuracy - new best: 0.6518063028439662
val AUC - new best: 0.6901163368797852
epoch 5, validation loss: 0.9988818764686584, epochs without improvement: 0/10
reached 5.
epoch 5, val loss: 0.9988818764686584, epochs without improvement: 0/10
val accuracy - new best: 0.6518063028439662
val auc - new best: 0.6901163368797852


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 6/20 loss: 0.9448: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


val loss - new best: 0.9882392287254333
val accuracy - new best: 0.6464258262874711
val AUC - new best: 0.6950513028120664
epoch 6, validation loss: 0.9882392287254333, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 7/20 loss: 1.0647: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


val loss - new best: 0.964897871017456
val accuracy - new best: 0.6487317448116833
val AUC - new best: 0.6986273479263436
epoch 7, validation loss: 0.964897871017456, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 8/20 loss: 0.9866: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


epoch 8, validation loss: 0.9671794772148132, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 9/20 loss: 0.7656: 100%|██████████| 92/92 [00:05<00:00, 16.68batch/s]


val loss - new best: 0.9593520760536194
val accuracy - new best: 0.6487317448116833
val AUC - new best: 0.698717311325445
epoch 9, validation loss: 0.9593520760536194, epochs without improvement: 0/10
reached 9.
epoch 9, val loss: 0.9593520760536194, epochs without improvement: 0/10
val accuracy - new best: 0.6487317448116833
val auc - new best: 0.698717311325445


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 10/20 loss: 0.8632: 100%|██████████| 92/92 [00:05<00:00, 16.67batch/s]


epoch 10, validation loss: 0.9732196927070618, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 11/20 loss: 1.0922: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


epoch 11, validation loss: 0.97207111120224, epochs without improvement: 2/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 12/20 loss: 0.9565: 100%|██████████| 92/92 [00:05<00:00, 16.75batch/s]


val loss - new best: 0.9540778398513794
val accuracy - new best: 0.6510376633358954
val AUC - new best: 0.6999389195869261
epoch 12, validation loss: 0.9540778398513794, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 13/20 loss: 0.9866: 100%|██████████| 92/92 [00:05<00:00, 16.85batch/s]


epoch 13, validation loss: 0.9662309288978577, epochs without improvement: 1/10
reached 13.
epoch 13, val loss: 0.9662309288978577, epochs without improvement: 1/10
val accuracy - new best: 0.6510376633358954
val auc - new best: 0.6999389195869261


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 14/20 loss: 1.0169: 100%|██████████| 92/92 [00:05<00:00, 16.86batch/s]


epoch 14, validation loss: 0.9671354293823242, epochs without improvement: 2/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 15/20 loss: 1.0183: 100%|██████████| 92/92 [00:05<00:00, 16.80batch/s]


val loss - new best: 0.951413094997406
val accuracy - new best: 0.6518063028439662
val AUC - new best: 0.6980757302423803
epoch 15, validation loss: 0.951413094997406, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 16/20 loss: 0.8443: 100%|██████████| 92/92 [00:05<00:00, 16.86batch/s]


epoch 16, validation loss: 0.9523166418075562, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 17/20 loss: 0.7943: 100%|██████████| 92/92 [00:05<00:00, 16.85batch/s]


epoch 17, validation loss: 0.9548022747039795, epochs without improvement: 2/10
reached 17.
epoch 17, val loss: 0.9548022747039795, epochs without improvement: 2/10
val accuracy - new best: 0.6518063028439662
val auc - new best: 0.6980757302423803


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 18/20 loss: 0.9156: 100%|██████████| 92/92 [00:05<00:00, 16.85batch/s]


epoch 18, validation loss: 0.9561578035354614, epochs without improvement: 3/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 19/20 loss: 0.8297: 100%|██████████| 92/92 [00:05<00:00, 16.87batch/s]


val loss - new best: 0.944879412651062
val accuracy - new best: 0.649500384319754
val AUC - new best: 0.705452728968688
epoch 19, validation loss: 0.944879412651062, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 20/20 loss: 0.9395: 100%|██████████| 92/92 [00:05<00:00, 16.87batch/s]

epoch 20, validation loss: 0.9521337151527405, epochs without improvement: 1/10





In [13]:
print_model_performance(classifier, X_cat_valid, X_cont_valid, y_valid)


AUC: 0.7072
Optimal Threshold: 0.0711
              precision    recall  f1-score   support

         0.0       0.74      0.40      0.52      1640
         1.0       0.58      0.86      0.69      1612

    accuracy                           0.63      3252
   macro avg       0.66      0.63      0.61      3252
weighted avg       0.66      0.63      0.60      3252


