In [1]:
import numpy as np
import os
import torch
import pandas as pd
import yaml

from sklearn.externals import joblib

from pytorch_utils.cfvae_models import CFVAEModel

In [2]:
outcome = 'los'
sensitive_variable = 'age'
data_path = 'data/'

# features_path = os.path.join(data_path, '{}_excluded'.format(sensitive_variable),'features')
features_path = os.path.join(data_path, 'features', str(0), '{}_excluded'.format(sensitive_variable))
label_path = os.path.join(data_path, 'labels')
# config_path = os.path.join(data_path, 'config', 'grid', 'baseline')
checkpoints_path = os.path.join(data_path, 'checkpoints', 'scratch', outcome)
performance_path = os.path.join(data_path, 'performance', 'scratch', outcome)

In [3]:
config_load_path = os.path.join(data_path, 'config', 'defaults', 'cfvae', outcome, sensitive_variable, 'model_config.yaml')
checkpoint_load_path = os.path.join(data_path, 'checkpoints', 'cfvae_default', outcome, sensitive_variable, str(1))
os.listdir(checkpoint_load_path)[0]

'1552683762.4261506.chk'

In [4]:
os.makedirs(checkpoints_path, exist_ok=True)
os.makedirs(performance_path, exist_ok=True)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
features_dict = joblib.load(os.path.join(features_path, 'features.pkl'))
master_label_dict = joblib.load(os.path.join(label_path, 'label_dict.pkl'))

In [7]:
data_dict = {split: features_dict[split]['features'] for split in features_dict.keys()}
label_dict = {split : master_label_dict[split][outcome] for split in master_label_dict.keys()}
group_dict = {split : master_label_dict[split][sensitive_variable] for split in master_label_dict.keys()}

In [8]:
group_map = pd.read_csv(os.path.join(label_path, '{}_map.csv'.format(sensitive_variable)))

In [9]:
# # with open(os.path.join(config_path, '{}.yaml'.format(grid_element)), 'r') as fp:
# #     config_dict = yaml.load(fp)
    
# CFVAE
# config_dict = {
#     # Standard parameters
#     'input_dim' : data_dict['train'].shape[1],
#     'num_groups' : len(np.unique(group_dict['train'])),
#     'lr' : 1e-3,
#     'lr_final_classifier' : 1e-3,
#     'gamma' : 0.99,
#     'num_epochs' : 10,
#     'iters_per_epoch' : 100,
#     'output_dim' : 2,
#     'batch_size' : 256,
#     'sparse' : True,
#     'sparse_mode' : 'binary',
#     # Parameters corresponding to the size of the VAE
#     'group_embed_dim' : 64,
#     'latent_dim' : 64,
#     'num_hidden' : 2,
#     'drop_prob' : 0.0,
#     'resnet' : False,
#     'normalize' : False,
#     # Parameters corresponding to the size of classifier
#     'hidden_dim_classifier' : 128,
#     'num_hidden_classifier' : 1,
#     'drop_prob_classifier' : 0.0,
#     'resnet_classifier' : False,
#     'normalize_classifier' : False,

#     # Lambda
#     'lambda_reconstruction' : 1e3,
#     'lambda_mmd' : 1e4,
#     'lambda_kl' : 0.0,
#     'lambda_classification' : 1e1,
#     'lambda_mmd_group' : 1e3
# }
with open(config_load_path, 'r') as fp:
    config_dict = yaml.load(fp)

In [10]:
config_dict

{'batch_size': 512,
 'drop_prob': 0.25,
 'drop_prob_classifier': 0.25,
 'gamma': None,
 'group_embed_dim': 64,
 'hidden_dim_classifier': 128,
 'input_dim': 368113,
 'iters_per_epoch': 100,
 'lambda_classification': 10.0,
 'lambda_kl': 0.0,
 'lambda_mmd': 10000.0,
 'lambda_mmd_group': 1000.0,
 'lambda_reconstruction': 1000.0,
 'latent_dim': 128,
 'lr': 0.0001,
 'lr_final_classifier': 0.001,
 'normalize': False,
 'normalize_classifier': True,
 'num_epochs': 20,
 'num_groups': 4,
 'num_hidden': 1,
 'num_hidden_classifier': 2,
 'output_dim': 2,
 'resnet': False,
 'resnet_classifier': False,
 'sparse': True,
 'sparse_mode': 'binary'}

