In [1]:
from utils_prediction.database import (gbq_connect, gbq_query)
from utils_prediction.dataloader.mimic4 import dataloader
from utils_prediction.preprocessor import (fill_missing,discretizer,binary_discretizer,one_hot_encoder,prune_features)

from utils_prediction.nn.group_fairness import group_regularized_model
from utils_prediction.nn.robustness import group_robust_model

from sklearn.pipeline import Pipeline

import os
import pyarrow as pa
import pyarrow.parquet as pq

#### Get example features from GBQ and save to disk

In [2]:
## Establish connection with GBQ
c = gbq_connect(
    service_account_json_path = '/hpf/projects/lsung/creds/gbq/mimic.json', # change to your service account auth
    project_id = 'mimic-iv-ches'
    )

## Grab example data
df = gbq_query(c, """
    select * from `mimic-iv-ches.demo.mimic4_slice`
    """, verbose = False)

## Save to disk
path = 'data/analysis_id=demo'

if not os.path.exists(path): os.makedirs(path)

pq.write_table(
    pa.Table.from_pandas(df),
    f"{path}/features.parquet"
    )

Google Big Query Connection Established


#### Load data & split into train,val,test
- retain group info

In [3]:
data = dataloader(
    analysis_id = 'demo',
    features_fpath = 'data'
    ).load_features()
data = data.split(retain_group_var=True)

In [4]:
len(data.X_train), len(data.X_val), len(data.X_test)

(1400, 300, 300)

#### Preprocessing pipeline

In [5]:
## Pipeline
pipe = Pipeline([
    ('fill missings',fill_missing(config={'count':0,'marital_status':'None'})),
    ('prune features',prune_features(special_cols={'count':0})),
    ('discretize counts', binary_discretizer(feature_tags_to_include= ['count'])),
    ('discretize measurements', discretizer(feature_tags_to_include = ['measurement'])),
    ('one hot encode', one_hot_encoder(feature_tags_to_exclude = ['count','group_var']))
    ])

#### Preprocess data

In [6]:
data.X_train = pipe.fit_transform(data.X_train)
data.X_val = pipe.transform(data.X_val)
data.X_test = pipe.transform(data.X_test)

#### Generate torch dataloaders
- use retained group info to define attribtue

In [7]:
loaders = data.to_torch(
    group_var_name='group_var',
    balance_groups=False
)

['2008 - 2010' '2011 - 2013' '2014 - 2016' '2017 - 2019']


#### Group Coral Model

In [8]:
model_class = group_regularized_model(model_type="group_coral")
m = model_class(**{
        "num_epochs":10,
        "num_hidden":2,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "drop_prob":0.5,
        "lambda_group_regularization":100,
        "group_regularization_mode":"group"
    }
)
#m.model.parameters
m.train(loaders)

cpu
Epoch 0/9
----------


  return torch._C._cuda_getDeviceCount() > 0


Phase: train:
                 metric  performance
0                   auc     0.524782
1                 auprc     0.079076
2                 brier     0.192339
3              loss_bce     0.570713
4       specificity_0.5     0.749863
5         precision_0.5     0.078960
6            recall_0.5     0.280682
0                  loss     0.983325
1            supervised     0.570713
2  group_regularization     0.004126
Phase: val:
                 metric  performance
0                   auc     0.570798
1                 auprc     0.094337
2                 brier     0.080998
3              loss_bce     0.311385
4       specificity_0.5     1.000000
5         precision_0.5     0.000000
6            recall_0.5     0.000000
0                  loss     0.442179
1            supervised     0.311385
2  group_regularization     0.001308
Best model updated
Epoch 1/9
----------
Phase: train:
                 metric  performance
0                   auc     0.544395
1                 auprc     0.07

{'performance':      phase  epoch                metric  performance
 0    train      0                   auc     0.524782
 1    train      0                 auprc     0.079076
 2    train      0                 brier     0.192339
 3    train      0              loss_bce     0.570713
 4    train      0       specificity_0.5     0.749863
 ..     ...    ...                   ...          ...
 195    val      9         precision_0.5     0.000000
 196    val      9            recall_0.5     0.000000
 197    val      9                  loss     0.268739
 198    val      9            supervised     0.243355
 199    val      9  group_regularization     0.000254
 
 [200 rows x 4 columns]}

