In [1]:
import numpy as np
import os
import torch
import pandas as pd
import yaml

from sklearn.externals import joblib

from pytorch_utils.cfvae_models import CFVAEModel

In [2]:
outcome = 'los'
sensitive_variable = 'age'
data_path = 'data/'

features_path = os.path.join(data_path, 'features', str(0), '{}_excluded'.format(sensitive_variable))
label_path = os.path.join(data_path, 'labels')
config_path = os.path.join(data_path, 'config', 'grid', 'baseline')
checkpoints_path = os.path.join(data_path, 'checkpoints', 'scratch', outcome)
performance_path = os.path.join(data_path, 'performance', 'scratch', outcome)

In [3]:
os.makedirs(checkpoints_path, exist_ok=True)
os.makedirs(performance_path, exist_ok=True)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
features_dict = joblib.load(os.path.join(features_path, 'features.pkl'))
master_label_dict = joblib.load(os.path.join(label_path, 'label_dict.pkl'))

In [6]:
data_dict = {split: features_dict[split]['features'] for split in features_dict.keys()}
label_dict = {split : master_label_dict[split][outcome] for split in master_label_dict.keys()}
group_dict = {split : master_label_dict[split][sensitive_variable] for split in master_label_dict.keys()}

In [7]:
group_map = pd.read_csv(os.path.join(label_path, '{}_map.csv'.format(sensitive_variable)))

In [8]:
# # with open(os.path.join(config_path, '{}.yaml'.format(grid_element)), 'r') as fp:
# #     config_dict = yaml.load(fp)
    
# CFVAE
config_dict = {
    # Standard parameters
    'input_dim' : data_dict['train'].shape[1],
    'num_groups' : len(np.unique(group_dict['train'])),
    'lr' : 1e-3,
    'lr_final_classifier' : 1e-3,
    'gamma' : 0.99,
    'num_epochs' : 10,
    'iters_per_epoch' : 100,
    'output_dim' : 2,
    'batch_size' : 256,
    'sparse' : True,
    'sparse_mode' : 'binary',
    
    # Parameters corresponding to the size of the VAE
    'group_embed_dim' : 64,
    'latent_dim' : 64,
    'num_hidden' : 2,
    'drop_prob' : 0.0,
    'resnet' : False,
    'normalize' : False,
    
    # Parameters corresponding to the size of classifier
    'hidden_dim_classifier' : 128,
    'num_hidden_classifier' : 1,
    'drop_prob_classifier' : 0.0,
    'resnet_classifier' : False,
    'normalize_classifier' : False,

    # Lambda
    'lambda_reconstruction' : 1e3,
    'lambda_mmd' : 1e4,
    'lambda_kl' : 0.0,
    'lambda_classification' : 1e1,
    'lambda_mmd_group' : 1e3
}

In [9]:
if sensitive_variable == 'gender':
    data_dict = {k: v[group_dict[k] < 2] for k,v in data_dict.items()}
    label_dict = {k: v[group_dict[k] < 2] for k,v in label_dict.items()}
    group_dict = {k: v[group_dict[k] < 2] for k,v in group_dict.items()}

In [10]:
model = CFVAEModel(config_dict)
for child in model.model.children():
    print(child)

VAEEncoder(
  (encoder): FeedforwardNet(
    (layers): ModuleList(
      (0): HiddenLinearLayer(
        (linear): LinearLayerWrapper(
          (linear): EmbeddingBagLinear(
            in_features=368117, out_features=256, bias=True
            (embed): EmbeddingBag(368117, 256, mode=sum)
          )
        )
        (dropout): Dropout(p=0.0)
      )
      (1): HiddenLinearLayer(
        (linear): LinearLayerWrapper(
          (linear): Linear(in_features=256, out_features=128, bias=True)
        )
        (dropout): Dropout(p=0.0)
      )
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (reparameterization_layer): ReparameterizationLayer()
)
ConditionalDecoder(
  (decoder): FeedforwardNet(
    (layers): ModuleList(
      (0): HiddenLinearLayer(
        (linear): LinearLayerWrapper(
          (linear): Linear(in_features=128, out_features=256, bias=True)
        )
        (dropout): Dropout(p=0.0)
      )
      (1): HiddenLinearLayer(
        (linear): Lin

In [None]:
%%time
result = model.train(data_dict, label_dict, group_dict)

Epoch 0/9
----------
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)

(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(256, 368113)
(256, 368117)
(164, 368113)
(164, 368117)
Phase: val:
 loss: 12.539723, elbo: 4.680308, mmd: 0.000245, reconstruction: 0.003165, kl: 4.205809, cla

In [None]:
# result_final_classifier = model.train(data_dict, label_dict, group_dict)
# result_eval = model.predict(data_dict, label_dict, group_dict, phases = ['val', 'test'])

In [None]:
# model.process_result_dict(result)

In [None]:
# model.process_result_dict(result_eval[1])