In [None]:
%load_ext autoreload
%autoreload 2
import os
from time import sleep
import numpy as np
import torch
import scipy
import pandas as pd
import random
from scipy import stats
from load_datasets import TWINS
from models_train import *
from evaluation import Evaluation

In [None]:
if torch.cuda.is_available():
  # torch.set_default_tensor_type('torch.cuda.FloatTensor')
  print('Use ***GPU***')
  print(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024,'GB')
else:
  print('Use CPU')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Use ***GPU***
14.726318359375 GB


# Train the model

In [None]:
# Configuration
OUTPUT_DIR = 'save_outputs/save_twins'
dim_z = 30
training_iter = 5000
transfer_flag = FLAGS_LEARN_TRANSFER
n_samples = 100
reg_alpha_w, lr_w = 1e-3, 1e-1
reg_alpha_y, reg_noise_y, lr_y =1e-2, 1e-2, 1e-1
reg_alpha_zy, lr_zy =1e-1, 1e-1
N_target_train = 100

# Lists to store results
learn_transfer_trans_factor_lst = []
learn_transfer_test_stats_lst = []
learn_transfer_test_stats_w_lst = []
learn_transfer_test_stats_y_lst = []

# Load data
dataset = TWINS()
m = 1 # number of sources

# Create output directory
if not os.path.exists(OUTPUT_DIR):
  os.makedirs(OUTPUT_DIR)

