In [1]:
# NOTE: this clashes with smartnoise! dunno how to deal with, but we use this opacus
# for fttransformer, so we need to install it here
# ! pip install opacus==1.5.2

In [18]:
import numpy as np
import pandas as pd
import torch

from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, precision_recall_curve, classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

import ydnpd

from ydnpd.datasets.loader import load_dataset, split_train_eval_datasets

from ydnpd.pretraining.ft_transformer import FTTransformerModel
from ydnpd.pretraining.utils import preprocess_acs_for_ft_transformer, preprocess_acs_for_classification, print_model_performance

from pathlib import Path

LLM_PATH = "llm_datasets"
DATA_PATH = 'ydnpd/datasets/data'


In [19]:
LLM_EXPERIMENTS = {
    experiment_name:
    [
        (f"{experiment_name}/{path.stem}", LLM_PATH)
        for path in Path(f"{LLM_PATH}/{experiment_name}").glob("*.csv")
        ]
        for experiment_name in ["acs"]
}


EXPERIMENTS = {
    experiment_name:
    [
        (f"{experiment_name}/{path.stem}", DATA_PATH)
        for path in Path(f"{DATA_PATH}/{experiment_name}").glob("*.csv")
        ]
        for experiment_name in ["acs"]
}

In [21]:
real_dataset, real_processed_schema, real_domain = load_dataset('acs/national', path=DATA_PATH)
llm_dataset, llm_processed_schema, llm_domain = load_dataset(LLM_EXPERIMENTS['acs'][0][0], path=LLM_PATH)

In [23]:
LLM_EXPERIMENTS['acs'][0][0]

'acs/csv-claude'

### Baseline, non-private

In [22]:
real_df = preprocess_acs_for_classification(real_dataset)
X = real_df.drop('y', axis=1)
y = real_df['y']

X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(X, y, test_size=0.2)

# naive classifier
print(f"Base accuracy: {y_train_real.sum() / len(y_train_real)}")
print(classification_report(y_train_real, np.zeros_like(y_train_real)))
print('AUC:', roc_auc_score(y_train_real, np.zeros_like(y_train_real)))

# random forest
print('Random Forest')
rf = RandomForestClassifier()
rf.fit(X_train_real, y_train_real)
y_pred_real = rf.predict(X_test_real)
print(classification_report(y_test_real, y_pred_real))
print('AUC:', roc_auc_score(y_test_real, y_pred_real))
print()

# xgboost
print('XGBoost')
xgb = XGBClassifier()
xgb.fit(X_train_real, y_train_real)
y_pred_real = xgb.predict(X_test_real)
print(classification_report(y_test_real, y_pred_real))
print('AUC:', roc_auc_score(y_test_real, y_pred_real))
print()


Base accuracy: 0.4996924496386283
              precision    recall  f1-score   support

           0       0.50      1.00      0.67      6507
           1       0.00      0.00      0.00      6499

    accuracy                           0.50     13006
   macro avg       0.25      0.50      0.33     13006
weighted avg       0.25      0.50      0.33     13006

AUC: 0.5
Random Forest


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           0       0.67      0.68      0.68      1622
           1       0.68      0.67      0.67      1630

    accuracy                           0.67      3252
   macro avg       0.67      0.67      0.67      3252
weighted avg       0.67      0.67      0.67      3252

AUC: 0.6746764200827579

XGBoost
              precision    recall  f1-score   support

           0       0.70      0.69      0.70      1622
           1       0.70      0.71      0.70      1630

    accuracy                           0.70      3252
   macro avg       0.70      0.70      0.70      3252
weighted avg       0.70      0.70      0.70      3252

AUC: 0.6998600531041734



### Privacy, with FTTransformer

In [6]:
(
    X_cat_train,
    X_cont_train,
    X_cat_valid,
    X_cont_valid,
    y_train,
    y_valid,
    cat_cardinalities,
    config
) = preprocess_acs_for_ft_transformer(real_dataset)

(
    X_cat_train_llm,
    X_cont_train_llm,
    X_cat_valid_llm,
    X_cont_valid_llm,
    y_train_llm,
    y_valid_llm,
    cat_cardinalities_llm,
    _
) = preprocess_acs_for_ft_transformer(llm_dataset, config=config)