#### Group MMD Model

In [9]:
model_class = group_regularized_model(model_type="mean_prediction")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "lambda_group_regularization":1,
        "group_regularization_mode":"group",
    }
)

cpu


In [10]:
m.train(loaders)

Epoch 0/9
----------
Phase: train:
                 metric  performance
0                   auc     0.511476
1                 auprc     0.071333
2                 brier     0.145519
3              loss_bce     0.463508
4       specificity_0.5     0.865894
5         precision_0.5     0.062169
6            recall_0.5     0.113137
0                  loss     0.491889
1            supervised     0.463508
2  group_regularization     0.028382
Phase: val:
                 metric  performance
0                   auc     0.469261
1                 auprc     0.139968
2                 brier     0.075727
3              loss_bce     0.296358
4       specificity_0.5     1.000000
5         precision_0.5     0.000000
6            recall_0.5     0.000000
0                  loss     0.311086
1            supervised     0.296358
2  group_regularization     0.014728
Best model updated
Epoch 1/9
----------
Phase: train:
                 metric  performance
0                   auc     0.555899
1          

{'performance':      phase  epoch                metric  performance
 0    train      0                   auc     0.511476
 1    train      0                 auprc     0.071333
 2    train      0                 brier     0.145519
 3    train      0              loss_bce     0.463508
 4    train      0       specificity_0.5     0.865894
 ..     ...    ...                   ...          ...
 195    val      9         precision_0.5     0.000000
 196    val      9            recall_0.5     0.000000
 197    val      9                  loss     0.331923
 198    val      9            supervised     0.260788
 199    val      9  group_regularization     0.071135
 
 [200 rows x 4 columns]}

#### Group Adversarial Model
- using final layer activations as inputs to discriminator instead of logits of classifier output

#### Gradient reversal between features & discriminator
- Objective: min L_cls + Lambd * L_adv, min L_adv

In [11]:
model_class = group_regularized_model(model_type="adversarial")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "output_dim_discriminator":4, #number of groups,
        "lambda_group_regularization":0.1,
        "use_layer_activations":True,
        "reverse_gradients":True,
    }
)

4
cpu


In [12]:
## train model
m.train(loaders,phases=['train','val'])

Epoch 0/9
----------
Phase: train:
              metric  performance
0                auc     0.472209
1              auprc     0.063448
2              brier     0.172484
3           loss_bce     0.521763
4    specificity_0.5     0.771636
5      precision_0.5     0.058976
6         recall_0.5     0.192221
0               loss     2.011852
1         supervised     0.521763
2      discriminator     1.490088
3  discriminator_alt     0.225503
Phase: val:
              metric  performance
0                auc     0.567037
1              auprc     0.121336
2              brier     0.072796
3           loss_bce     0.283589
4    specificity_0.5     1.000000
5      precision_0.5     0.000000
6         recall_0.5     0.000000
0               loss     1.898414
1         supervised     0.283589
2      discriminator     1.614825
3  discriminator_alt     0.199086
Best model updated
Epoch 1/9
----------
Phase: train:
              metric  performance
0                auc     0.546278
1              

{'performance':      phase  epoch             metric   performance
 0    train      0                auc  4.722095e-01
 1    train      0              auprc  6.344767e-02
 2    train      0              brier  1.724837e-01
 3    train      0           loss_bce  5.217633e-01
 4    train      0    specificity_0.5  7.716360e-01
 ..     ...    ...                ...           ...
 215    val      9         recall_0.5  0.000000e+00
 216    val      9               loss  1.912928e+01
 217    val      9         supervised  2.350570e-01
 218    val      9      discriminator  1.889422e+01
 219    val      9  discriminator_alt  1.058300e-08
 
 [220 rows x 4 columns]}

##### Shuffle Group Labels
- Objective: min L_cls + Lambd * L_adv, min L_adv

In [13]:
model_class = group_regularized_model(model_type="adversarial")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "output_dim_discriminator":4, #number of groups,
        "lambda_group_regularization":0.1,
        "use_layer_activations":True,
        "shuffle_group_labels":True,
    }
)
#m.model.parameters

4
cpu


In [14]:
## train model
m.train(loaders,phases=['train','val'])

Epoch 0/9
----------
Phase: train:
              metric  performance
