In [1]:
import numpy as np
import os
import torch
import pandas as pd
import yaml

from sklearn.externals import joblib

from pytorch_utils.cfvae_models import CFVAEModel

In [2]:
outcome = 'los'
sensitive_variable = 'age'
data_path = 'data/'

# features_path = os.path.join(data_path, '{}_excluded'.format(sensitive_variable),'features')
features_path = os.path.join(data_path, 'features', str(0), '{}_excluded'.format(sensitive_variable))
label_path = os.path.join(data_path, 'labels')
# config_path = os.path.join(data_path, 'config', 'grid', 'baseline')
checkpoints_path = os.path.join(data_path, 'checkpoints', 'scratch', outcome)
performance_path = os.path.join(data_path, 'performance', 'scratch', outcome)

In [3]:
config_load_path = os.path.join(data_path, 'config', 'defaults', 'cfvae', outcome, sensitive_variable, 'model_config.yaml')
checkpoint_load_path = os.path.join(data_path, 'checkpoints', 'cfvae_default', outcome, sensitive_variable, str(1))
os.listdir(checkpoint_load_path)[0]

'1552683762.4261506.chk'

In [4]:
os.makedirs(checkpoints_path, exist_ok=True)
os.makedirs(performance_path, exist_ok=True)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
features_dict = joblib.load(os.path.join(features_path, 'features.pkl'))
master_label_dict = joblib.load(os.path.join(label_path, 'label_dict.pkl'))

In [7]:
data_dict = {split: features_dict[split]['features'] for split in features_dict.keys()}
label_dict = {split : master_label_dict[split][outcome] for split in master_label_dict.keys()}
group_dict = {split : master_label_dict[split][sensitive_variable] for split in master_label_dict.keys()}

In [8]:
group_map = pd.read_csv(os.path.join(label_path, '{}_map.csv'.format(sensitive_variable)))

In [9]:
# # with open(os.path.join(config_path, '{}.yaml'.format(grid_element)), 'r') as fp:
# #     config_dict = yaml.load(fp)
    
# CFVAE
# config_dict = {
#     # Standard parameters
#     'input_dim' : data_dict['train'].shape[1],
#     'num_groups' : len(np.unique(group_dict['train'])),
#     'lr' : 1e-3,
#     'lr_final_classifier' : 1e-3,
#     'gamma' : 0.99,
#     'num_epochs' : 10,
#     'iters_per_epoch' : 100,
#     'output_dim' : 2,
#     'batch_size' : 256,
#     'sparse' : True,
#     'sparse_mode' : 'binary',
#     # Parameters corresponding to the size of the VAE
#     'group_embed_dim' : 64,
#     'latent_dim' : 64,
#     'num_hidden' : 2,
#     'drop_prob' : 0.0,
#     'resnet' : False,
#     'normalize' : False,
#     # Parameters corresponding to the size of classifier
#     'hidden_dim_classifier' : 128,
#     'num_hidden_classifier' : 1,
#     'drop_prob_classifier' : 0.0,
#     'resnet_classifier' : False,
#     'normalize_classifier' : False,

#     # Lambda
#     'lambda_reconstruction' : 1e3,
#     'lambda_mmd' : 1e4,
#     'lambda_kl' : 0.0,
#     'lambda_classification' : 1e1,
#     'lambda_mmd_group' : 1e3
# }
with open(config_load_path, 'r') as fp:
    config_dict = yaml.load(fp)

In [10]:
config_dict

{'batch_size': 512,
 'drop_prob': 0.25,
 'drop_prob_classifier': 0.25,
 'gamma': None,
 'group_embed_dim': 64,
 'hidden_dim_classifier': 128,
 'input_dim': 368113,
 'iters_per_epoch': 100,
 'lambda_classification': 10.0,
 'lambda_kl': 0.0,
 'lambda_mmd': 10000.0,
 'lambda_mmd_group': 1000.0,
 'lambda_reconstruction': 1000.0,
 'latent_dim': 128,
 'lr': 0.0001,
 'lr_final_classifier': 0.001,
 'normalize': False,
 'normalize_classifier': True,
 'num_epochs': 20,
 'num_groups': 4,
 'num_hidden': 1,
 'num_hidden_classifier': 2,
 'output_dim': 2,
 'resnet': False,
 'resnet_classifier': False,
 'sparse': True,
 'sparse_mode': 'binary'}

