In [2]:
import math
import numpy as np
import pandas as pd
import torch
import torch.distributions as td
from torch import nn, optim
from torch.nn import functional as F
from tqdm import tqdm
from tqdm.auto import trange
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import mean_squared_error 
from SinkhornDistance import SinkhornDistance
from utils import *

        

In [3]:
import pandas as pd
import numpy as np
from tqdm import tqdm
path = 'csv_files/' #path of IHDP files, downloaded from https://www.fredjo.com/
def processed_data(i):
    train_data = pd.read_csv(path+'1000_train'+str(i)+'.csv').values
    train_data[:,13] = train_data[:,13]-1 #processing according to CEVAE
    test_data = pd.read_csv(path+'1000_test'+str(i)+'.csv').values
    test_data[:,13] = test_data[:,13]-1 #processing according to CEVAE
    train_label = np.zeros((len(train_data),3))
    t= pd.read_csv(path+'1000_train_t'+str(i)+'.csv').values.flatten()
    train_label[:,0] = t
    yf = pd.read_csv(path+'1000_train_yf'+str(i)+'.csv').values.flatten()
    ycf = pd.read_csv(path+'1000_train_ycf'+str(i)+'.csv').values.flatten()
    train_label[:,1] = np.where(t==1,yf, ycf) #treatment effect
    train_label[:,2] = np.where(t==0,yf, ycf) #non-treatment effect
    train_ite = np.zeros((len(train_data),1))
    train_ite = np.where(t==1,yf-ycf, ycf-yf)
    ite_test = pd.read_csv(path+'1000_test_ite'+str(i)+'.csv').values
    train_eval_split = int(0.8*len(train_data))
    indices = np.random.permutation(train_data.shape[0])
    training_idx, eval_idx = indices[:train_eval_split], indices[train_eval_split:]
    return train_data[training_idx,:],test_data,train_label[training_idx,:],\
            train_data[eval_idx,:],train_ite[training_idx],train_ite[eval_idx], \
            ite_test
    

In [4]:
#specify the GPU (or CPU)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

VAE = DisentangledVAE(n_epochs=150, number_of_labels=3,input_dimension =25 ,
                      latent_dimension = 10, 
                      hidden_layer_width=500,device=device)
VAE.batch_size=32
VAE.optimizer = optim.Adam(
    list(VAE.encoder.parameters()) + list(VAE.decoder.parameters()), lr=1e-4
    # ,weight_decay=1e-5
)
#prediction loss
VAE.pred_weight = [1,1]
# KL loss
VAE.KL_weight=1
#TC loss
VAE.beta = .1
#mmd loss
VAE.gamma=.1
VAE.plot=False
#recon loss
VAE.recon_weight = .1

#early stopping for model training
VAE.early_stopper = EarlyStopper(patience=4, min_delta=0.05)

In [5]:
train_losses=[]
eval_losses=[]
test_losses=[]

#in this demo, we take all 1000 replications
for i in trange(1000, position=0, desc="replication", leave=True, colour='black',):
    train_data,test_data,train_label,eval_data, ite_train, ite_eval,ite_test = processed_data(i)   
    VAE.wasserstein = 1
    VAE.pred_weight[0] =min(1, 0.1*train_label[train_label[:,0]==1,1].std())
    VAE.pred_weight[1] =min(1, 0.1*train_label[train_label[:,0]==0,2].std())
    VAE.early_stopper = EarlyStopper(patience=3, min_delta=0.05)
    score_train,score_eval,test_score = VAE.trainer(train_data,test_data,train_label,
                                                    eval_data, ite_train, ite_eval,
                                                    ite_test)
    train_losses.append(score_train)
    eval_losses.append(score_eval)
    test_losses.append(test_score)
    if i%20==0:
        print(i, np.mean(test_losses))


replication:   0%|          | 0/1000 [00:00<?, ?it/s]

0 0.36673075087048973
20 1.0721320370996037
40 1.3064795979319856
60 1.1608240313907012
80 1.1163837855163026
100 1.1904657508623415
120 1.1690790698804232
140 1.1720435323176737
160 1.157853035685819
180 1.1205391696619005
200 1.0802124438247944
220 1.097973141688663
240 1.0771746578315031
260 1.08160120443023
280 1.0683882011764099
300 1.058945218057044
320 1.0678927871022812
340 1.0908179772283049
360 1.1241088614522767
380 1.1131061030075724
400 1.1146458952685374
420 1.109792992338218
440 1.1161702199165142
460 1.1047634975523093
480 1.1178940265120145
500 1.1046351543822504
520 1.11288130100727
540 1.1032957427532541
560 1.094191545426488
580 1.0908665840341385
600 1.110925355173489
620 1.1088563275417196
640 1.11000388824785
660 1.1269817487752671
680 1.133489319041317
700 1.1328883633227878
720 1.12911863721639
740 1.1291287195319482
760 1.128368667748966
780 1.1431606589010046
800 1.1460850674714804
820 1.155770185142447
840 1.159343890296177
860 1.1559782906990794
880 1.15527

In [6]:
#print the results
from scipy.stats import sem
results = np.array(test_losses)
print(
    '\n The average rpehe of test data are: ',
    np.mean(results),
    'The standard error of epehe in test data are: ',
    sem(results))


 The average rpehe of test data are:  1.1784522819642695 The standard error of epehe in test data are:  0.038954680888623035