0                auc     0.511055
1              auprc     0.077876
2              brier     0.166854
3           loss_bce     0.509109
4    specificity_0.5     0.791705
5      precision_0.5     0.074290
6         recall_0.5     0.200508
0               loss     0.653808
1         supervised     0.509109
2      discriminator     1.446991
3  discriminator_alt     0.235382
Phase: val:
              metric  performance
0                auc     0.658927
1              auprc     0.171869
2              brier     0.071829
3           loss_bce     0.280031
4    specificity_0.5     1.000000
5      precision_0.5     0.000000
6         recall_0.5     0.000000
0               loss     0.422253
1         supervised     0.280031
2      discriminator     1.422219
3  discriminator_alt     0.241444
Best model updated
Epoch 1/9
----------
Phase: train:
              metric  performance
0                auc     0.614473
1              

{'performance':      phase  epoch             metric  performance
 0    train      0                auc     0.511055
 1    train      0              auprc     0.077876
 2    train      0              brier     0.166854
 3    train      0           loss_bce     0.509109
 4    train      0    specificity_0.5     0.791705
 ..     ...    ...                ...          ...
 215    val      9         recall_0.5     0.000000
 216    val      9               loss     0.373863
 217    val      9         supervised     0.233848
 218    val      9      discriminator     1.400150
 219    val      9  discriminator_alt     0.246563
 
 [220 rows x 4 columns]}

##### Only changing the objective
Objective: min L_cls - Lambd * L_adv, min L_adv

In [15]:
model_class = group_regularized_model(model_type="adversarial")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "output_dim_discriminator":4, #number of groups,
        "lambda_group_regularization":0.1,
        "use_layer_activations":True
    }
)
#m.model.parameters

4
cpu


In [16]:
m.train(loaders,phases=['train','val'])

Epoch 0/9
----------
Phase: train:
              metric  performance
0                auc     0.480099
1              auprc     0.069826
2              brier     0.187417
3           loss_bce     0.556066
4    specificity_0.5     0.746808
5      precision_0.5     0.065630
6         recall_0.5     0.225548
0               loss     0.395368
1         supervised     0.556066
2      discriminator     1.606978
3  discriminator_alt     0.201162
Phase: val:
              metric  performance
0                auc     0.438031
1              auprc     0.066942
2              brier     0.080499
3           loss_bce     0.310721
4    specificity_0.5     1.000000
5      precision_0.5     0.000000
6         recall_0.5     0.000000
0               loss     0.133595
1         supervised     0.310721
2      discriminator     1.771265
3  discriminator_alt     0.170256
Best model updated
Epoch 1/9
----------
Phase: train:
              metric  performance
0                auc     0.542796
1              

{'performance':      phase  epoch             metric   performance
 0    train      0                auc  4.800993e-01
 1    train      0              auprc  6.982569e-02
 2    train      0              brier  1.874172e-01
 3    train      0           loss_bce  5.560660e-01
 4    train      0    specificity_0.5  7.468077e-01
 ..     ...    ...                ...           ...
 215    val      9         recall_0.5  0.000000e+00
 216    val      9               loss -1.637712e+00
 217    val      9         supervised  2.394783e-01
 218    val      9      discriminator  1.877190e+01
 219    val      9  discriminator_alt  8.186854e-09
 
 [220 rows x 4 columns]}

#### Group IRM Model

In [17]:
model_class = group_regularized_model(model_type="group_irm")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "lambda_group_regularization":100
    }
)
#m.model.parameters

cpu


In [18]:
## train model
m.train(loaders,phases=['train','val'])

Epoch 0/9
----------
Phase: train:
                 metric  performance
0                   auc     0.426774
1                 auprc     0.062304
2                 brier     0.262056
3              loss_bce     0.718229
4       specificity_0.5     0.471649
5         precision_0.5     0.058461
6            recall_0.5     0.405744
0                  loss     4.959316
1            supervised     0.718229
2  group_regularization     0.042411
Phase: val:
                 metric  performance
0                   auc     0.471387
1                 auprc     0.073846
2                 brier     0.241855
3              loss_bce     0.676356
4       specificity_0.5     0.557554
5         precision_0.5     0.082090
6            recall_0.5     0.500000
0                  loss     0.790089
1            supervised     0.676356
2  group_regularization     0.001137
Best model updated
Epoch 1/9
----------
Phase: train:
                 metric  performance