# Train the models
# for delta in [0.0, 0.5, 1.0, 1.5, 2.5]:
for delta in [0.0]: # This is an example with delta=0, uncomment the above to run all
  loss_lst = []
  trans_factors = []
  test_stats = []
  test_stats_w = []
  test_stats_y = []
  for i, (datas, datat, source_ranges) in enumerate(dataset.get_source_target(delta=delta)):
    print('======================================================================================')
    print('#Delta: {}, Replicate: {}'.format(delta, i+1))
    print('======================================================================================')
    sleep(0.5)
    # Source data to tensor
    Ws, Ys, Y_cfs, mus, Xs = datas[:,0].reshape(-1),\
                              datas[:,1].reshape(-1),\
                              datas[:,2].reshape(-1),\
                              np.concatenate((datas[:,3:4],datas[:,4:5]),axis=1),\
                              datas[:,5:]
    Ts = len(Ys)
    xs = torch.from_numpy(Xs.reshape(Ts,-1)).float().to(device)
    ys = torch.from_numpy(Ys.reshape(-1,1)).float().to(device)
    ws = torch.from_numpy(Ws.reshape(-1,1)).float().to(device)

    # Target train data to tensor
    Wt, Yt, Y_cft, mut, Xt = datat[:,0].reshape(-1)[:N_target_train],\
                              datat[:,1].reshape(-1)[:N_target_train],\
                              datat[:,2].reshape(-1)[:N_target_train],\
                              np.concatenate((datat[:,3:4],datat[:,4:5]),axis=1)[:N_target_train],\
                              datat[:N_target_train,5:]
    Tt = len(Yt)
    xt = torch.from_numpy(Xt.reshape(Tt,-1)).float().to(device)
    yt = torch.from_numpy(Yt.reshape(-1,1)).float().to(device)
    wt = torch.from_numpy(Wt.reshape(-1,1)).float().to(device)

    # Target test data to tensor
    Wtte, Ytte, Y_cftte, mutte, Xtte = datat[:,0].reshape(-1)[N_target_train:-100],\
                              datat[:,1].reshape(-1)[N_target_train:-100],\
                              datat[:,2].reshape(-1)[N_target_train:-100],\
                              np.concatenate((datat[:,3:4],datat[:,4:5]),axis=1)[N_target_train:-100],\
                              datat[N_target_train:-100,5:]
    Ttte = len(Ytte)
    xtte = torch.from_numpy(Xtte.reshape(Ttte,-1)).float().to(device)
    ytte = torch.from_numpy(Ytte.reshape(-1,1)).float().to(device)
    wtte = torch.from_numpy(Wtte.reshape(-1,1)).float().to(device)

    # Train
    domain_ranges =  list(source_ranges) + [(source_ranges[-1][1], source_ranges[-1][1]+N_target_train)]
    print('*** P(W|X)')
    model_w = trainW(train_x=torch.cat((xs,xt),dim=0),
                     train_w=torch.cat((ws,wt),dim=0).reshape(-1),
                     n_domains=m+1, domain_ranges=domain_ranges,
                     training_iter=training_iter, transfer_flag=transfer_flag,
                     reg_alpha=reg_alpha_w, lr=lr_w)
    w_samples = model_w.sample(x=xtte,n_samples=n_samples)

    print('*** P(Y|X,W)')
    model_y = trainY(train_x=torch.cat((xs,xt),dim=0),
                     train_y=torch.cat((ys,yt),dim=0),
                     train_w=torch.cat((ws,wt),dim=0),
                     n_domains=m+1, domain_ranges=domain_ranges, is_binary=True,
                     training_iter=training_iter, transfer_flag=transfer_flag,
                     reg_alpha=reg_alpha_y, reg_noise=reg_noise_y, lr=lr_y)
    y_samples_xw = model_y.sample(x=xtte,w_samples=w_samples,n_samples=n_samples,is_binary=True)

    print('*** P(Z|Y,X,W) and P(Y|W,Z)')
    model_zy = trainZY(train_x=torch.cat((xs,xt),dim=0),
                       train_y=torch.cat((ys,yt),dim=0),
                       train_w=torch.cat((ws,wt),dim=0),
                       n_domains=m+1, domain_ranges=domain_ranges, is_binary=True, dim_z=dim_z,
                       training_iter=training_iter, transfer_flag=transfer_flag,
                       reg_alpha=reg_alpha_zy, lr=lr_zy)
    
    
    # Evaluate P(W|X)
    accur = torch.sum((model_w.pred(xtte, domain_ranges=[(0,0)]*m+[(0,xtte.shape[0])])>0)*1.0==wtte.reshape(-1))/len(xtte)
    test_stats_w.append(np.insert(model_w.trans_factor().cpu().detach().numpy(),0,accur.cpu().detach().numpy()))

    # Evaluate P(Y|X,W)
    mae = torch.mean(torch.abs(model_y.pred(xtte, wtte, domain_ranges=[(0,0)]*m+[(0,xtte.shape[0])]).reshape(-1) - ytte.reshape(-1)))
    rmse = torch.sqrt(torch.mean((model_y.pred(xtte, wtte, domain_ranges=[(0,0)]*m+[(0,xtte.shape[0])]).reshape(-1) - ytte.reshape(-1))**2))
    test_stats_y.append(np.concatenate((np.asarray([mae.cpu().detach().numpy(), rmse.cpu().detach().numpy()]),
                                        model_y.trans_factor().cpu().detach().numpy())))

    # Evaluate ATE, ITE
    y_samples_do0,_ = model_zy.sample_v2(xtte, do_w=torch.zeros((xtte.shape[0],1),device=device),
                                         y_samples=y_samples_xw, w_samples=w_samples, n_samples=n_samples, is_binary=True)
    y_samples_do1,_ = model_zy.sample_v2(xtte, do_w=torch.ones((xtte.shape[0],1),device=device),
                                         y_samples=y_samples_xw, w_samples=w_samples, n_samples=n_samples, is_binary=True)
    
    y_do0_mean = torch.mean(y_samples_do0,dim=1)
    y_do1_mean = torch.mean(y_samples_do1,dim=1)

    # Y0tte = (1-Wtte)*Ytte + Wtte*Y_cftte
    # Y1tte = Wtte*Ytte + (1-Wtte)*Y_cftte
    eval = Evaluation(m0=mutte[:,0], m1=mutte[:,1])
    abs_err = eval.absolute_err_ate(y_do0_mean.cpu().detach().numpy(), y_do1_mean.cpu().detach().numpy())
    pehe = eval.pehe(y_do0_mean.cpu().detach().numpy(), y_do1_mean.cpu().detach().numpy())

    test_stats.append((abs_err, pehe))
    trans_factors.append(np.concatenate((model_w.trans_factor().cpu().detach().numpy(), model_y.trans_factor().cpu().detach().numpy(),
                                        model_zy.trans_factor().cpu().detach().numpy())))
    
    # Save output of replicate i+1
    np.savez(OUTPUT_DIR + '/graphs_learn_transfer_delta_{}_fold_{}.npz'.format(delta,i+1),
            trans_factors=trans_factors,
            test_stats=test_stats,
            test_stats_w=test_stats_w,
            test_stats_y=test_stats_y)

  learn_transfer_trans_factor_lst.append(trans_factors)
  learn_transfer_test_stats_lst.append(test_stats)
  learn_transfer_test_stats_w_lst.append(test_stats_w)
  learn_transfer_test_stats_y_lst.append(test_stats_y)

  # Save outputs
  learn_transfer_stats_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
  learn_transfer_stats_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
  np.savez(OUTPUT_DIR + '/graphs_learn_transfer_stats.npz',
          stats_mean=learn_transfer_stats_mean,
          stats_stderr=learn_transfer_stats_stderr)
  
  learn_transfer_trans_factor_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_trans_factor_lst])
  learn_transfer_trans_factor_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_trans_factor_lst])
  np.savez(OUTPUT_DIR + '/graphs_learn_transfer_trans_factor.npz',
          stats_mean=learn_transfer_trans_factor_mean,
          stats_stderr=learn_transfer_trans_factor_stderr)
  
  learn_transfer_stats_w_mean = np.asarray([np.mean(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
  learn_transfer_stats_w_stderr = np.asarray([stats.sem(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
  np.savez(OUTPUT_DIR + '/graphs_w_learn_transfer_stats.npz',
          stats_mean=learn_transfer_stats_w_mean,
          stats_stderr=learn_transfer_stats_w_stderr)
  
  learn_transfer_stats_y_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
  learn_transfer_stats_y_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
  np.savez(OUTPUT_DIR + '/graphs_y_learn_transfer_stats.npz',
          stats_mean=learn_transfer_stats_y_mean,
          stats_stderr=learn_transfer_stats_y_stderr)

#Delta: 0.0, Replicate: 1
*** P(W|X)
Iter 100/5000 - Loss: 33115.305   Transfer Ratio: [0.097]
Iter 200/5000 - Loss: 6885.087   Transfer Ratio: [0.077]
Iter 300/5000 - Loss: 11743.854   Transfer Ratio: [0.061]
Iter 400/5000 - Loss: 20075.627   Transfer Ratio: [0.047]
Iter 500/5000 - Loss: 18280.318   Transfer Ratio: [0.038]
Iter 600/5000 - Loss: 6982.391   Transfer Ratio: [0.032]
Iter 700/5000 - Loss: 19548.268   Transfer Ratio: [0.029]
Iter 800/5000 - Loss: 15097.925   Transfer Ratio: [0.027]
Iter 900/5000 - Loss: 20046.906   Transfer Ratio: [0.024]
Iter 1000/5000 - Loss: 15301.428   Transfer Ratio: [0.022]
Iter 1100/5000 - Loss: 9490.228   Transfer Ratio: [0.02]
Iter 1200/5000 - Loss: 14960.640   Transfer Ratio: [0.019]
Iter 1300/5000 - Loss: 18906.816   Transfer Ratio: [0.016]
Iter 1400/5000 - Loss: 7600.539   Transfer Ratio: [0.015]
Iter 1500/5000 - Loss: 10073.178   Transfer Ratio: [0.014]
Iter 1600/5000 - Loss: 50302.336   Transfer Ratio: [0.013]
Iter 1700/5000 - Loss: 35414.660 

### Print $\epsilon_{\text{ATE}}, \sqrt{\epsilon_{\text{PEHE}}}$

In [None]:
learn_transfer_stats_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
learn_transfer_stats_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
learn_transfer_stats_mean

array([[0.00704312, 0.30221926]])