column SEX is categorical
column MSP is categorical
column RAC1P is categorical
column PINCP_DECILE is categorical
column EDU is categorical
column HOUSING_TYPE is categorical
{'SEX', 'EDU', 'HOUSING_TYPE', 'RAC1P', 'PINCP_DECILE', 'MSP'}
Index(['SEX', 'MSP', 'RAC1P', 'PINCP_DECILE', 'EDU', 'HOUSING_TYPE'], dtype='object')
[0 1]
[4 0 5 2 3 1]
[1 0 6 4 7 2 5 3]
[6 7 9 2 3 0 4 8 5 1]
[ 3  7  9  4 11  6  8  5  0  2 10  1]
[0]
cat training features shape: (13006, 6)
cont training features shape: (13006, 0)
cat val features shape: (3252, 6)
cont val features shape: (3252, 0)
training targets shape: (13006,)
val targets shape: (3252,)
cat feature cards: [2, 6, 8, 10, 12, 1]
config: {'SEX': {'type': 'categorical'}, 'MSP': {'type': 'categorical'}, 'RAC1P': {'type': 'categorical'}, 'PINCP_DECILE': {'type': 'categorical'}, 'EDU': {'type': 'categorical'}, 'HOUSING_TYPE': {'type': 'categorical'}}
{'SEX', 'EDU', 'HOUSING_TYPE', 'RAC1P', 'PINCP_DECILE', 'MSP'}
Index(['SEX', 'MSP', 'RAC1P', 'PINCP_DE

In [7]:
print(torch.isnan(X_cat_train).any())
print(torch.isnan(X_cont_train).any())
print(torch.isnan(y_train).any())

tensor(False)
tensor(False)
tensor(False)


#### Baseline, Non-private

In [24]:
classifier = FTTransformerModel(
    dim = 32,
    dim_out = 2,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    batch_size=128,
    num_epochs=20,
    lr=3e-4,
    load_best_model_when_trained=True,
    verbose=True)

classifier.fit(X_cat_train, X_cont_train, y_train.flatten(), cat_cardinalities, X_cont_train.shape[1], use_class_weights=True)

epoch 1/20 loss: 0.8118:   0%|          | 0/92 [00:00<?, ?batch/s]

epoch 1/20 loss: 0.5477: 100%|██████████| 92/92 [00:01<00:00, 68.57batch/s]


val loss - new best: 0.6284480690956116
val accuracy - new best: 0.6571867794004612
val AUC - new best: 0.6988380516768705
epoch 1, validation loss: 0.6284480690956116, epochs without improvement: 0/10
reached 1.
epoch 1, val loss: 0.6284480690956116, epochs without improvement: 0/10
val accuracy - new best: 0.6571867794004612
val auc - new best: 0.6988380516768705


epoch 2/20 loss: 0.6088: 100%|██████████| 92/92 [00:01<00:00, 69.16batch/s]


val loss - new best: 0.6220353841781616
val accuracy - new best: 0.6694850115295926
val AUC - new best: 0.7133553033423771
epoch 2, validation loss: 0.6220353841781616, epochs without improvement: 0/10


epoch 3/20 loss: 0.6430: 100%|██████████| 92/92 [00:01<00:00, 69.11batch/s]


val loss - new best: 0.6173290014266968
val accuracy - new best: 0.6679477325134512
val AUC - new best: 0.7188998896764631
epoch 3, validation loss: 0.6173290014266968, epochs without improvement: 0/10


epoch 4/20 loss: 0.6142: 100%|██████████| 92/92 [00:01<00:00, 69.05batch/s]


epoch 4, validation loss: 0.6218641996383667, epochs without improvement: 1/10


epoch 5/20 loss: 0.6050: 100%|██████████| 92/92 [00:01<00:00, 69.02batch/s]


val loss - new best: 0.6127835512161255
val accuracy - new best: 0.6740968485780169
val AUC - new best: 0.7241011946192417
epoch 5, validation loss: 0.6127835512161255, epochs without improvement: 0/10
reached 5.
epoch 5, val loss: 0.6127835512161255, epochs without improvement: 0/10
val accuracy - new best: 0.6740968485780169
val auc - new best: 0.7241011946192417


epoch 6/20 loss: 0.5662: 100%|██████████| 92/92 [00:01<00:00, 69.00batch/s]


epoch 6, validation loss: 0.6134600043296814, epochs without improvement: 1/10


epoch 7/20 loss: 0.6902: 100%|██████████| 92/92 [00:01<00:00, 68.99batch/s]


epoch 7, validation loss: 0.6189151406288147, epochs without improvement: 2/10


epoch 8/20 loss: 0.6117: 100%|██████████| 92/92 [00:01<00:00, 68.96batch/s]


val loss - new best: 0.6111562252044678
val accuracy - new best: 0.6748654880860876
val AUC - new best: 0.7302068684687756
epoch 8, validation loss: 0.6111562252044678, epochs without improvement: 0/10


epoch 9/20 loss: 0.5377: 100%|██████████| 92/92 [00:01<00:00, 68.97batch/s]


epoch 9, validation loss: 0.6139976382255554, epochs without improvement: 1/10
reached 9.
epoch 9, val loss: 0.6139976382255554, epochs without improvement: 1/10
val accuracy - new best: 0.6748654880860876
val auc - new best: 0.7302068684687756


epoch 10/20 loss: 0.6116: 100%|██████████| 92/92 [00:01<00:00, 59.38batch/s]


val loss - new best: 0.6090763211250305
val accuracy - new best: 0.6717909300538047
val AUC - new best: 0.7349039048850126
epoch 10, validation loss: 0.6090763211250305, epochs without improvement: 0/10


epoch 11/20 loss: 0.6000: 100%|██████████| 92/92 [00:01<00:00, 68.81batch/s]


epoch 11, validation loss: 0.6124808192253113, epochs without improvement: 1/10


epoch 12/20 loss: 0.6435: 100%|██████████| 92/92 [00:01<00:00, 68.81batch/s]


val loss - new best: 0.6073511242866516
val accuracy - new best: 0.6733282090699462
val AUC - new best: 0.7338172417221835
epoch 12, validation loss: 0.6073511242866516, epochs without improvement: 0/10


epoch 13/20 loss: 0.5430: 100%|██████████| 92/92 [00:01<00:00, 68.78batch/s]


epoch 13, validation loss: 0.6100810766220093, epochs without improvement: 1/10
reached 13.
epoch 13, val loss: 0.6100810766220093, epochs without improvement: 1/10
val accuracy - new best: 0.6733282090699462
val auc - new best: 0.7338172417221835


epoch 14/20 loss: 0.5712: 100%|██████████| 92/92 [00:01<00:00, 68.76batch/s]


val loss - new best: 0.6059876084327698
val accuracy - new best: 0.6787086856264412
val AUC - new best: 0.7351004038883129
epoch 14, validation loss: 0.6059876084327698, epochs without improvement: 0/10


epoch 15/20 loss: 0.5720: 100%|██████████| 92/92 [00:01<00:00, 68.79batch/s]


epoch 15, validation loss: 0.6081082224845886, epochs without improvement: 1/10


epoch 16/20 loss: 0.6315: 100%|██████████| 92/92 [00:01<00:00, 68.82batch/s]


epoch 16, validation loss: 0.6128257513046265, epochs without improvement: 2/10


epoch 17/20 loss: 0.5598: 100%|██████████| 92/92 [00:01<00:00, 68.82batch/s]


epoch 17, validation loss: 0.6059964299201965, epochs without improvement: 3/10
reached 17.
epoch 17, val loss: 0.6059964299201965, epochs without improvement: 3/10
val accuracy - new best: 0.6787086856264412
val auc - new best: 0.7351004038883129


epoch 18/20 loss: 0.5998: 100%|██████████| 92/92 [00:01<00:00, 68.78batch/s]


epoch 18, validation loss: 0.6074085235595703, epochs without improvement: 4/10


epoch 19/20 loss: 0.6096: 100%|██████████| 92/92 [00:01<00:00, 68.80batch/s]


epoch 19, validation loss: 0.6071377992630005, epochs without improvement: 5/10


epoch 20/20 loss: 0.6160: 100%|██████████| 92/92 [00:01<00:00, 68.80batch/s]


val loss - new best: 0.6004096269607544
val accuracy - new best: 0.6840891621829363
val AUC - new best: 0.7425152819405578
epoch 20, validation loss: 0.6004096269607544, epochs without improvement: 0/10


In [25]:
print_model_performance(classifier, X_cat_valid, X_cont_valid, y_valid)


AUC: 0.7354
Optimal Threshold: 0.3765
              precision    recall  f1-score   support

         0.0       0.73      0.55      0.63      1640
         1.0       0.64      0.80      0.71      1612

    accuracy                           0.67      3252
   macro avg       0.69      0.67      0.67      3252
weighted avg       0.69      0.67      0.67      3252




### Just Private, no pretraining

In [26]:
classifier = FTTransformerModel(
    dim = 32,
    dim_out = 2,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    batch_size=128,
    num_epochs=20,
    lr=3e-4,
    load_best_model_when_trained=True,
    verbose=True,
    dp=True, # turn on privacy!
    epsilon=3.0) # turn on privacy!

classifier.fit(X_cat_train, X_cont_train, y_train.flatten(), cat_cardinalities, X_cont_train.shape[1], use_class_weights=True)



  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 1/20 loss: 0.7498: 100%|██████████| 92/92 [00:05<00:00, 16.71batch/s]


val loss - new best: 0.7142237424850464
val accuracy - new best: 0.5026902382782475
val AUC - new best: 0.5096024091251296
epoch 1, validation loss: 0.7142237424850464, epochs without improvement: 0/10
reached 1.
epoch 1, val loss: 0.7142237424850464, epochs without improvement: 0/10
val accuracy - new best: 0.5026902382782475
val auc - new best: 0.5096024091251296


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 2/20 loss: 0.7207: 100%|██████████| 92/92 [00:05<00:00, 16.76batch/s]


epoch 2, validation loss: 0.7142502069473267, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 3/20 loss: 0.6828: 100%|██████████| 92/92 [00:05<00:00, 16.78batch/s]


val loss - new best: 0.700144350528717
val accuracy - new best: 0.5318985395849347
val AUC - new best: 0.5469987736568227
epoch 3, validation loss: 0.700144350528717, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 4/20 loss: 0.6710: 100%|██████████| 92/92 [00:05<00:00, 16.78batch/s]


epoch 4, validation loss: 0.702039897441864, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 5/20 loss: 0.7091: 100%|██████████| 92/92 [00:05<00:00, 16.69batch/s]


val loss - new best: 0.6839051246643066
val accuracy - new best: 0.579554189085319
val AUC - new best: 0.6192512204245326
epoch 5, validation loss: 0.6839051246643066, epochs without improvement: 0/10
reached 5.
epoch 5, val loss: 0.6839051246643066, epochs without improvement: 0/10
val accuracy - new best: 0.579554189085319
val auc - new best: 0.6192512204245326


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 6/20 loss: 0.6602: 100%|██████████| 92/92 [00:05<00:00, 16.66batch/s]


val loss - new best: 0.6824221014976501
val accuracy - new best: 0.6133743274404304
val AUC - new best: 0.6613943379877556
epoch 6, validation loss: 0.6824221014976501, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 7/20 loss: 0.8475: 100%|██████████| 92/92 [00:05<00:00, 16.76batch/s]


epoch 7, validation loss: 0.76437908411026, epochs without improvement: 1/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 8/20 loss: 0.9171: 100%|██████████| 92/92 [00:05<00:00, 16.77batch/s]


epoch 8, validation loss: 0.8422108292579651, epochs without improvement: 2/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 9/20 loss: 1.0259: 100%|██████████| 92/92 [00:05<00:00, 16.78batch/s]


epoch 9, validation loss: 0.9364686608314514, epochs without improvement: 3/10
reached 9.
epoch 9, val loss: 0.9364686608314514, epochs without improvement: 3/10
val accuracy - new best: 0.6133743274404304
val auc - new best: 0.6613943379877556


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 10/20 loss: 0.9229: 100%|██████████| 92/92 [00:05<00:00, 16.79batch/s]


epoch 10, validation loss: 0.9742581844329834, epochs without improvement: 4/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 11/20 loss: 0.8237: 100%|██████████| 92/92 [00:05<00:00, 16.79batch/s]


epoch 11, validation loss: 0.9746822714805603, epochs without improvement: 5/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 12/20 loss: 0.8402: 100%|██████████| 92/92 [00:05<00:00, 16.79batch/s]


epoch 12, validation loss: 0.9871757626533508, epochs without improvement: 6/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 13/20 loss: 0.7875: 100%|██████████| 92/92 [00:05<00:00, 16.80batch/s]


epoch 13, validation loss: 0.9558269381523132, epochs without improvement: 7/10
reached 13.
epoch 13, val loss: 0.9558269381523132, epochs without improvement: 7/10
val accuracy - new best: 0.6133743274404304
val auc - new best: 0.6613943379877556


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 14/20 loss: 0.8751: 100%|██████████| 92/92 [00:05<00:00, 16.78batch/s]


epoch 14, validation loss: 0.9662612080574036, epochs without improvement: 8/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 15/20 loss: 1.0199: 100%|██████████| 92/92 [00:05<00:00, 16.80batch/s]


epoch 15, validation loss: 0.9597510695457458, epochs without improvement: 9/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 16/20 loss: 1.0232: 100%|██████████| 92/92 [00:05<00:00, 16.80batch/s]

epoch 16, validation loss: 0.9574492573738098, epochs without improvement: 10/10
stopping early at epoch 16. no improvement in validation loss for 10 consecutive epochs.
epoch 16, validation loss: 0.9574492573738098, epochs without improvement: 10/10
val accuracy - new best: 0.6133743274404304
val auc - new best: 0.6613943379877556





In [27]:
print_model_performance(classifier, X_cat_valid, X_cont_valid, y_valid)


AUC: 0.6912
Optimal Threshold: 0.0638
              precision    recall  f1-score   support

         0.0       0.72      0.49      0.58      1640
         1.0       0.61      0.80      0.69      1612

    accuracy                           0.65      3252
   macro avg       0.66      0.65      0.64      3252
weighted avg       0.66      0.65      0.64      3252




### With Pretraining, Private

In [28]:
EPOCHS = [1,3,9]
BATCH_SIZES = [4, 32, 128]
LRS = [3e-4, 3e-5]

classifier = FTTransformerModel(
    dim = 32,
    dim_out = 2,
    depth = 6,
    heads = 8,
    attn_dropout = 0.1,
    ff_dropout = 0.1,
    batch_size=128,
    num_epochs=20,
    lr=3e-4,
    load_best_model_when_trained=True,
    verbose=True,
    epsilon=3.0,
    partial_dp=True,
    partial_pretrain_config={
        'X_cat_pre': X_cat_train_llm,
        'X_cont_pre': X_cont_train_llm,
        'y_pre': y_train_llm,
        'categories': cat_cardinalities,
        'num_continuous': X_cont_train_llm.shape[1],
        'pre_epochs': 3,
        'pre_batch_size': 4,
        'pre_lr': 3e-4,
    }
)


classifier.fit(X_cat_train, X_cont_train, y_train.flatten(), cat_cardinalities, X_cont_train.shape[1])

X_cat_pre torch.Size([13969, 6])
X_cont_pre torch.Size([13969, 0])
y_pre torch.Size([13969])


pretraining Epoch 1/3 loss: 0.0499: 100%|██████████| 3143/3143 [00:41<00:00, 75.92batch/s]


pretraining val loss - new best: 0.106063112616539
pretraining epoch 1, validation Loss: 0.106063112616539, epochs without improvement: 0/10


pretraining Epoch 2/3 loss: 0.0423: 100%|██████████| 3143/3143 [00:41<00:00, 75.94batch/s]


pretraining val loss - new best: 0.09852214902639389
pretraining epoch 2, validation Loss: 0.09852214902639389, epochs without improvement: 0/10


pretraining Epoch 3/3 loss: 0.0392: 100%|██████████| 3143/3143 [00:41<00:00, 76.03batch/s]


pretraining val loss - new best: 0.09302705526351929
pretraining epoch 3, validation Loss: 0.09302705526351929, epochs without improvement: 0/10


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
epoch 1/20 loss: 1.7484:  38%|███▊      | 35/92 [00:02<00:03, 14.74batch/s]


KeyboardInterrupt: 

In [None]:
print_model_performance(classifier, X_cat_valid, X_cont_valid, y_valid)


AUC: 0.6827
Optimal Threshold: 0.0635
              precision    recall  f1-score   support

         0.0       0.70      0.49      0.58      1640
         1.0       0.60      0.79      0.68      1612

    accuracy                           0.64      3252
   macro avg       0.65      0.64      0.63      3252
weighted avg       0.65      0.64      0.63      3252