0                   auc     0.420637
1          

{'performance':      phase  epoch                metric  performance
 0    train      0                   auc     0.426774
 1    train      0                 auprc     0.062304
 2    train      0                 brier     0.262056
 3    train      0              loss_bce     0.718229
 4    train      0       specificity_0.5     0.471649
 ..     ...    ...                   ...          ...
 195    val      9         precision_0.5     0.073529
 196    val      9            recall_0.5     0.454545
 197    val      9                  loss     0.749880
 198    val      9            supervised     0.681769
 199    val      9  group_regularization     0.000681
 
 [200 rows x 4 columns]}

#### Group DRO

In [19]:
model_class = group_robust_model(model_type="loss")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "num_groups":4,
        "lr_lambda":1e-1
    }
)

cpu


In [20]:
## train model
m.train(loaders,phases=['train','val'])

Epoch 0/9
----------
Phase: train:
            metric  performance
0              auc     0.518929
1            auprc     0.070546
2            brier     0.156061
3         loss_bce     0.483661
4  specificity_0.5     0.817290
5    precision_0.5     0.068833
6       recall_0.5     0.183059
0             loss     0.484054
Phase: val:
            metric  performance
0              auc     0.574722
1            auprc     0.105182
2            brier     0.072923
3         loss_bce     0.283061
4  specificity_0.5     1.000000
5    precision_0.5     0.000000
6       recall_0.5     0.000000
0             loss     0.288889
Best model updated
Epoch 1/9
----------
Phase: train:
            metric  performance
0              auc     0.627157
1            auprc     0.107032
2            brier     0.068522
3         loss_bce     0.262198
4  specificity_0.5     1.000000
5    precision_0.5     0.000000
6       recall_0.5     0.000000
0             loss     0.270888
Phase: val:
            metric  per

{'performance':      phase  epoch           metric  performance
 0    train      0              auc     0.518929
 1    train      0            auprc     0.070546
 2    train      0            brier     0.156061
 3    train      0         loss_bce     0.483661
 4    train      0  specificity_0.5     0.817290
 ..     ...    ...              ...          ...
 155    val      9         loss_bce     0.236448
 156    val      9  specificity_0.5     1.000000
 157    val      9    precision_0.5     0.000000
 158    val      9       recall_0.5     0.000000
 159    val      9             loss     0.310260
 
 [160 rows x 4 columns]}

#### Group DRO with IRM Penalty

In [21]:
model_class = group_robust_model(model_type="IRM_penalty_proxy")
m = model_class(**{
        "num_epochs":10,
        "input_dim":next(iter(loaders['train']))['features'].shape[1],
        "num_groups":4,
        "lr_lambda":1e-1
    }
)

cpu


In [22]:
m.train(loaders)

Epoch 0/9
----------
Phase: train:
            metric  performance
0              auc     0.516736
1            auprc     0.075077
2            brier     0.106699
3         loss_bce     0.373032
4  specificity_0.5     0.975811
5    precision_0.5     0.084395
6       recall_0.5     0.028978
0             loss     0.371462
Phase: val:
            metric  performance
0              auc     0.538914
1            auprc     0.096358
2            brier     0.069749
3         loss_bce     0.271227
4  specificity_0.5     1.000000
5    precision_0.5     0.000000
6       recall_0.5     0.000000
0             loss     0.269769
Best model updated
Epoch 1/9
----------
Phase: train:
            metric  performance
0              auc     0.636562
1            auprc     0.129468
2            brier     0.065598
3         loss_bce     0.252060
4  specificity_0.5     1.000000
5    precision_0.5     0.000000
6       recall_0.5     0.000000
0             loss     0.249095
Phase: val:
            metric  per

{'performance':      phase  epoch           metric  performance
 0    train      0              auc     0.516736
 1    train      0            auprc     0.075077
 2    train      0            brier     0.106699
 3    train      0         loss_bce     0.373032
 4    train      0  specificity_0.5     0.975811
 ..     ...    ...              ...          ...
 155    val      9         loss_bce     0.237197
 156    val      9  specificity_0.5     1.000000
 157    val      9    precision_0.5     0.000000
 158    val      9       recall_0.5     0.000000
 159    val      9             loss     0.231907
 
 [160 rows x 4 columns]}