In [11]:
config_dict_final_classifier = {
    'lr_final_classifier' : 1e-3,
    'lambda_final_classifier_cf' : 0e0,
    'lambda_clp' : 1e0,
    'lambda_clp_entropy' : 0e0,
    'num_epochs' : 5
#     'num_iters_per_epoch' : None,
    
}
config_dict = {**config_dict, **config_dict_final_classifier}

In [12]:
if sensitive_variable == 'gender':
    data_dict = {k: v[group_dict[k] < 2] for k,v in data_dict.items()}
    label_dict = {k: v[group_dict[k] < 2] for k,v in label_dict.items()}
    group_dict = {k: v[group_dict[k] < 2] for k,v in group_dict.items()}

In [13]:
model = CFVAEModel(config_dict)

In [14]:
model.load_weights(os.path.join(checkpoint_load_path, os.listdir(checkpoint_load_path)[0]))

In [15]:
%%time
result = model.train_final_classifier(data_dict, label_dict, group_dict)
# result = model.train_final_classifier(small_data_dict, small_label_dict, small_group_dict)

Epoch 0/4
----------
Phase: train:
 loss: 0.705469, classification: 0.412684, classification_cf: 0.399826, clp: 0.293164, clp_entropy: 1.079765,
Factual
 auc: 0.822840, auprc: 0.517786, brier: 0.128447,
Counterfactual
 auc: 0.830639, auprc: 0.510058, brier: 0.123129,
Phase: val:
 loss: 0.397353, classification: 0.396312, classification_cf: 0.333426, clp: 0.001028, clp_entropy: 0.978255,
Factual
 auc: 0.829614, auprc: 0.571277, brier: 0.122326,
Counterfactual
 auc: 0.905752, auprc: 0.710474, brier: 0.096928,
Best model updated
Epoch 1/4
----------
Phase: train:
 loss: 0.383696, classification: 0.313177, classification_cf: 0.305847, clp: 0.070529, clp_entropy: 0.932480,
Factual
 auc: 0.916244, auprc: 0.768294, brier: 0.090700,
Counterfactual
 auc: 0.918621, auprc: 0.747279, brier: 0.087772,
Phase: val:
 loss: 0.388917, classification: 0.388097, classification_cf: 0.303736, clp: 0.000813, clp_entropy: 0.905352,
Factual
 auc: 0.835235, auprc: 0.578249, brier: 0.120467,
Counterfactual
 auc:

In [16]:
result_eval = model.predict_final_classifier(data_dict, label_dict, group_dict, phases = ['train', 'val', 'test'])
result_eval[1]

{'train': {'auc': [0.9299105703931062],
  'auprc': [0.7988801608337359],
  'brier': [0.08344975838459158],
  'loss': [0.29502871379256246],
  'classification': [0.29429263293743135],
  'classification_cf': [0.2902029851078987],
  'clp': [0.0007358566188690642],
  'clp_entropy': [0.9003165664097309]},
 'val': {'auc': [0.8353877931954911],
  'auprc': [0.5790901869558563],
  'brier': [0.1204830829907415],
  'loss': [0.3885889660853606],
  'classification': [0.38778666808054996],
  'classification_cf': [0.3040125392950498],
  'clp': [0.000798234985454579],
  'clp_entropy': [0.9043718690859855]},
 'test': {'auc': [0.8378277177807129],
  'auprc': [0.5748208666818564],
  'brier': [0.1177468596058258],
  'loss': [0.38143051243745363],
  'classification': [0.38066202860612136],
  'classification_cf': [0.2992565310918368],
  'clp': [0.0007718362817846446],
  'clp_entropy': [0.9004612352217497]}}