In [11]:
config_dict_final_classifier = {
    'lr_final_classifier' : 1e-3,
    'lambda_final_classifier_cf' : 0e0,
    'lambda_clp' : 1e1,
    'num_epochs' : 10
#     'num_iters_per_epoch' : None,
    
}
config_dict = {**config_dict, **config_dict_final_classifier}

In [12]:
if sensitive_variable == 'gender':
    data_dict = {k: v[group_dict[k] < 2] for k,v in data_dict.items()}
    label_dict = {k: v[group_dict[k] < 2] for k,v in label_dict.items()}
    group_dict = {k: v[group_dict[k] < 2] for k,v in group_dict.items()}

In [13]:
model = CFVAEModel(config_dict)

In [14]:
model.load_weights(os.path.join(checkpoint_load_path, os.listdir(checkpoint_load_path)[0]))

In [15]:
# from collections import OrderedDict
# new_dict = OrderedDict()
# for key, value in model.model.state_dict().items():
#     if 'classifier' in key:
#         new_dict[key.split('classifier.')[1]] = value
        
# model.final_classifier.load_state_dict(new_dict)

In [16]:
%%time
result = model.train_final_classifier(data_dict, label_dict, group_dict)
# result = model.train_final_classifier(small_data_dict, small_label_dict, small_group_dict)

Epoch 0/9
----------
Phase: train:
 loss: 2.989449, classification: 0.533988, classification_cf: 0.526200, clp: 0.245713,
Factual
 auc: 0.628520, auprc: 0.272385, brier: 0.174093,
Counterfactual
 auc: 0.636100, auprc: 0.265299, brier: 0.170472,
Phase: val:
 loss: 0.515817, classification: 0.514401, classification_cf: 0.499870, clp: 0.000141,
Factual
 auc: 0.725882, auprc: 0.414974, brier: 0.164804,
Counterfactual
 auc: 0.759414, auprc: 0.451105, brier: 0.158066,
Best model updated
Epoch 1/9
----------
Phase: train:
 loss: 0.618554, classification: 0.510718, classification_cf: 0.502379, clp: 0.010784,
Factual
 auc: 0.715831, auprc: 0.371831, brier: 0.163706,
Counterfactual
 auc: 0.727256, auprc: 0.367994, brier: 0.159802,
Phase: val:
 loss: 0.510559, classification: 0.510252, classification_cf: 0.493113, clp: 0.000031,
Factual
 auc: 0.746304, auprc: 0.456260, brier: 0.163133,
Counterfactual
 auc: 0.778634, auprc: 0.497388, brier: 0.155320,
Best model updated
Epoch 2/9
----------
Phase: 

In [17]:
result_eval = model.predict_final_classifier(data_dict, label_dict, group_dict, phases = ['train', 'val', 'test'])
result_eval[1]

{'train': {'auc': [0.8582821340541957],
  'auprc': [0.6876788878264035],
  'brier': [0.14644016707069957],
  'loss': [0.4619004347920418],
  'classification': [0.4618932661414146],
  'classification_cf': [0.44733342707157137],
  'clp': [0.018759957319125532]},
 'val': {'auc': [0.7949950148590735],
  'auprc': [0.538997197176007],
  'brier': [0.1498865325039038],
  'loss': [0.4725887534710077],
  'classification': [0.4725667524796266],
  'classification_cf': [0.44565975207548875],
  'clp': [0.017238417067206822]},
 'test': {'auc': [0.7938179056140078],
  'auprc': [0.5264079380053168],
  'brier': [0.14699792897199132],
  'loss': [0.46595505338448745],
  'classification': [0.4659344955132558],
  'classification_cf': [0.4407166391611099],
  'clp': [0.016661223024129868]}}

In [18]:
# result_final_classifier = model.train(data_dict, label_dict, group_dict)
# result_eval = model.predict(data_dict, label_dict, group_dict, phases = ['val', 'test'])

In [19]:
# model.process_result_dict(result)

In [20]:
# model.process_result_dict(result_eval[1])