In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import pandas as pd
import random
from time import sleep
import gpytorch
from scipy import stats
import scipy
from scipy import stats
from load_datasets import IHDP
from model_train import train_model
from model_utils import *
from evaluation import Evaluation

if torch.cuda.is_available():
  torch.set_default_tensor_type('torch.cuda.FloatTensor')
  print('Use ***GPU***')
  print(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024,'GB')
else:
  print('Use CPU')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Use ***GPU***
14.755615234375 GB


In [None]:
n_iterations = 2000
learning_rate = 1e-3
display_per_iter = 200

dataset = IHDP()
source_size = dataset.source_size
train_size = dataset.train_size
test_size = dataset.test_size
val_size = dataset.val_size

stats_federated = []
for n_sources in [1, 2, 3]:
  stats_train_all = []
  stats_test_all = []

  for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test_combined(n_sources)):
    print('***************************************')
    print('***** n_source {},  Replicate #{} *****'.format(n_sources, i+1))
    print('***************************************')

    w, y, y_cf, mu, x = train[0][1].reshape(-1), train[0][2].reshape(-1), train[1][0].reshape(-1),\
                            np.concatenate((train[1][1],train[1][2]),axis=1), train[0][0][:,:]
    N = len(y)
    X = torch.from_numpy(x.reshape(N,-1)).float().to(device)
    Yobs = torch.from_numpy(y.reshape(-1,1)).float().to(device)
    W = torch.from_numpy(w.reshape(-1,1)).float().to(device)
    
    # Test data
    wte, yte, y_cfte, mute, xte = test[0][1].reshape(-1), test[0][2].reshape(-1), test[1][0].reshape(-1), \
                                        np.concatenate((test[1][1],test[1][2]), axis=1), test[0][0][:,:]
    Nte = len(yte)
    Xte = torch.from_numpy(xte.reshape(Nte,-1)).float().to(device)
    Yobste = torch.from_numpy(yte.reshape(-1,1)).float().to(device)
    Wte = torch.from_numpy(wte.reshape(-1,1)).float().to(device)

    # Validation data
    wva, yva, y_cfva, muva, xva = valid[0][1].reshape(-1), valid[0][2].reshape(-1), valid[1][0].reshape(-1), \
                                        np.concatenate((valid[1][1],valid[1][2]), axis=1), valid[0][0][:,:]
    Nva = len(yva)
    Xva = torch.from_numpy(xva.reshape(Nva,-1)).float().to(device)
    Yobsva = torch.from_numpy(yva.reshape(-1,1)).float().to(device)
    Wva = torch.from_numpy(wva.reshape(-1,1)).float().to(device)

    model_server, model_sources = train_model(Yobs=Yobs, X=X, W=W, Yobste=Yobste, Xte=Xte, Wte=Wte, 
                                              y=y, y_cf=y_cf, x=x, w=w, yte=yte, y_cfte=y_cfte, xte=xte, wte=wte,
                                              n_sources=n_sources, train_size=train_size, test_size=test_size,
                                              n_iterations=n_iterations, learning_rate=learning_rate, display_per_iter=display_per_iter,
                                              bin_feats_4moments=False, cont_feats=contfeats, bin_feats=binfeats)

    # EVALUATION
    Ymis_pred_tr = []
    err_ate_source_tr = []
    Ymis_pred_te = []
    err_ate_source_te = []
    for s in range(len(model_sources)):
      idx = range(s*train_size, (s+1)*train_size)
      idxte = range(s*test_size, (s+1)*test_size)
      meanATE_tr, varATE_tr, meanYmis_tr, covYmis_tr = detach_to_numpy(model_sources[s].predictATE(Yobs[idx,:], X[idx,:],
                                                                                                W[idx,:], idx_source=s))
      meanATE_te, varATE_te, meanYmis_te, covYmis_te = detach_to_numpy(model_sources[s].predictATE(Yobste[idxte,:], Xte[idxte,:],
                                                                                                Wte[idxte,:], idx_source=s))
      Ymis_pred_tr.append(meanYmis_tr.reshape(-1))
      Ymis_pred_te.append(meanYmis_te.reshape(-1))

      y0 = (1-w[idx])*y[idx] + w[idx]*y_cf[idx]
      y1 = w[idx]*y[idx] + (1-w[idx])*y_cf[idx]
      y0pred = (1-w[idx])*y[idx] + w[idx]*meanYmis_tr.reshape(-1)
      y1pred = w[idx]*y[idx] + (1-w[idx])*meanYmis_tr.reshape(-1)
      evaluator_tr = Evaluation(m0=y0, m1=y1)
      err_ate_source_tr.append(evaluator_tr.absolute_err_ate(y0pred=y0pred,y1pred=y1pred))

      y0 = (1-wte[idxte])*yte[idxte] + wte[idxte]*y_cfte[idxte]
      y1 = wte[idxte]*yte[idxte] + (1-wte[idxte])*y_cfte[idxte]
      y0pred = (1-wte[idxte])*yte[idxte] + wte[idxte]*meanYmis_te.reshape(-1)
      y1pred = wte[idxte]*yte[idxte] + (1-wte[idxte])*meanYmis_te.reshape(-1)
      evaluator_te = Evaluation(m0=y0, m1=y1)
      err_ate_source_te.append(evaluator_te.absolute_err_ate(y0pred=y0pred,y1pred=y1pred))

    Ymis_pred_tr = np.concatenate(Ymis_pred_tr)
    Ymis_pred_te = np.concatenate(Ymis_pred_te)

    y0 = (1-w)*y + w*y_cf
    y1 = w*y + (1-w)*y_cf
    y0pred = (1-w)*y + w*Ymis_pred_tr.reshape(-1)
    y1pred = w*y + (1-w)*Ymis_pred_tr.reshape(-1)
    evaluator_tr = Evaluation(m0=y0, m1=y1)
    stats_tr = [evaluator_tr.pehe(y0pred=y0pred,y1pred=y1pred),
                evaluator_tr.absolute_err_ate(y0pred=y0pred,y1pred=y1pred),
                np.mean(err_ate_source_tr)]
    stats_train_all.append(stats_tr)
    print('Last result of replicate #{}'.format(i+1))
    print('Train: pehe {}, err_ate {}, err_ate(s) {}'.format(stats_tr[0], stats_tr[1], stats_tr[2]))
    

    y0 = (1-wte)*yte + wte*y_cfte
    y1 = wte*yte + (1-wte)*y_cfte
    y0pred = (1-wte)*yte + wte*Ymis_pred_te.reshape(-1)
    y1pred = wte*yte + (1-wte)*Ymis_pred_te.reshape(-1)
    evaluator_te = Evaluation(m0=y0, m1=y1)
    stats_te = [evaluator_te.pehe(y0pred=y0pred,y1pred=y1pred),
                evaluator_te.absolute_err_ate(y0pred=y0pred,y1pred=y1pred),
                np.mean(err_ate_source_te)]
    stats_test_all.append(stats_te)
    print('Test: pehe {}, err_ate {}, err_ate(s) {}'.format(stats_te[0], stats_te[1], stats_te[2]))

    print('=================================================================')
  print('PERFORMANCE, n_sources {}'.format(n_sources))
  print('Train:', np.mean(stats_train_all,axis=0), stats.sem(stats_train_all,axis=0))
  print('Test:', np.mean(stats_test_all,axis=0), stats.sem(stats_test_all,axis=0))
  stats_federated.append(stats_test_all)

