# Run GANITE on Twins dataset

In [1]:
import numpy as np
import pandas as pd
from scipy.special import expit
import tensorflow as tf
from sklearn.model_selection import train_test_split
import itertools
import logging
# switch to logging.DEBUG to debug model
logging.basicConfig(level=logging.INFO)

### Data Preprocesssing

In [2]:
from data.Twins import Twins
twins = Twins()
num_patients, num_features = twins.X.shape
opt_y = twins.one_year_mortality(twins.Y)
T = twins.treatment_assignment(twins.X)
Y = twins.observable_outcomes(opt_y, T)

### Split data into 56/24/20 train, validation, test

In [3]:
train_X, test_X, train_T, test_T, train_Y, test_Y, train_OptY, test_OptY = train_test_split(
    twins.X, T, Y, opt_y, test_size=0.2)
train_X, validate_X, train_T, validate_T, train_Y, validate_Y, train_OptY, validate_OptY = train_test_split(
    train_X, train_T, train_Y, train_OptY, test_size=0.3)

dim_outcome = test_OptY.shape[1]

### Train model

In [4]:
from api import Model
num_iterations = 1000
num_kk = 10
_alpha = 1
_mini_batch_size = 128
_h_dim = 30
ganite = Model('GANITE',  num_kk, num_iterations, _alpha, _mini_batch_size, int(_h_dim))
ganite.fit(train_X, train_Y, train_T, dim_outcome)



















Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
























100%|██████████| 1000/1000 [00:15<00:00, 62.66it/s]
100%|██████████| 1000/1000 [00:01<00:00, 626.31it/s]


### Predict outcome on new features

In [5]:
result = ganite.predict(test_X)

### Evaluate with metrics 

In [6]:
from api.metrics import PEHE
ganite.test(test_X, test_OptY, metric=PEHE)







0.08639763

## Hyperparameter tuning

Limitations of this approach:
- here I am running the model once per hyperparameter combination. Should do 100 iterations to match workings in paper.
- naive early stopping approach by using num_iterations as hyperparameter. TODO use more advanced approach e.g. https://gist.github.com/ryanpeach/9ef833745215499e77a2a92e71f89ce2
- Maybe could using K fold CV search here but using 56/24/20 split to match paper methodology

In [7]:
# hyperparameters
mini_batch_size = {32, 64, 128, 256}
alpha = {0, 0.1, 0.5, 1, 2, 5, 10}
h_dim = {num_features, np.ceil(num_features/2.), np.ceil(num_features/3.), np.ceil(num_features/4.), np.ceil(num_features/5.)}
num_iterations = {2000,1000, 500, 250, 125, 60}

In [None]:
from api import Model
from api.metrics import PEHE
import json
num_kk = 10
results = {}

for _num_iter, _alpha, _mini_batch_size, _h_dim in itertools.product(num_iterations, alpha, mini_batch_size, h_dim):
    ganite = Model('GANITE', num_kk, _num_iter, _alpha, _mini_batch_size, int(_h_dim))
    
    # fit on train set     
    ganite.fit(train_X, train_T, train_Y, dim_outcome)
    # test on validation set     
    pehe = ganite.test(validate_X, validate_OptY, metric=PEHE)
    results[(_num_iter, _alpha, _mini_batch_size, _h_dim)] = pehe
    
with open('results.json', 'w') as f:
    f.write(json.dumps({str(k):str(v) for k,v in results.items()}))

100%|██████████| 1000/1000 [00:13<00:00, 74.27it/s]
100%|██████████| 1000/1000 [00:01<00:00, 603.76it/s]
100%|██████████| 1000/1000 [00:13<00:00, 73.88it/s]
100%|██████████| 1000/1000 [00:01<00:00, 651.63it/s]
100%|██████████| 1000/1000 [00:13<00:00, 73.14it/s]
100%|██████████| 1000/1000 [00:01<00:00, 687.21it/s]
100%|██████████| 1000/1000 [00:14<00:00, 71.10it/s]
100%|██████████| 1000/1000 [00:01<00:00, 677.99it/s]
100%|██████████| 1000/1000 [00:14<00:00, 70.00it/s]
100%|██████████| 1000/1000 [00:01<00:00, 682.79it/s]
100%|██████████| 1000/1000 [00:15<00:00, 65.94it/s]
100%|██████████| 1000/1000 [00:01<00:00, 627.02it/s]
100%|██████████| 1000/1000 [00:14<00:00, 66.90it/s]
100%|██████████| 1000/1000 [00:01<00:00, 684.27it/s]
100%|██████████| 1000/1000 [00:15<00:00, 65.01it/s]
100%|██████████| 1000/1000 [00:01<00:00, 649.14it/s]
100%|██████████| 1000/1000 [00:16<00:00, 61.68it/s]
100%|██████████| 1000/1000 [00:01<00:00, 615.90it/s]
100%|██████████| 1000/1000 [00:18<00:00, 54.04it/s]
100