***********************
***** n_source 1 *****
***********************
Do not use one-hot encoding, d_x = 25
***** Replicate #1 *****
Iter: 0
Train: pehe 3.9409746159235586, err_ate 2.922787977950478, err_ate(s) 2.922787977950478
Test: pehe 3.0055245767053234, err_ate 2.6039411695321135, err_ate(s) 2.6039411695321135
Iter: 200
Train: pehe 1.3406945381488093, err_ate 0.572168888508632, err_ate(s) 0.572168888508632
Test: pehe 1.3084042701448095, err_ate 0.8175705645743414, err_ate(s) 0.8175705645743414
Iter: 400
Train: pehe 1.6573900456331145, err_ate 1.2330743481885902, err_ate(s) 1.2330743481885902
Test: pehe 1.056607479121483, err_ate 0.36045566277348184, err_ate(s) 0.36045566277348184
Iter: 600
Train: pehe 1.0892356815026343, err_ate 0.25935542696363134, err_ate(s) 0.25935542696363134
Test: pehe 1.1138640476591621, err_ate 0.20536668225959875, err_ate(s) 0.20536668225959875
Iter: 800
Train: pehe 1.099049487854554, err_ate 0.31210818163450416, err_ate(s) 0.31210818163450416
Test: pehe

In [None]:
stats_test_federated_mean = np.asarray([np.mean(stats_test_all,axis=0) for stats_test_all in stats_federated])
stats_test_federated_stderr = np.asarray([stats.sem(stats_test_all,axis=0) for stats_test_all in stats_federated])
np.savez('save_outputs/save_ihdp/stats_errors_test_federated.npz',
         stats_test_federated_mean=stats_test_federated_mean,
         stats_test_federated_stderr=stats_test_federated_stderr)
stats_test_federated_mean

array([[2.74792476, 1.28331512, 1.28331512],
       [2.38807871, 0.87869454, 1.1924641 ],
       [2.48809336, 0.6422931 , 0.96383898]])

In [None]:
stats_test_federated_stderr

array([[0.65060266, 0.44864847, 0.44864847],
       [0.35403448, 0.20386732, 0.15815424],
       [0.60119304, 0.20824323, 0.16676626]])