100%|██████████| 1000/1000 [00:17<00:00, 57.08it/s]
100%|██████████| 1000/1000 [00:01<00:00, 608.48it/s]
100%|██████████| 1000/1000 [00:18<00:00, 53.13it/s]
100%|██████████| 1000/1000 [00:01<00:00, 564.95it/s]
100%|██████████| 1000/1000 [00:14<00:00, 68.73it/s]
100%|██████████| 1000/1000 [00:01<00:00, 557.05it/s]
100%|██████████| 1000/1000 [00:15<00:00, 65.88it/s]
100%|██████████| 1000/1000 [00:01<00:00, 649.08it/s]
100%|██████████| 1000/1000 [00:14<00:00, 67.37it/s]
100%|██████████| 1000/1000 [00:01<00:00, 617.48it/s]
100%|██████████| 1000/1000 [00:16<00:00, 62.28it/s]
100%|██████████| 1000/1000 [00:01<00:00, 641.39it/s]
100%|██████████| 1000/1000 [00:15<00:00, 64.33it/s]
100%|██████████| 1000/1000 [00:01<00:00, 620.22it/s]
100%|██████████| 1000/1000 [00:15<00:00, 63.55it/s]
100%|██████████| 1000/1000 [00:01<00:00, 611.67it/s]
100%|██████████| 1000/1000 [00:15<00:00, 64.22it/s]
100%|██████████| 1000/1000 [00:01<00:00, 652.59it/s]
100%|██████████| 1000/1000 [00:16<00:00, 60.73it/s]
100

100%|██████████| 2000/2000 [00:28<00:00, 71.24it/s]
100%|██████████| 2000/2000 [00:03<00:00, 657.22it/s]
100%|██████████| 2000/2000 [00:29<00:00, 67.25it/s]
100%|██████████| 2000/2000 [00:02<00:00, 679.62it/s]
100%|██████████| 2000/2000 [00:30<00:00, 65.52it/s]
100%|██████████| 2000/2000 [00:03<00:00, 636.34it/s]
100%|██████████| 2000/2000 [00:31<00:00, 63.66it/s]
100%|██████████| 2000/2000 [00:03<00:00, 666.44it/s]
100%|██████████| 2000/2000 [00:26<00:00, 76.86it/s]
100%|██████████| 2000/2000 [00:02<00:00, 740.33it/s]
100%|██████████| 2000/2000 [00:25<00:00, 77.42it/s]
100%|██████████| 2000/2000 [00:02<00:00, 776.10it/s]
100%|██████████| 2000/2000 [00:27<00:00, 73.66it/s]
100%|██████████| 2000/2000 [00:02<00:00, 735.41it/s]
100%|██████████| 2000/2000 [00:27<00:00, 72.58it/s]
100%|██████████| 2000/2000 [00:02<00:00, 732.54it/s]
100%|██████████| 2000/2000 [00:28<00:00, 71.25it/s]
100%|██████████| 2000/2000 [00:02<00:00, 723.73it/s]
100%|██████████| 2000/2000 [00:29<00:00, 68.14it/s]
100

100%|██████████| 2000/2000 [00:30<00:00, 66.29it/s]
100%|██████████| 2000/2000 [00:03<00:00, 656.96it/s]
100%|██████████| 2000/2000 [00:28<00:00, 69.91it/s]
100%|██████████| 2000/2000 [00:02<00:00, 708.99it/s]
100%|██████████| 2000/2000 [00:28<00:00, 70.33it/s]
100%|██████████| 2000/2000 [00:02<00:00, 715.91it/s]
100%|██████████| 2000/2000 [00:29<00:00, 66.72it/s]
100%|██████████| 2000/2000 [00:03<00:00, 664.83it/s]
100%|██████████| 2000/2000 [00:30<00:00, 64.73it/s]
100%|██████████| 2000/2000 [00:02<00:00, 697.20it/s]
 29%|██▉       | 582/2000 [00:09<00:25, 55.15it/s]

### Train on optimal hyperparameters and predict on test set (not validation set)

In [None]:
for k in results:
    if results[k] == min(results.values()):
        opt_hyperparameters = k
        
model = Model('GANITE', num_kk, *opt_hyperparameters)
model.fit(train_X, train_Y, train_T, dim_outcome)
hat_Y = model.predict(test_X)
pehe = model.test(test_X, test_OptY, metric=PEHE)
