In [None]:
%load_ext autoreload
%autoreload 2
import os
from time import sleep
import numpy as np
import torch
import scipy
import pandas as pd
import random
from scipy import stats
from load_datasets import SynDataOneSrc
from models_train import *
from evaluation import Evaluation

In [None]:
if torch.cuda.is_available():
  # torch.set_default_tensor_type('torch.cuda.FloatTensor')
  print('Use ***GPU***')
  print(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024,'GB')
else:
  print('Use CPU')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Use ***GPU***
14.726318359375 GB


# Train the model

In [None]:
# Configuration
OUTPUT_DIR = 'save_outputs/save_one_source_F'
dim_z = [35, 40, 45, 50, 55] # Varying z-dim
training_iter = 5000
transfer_flag = FLAGS_LEARN_TRANSFER
n_samples = 100
reg_alpha_w, lr_w = 1e-1, 1e-3
reg_alpha_y, reg_noise_y, lr_y =1e-2, 1e-1, 1e-3
reg_alpha_zy, lr_zy =1e-1, 1e-2
N_target_train = 50

# Lists to store results
learn_transfer_trans_factor_lst = []
learn_transfer_test_stats_lst = []
learn_transfer_test_stats_w_lst = []
learn_transfer_test_stats_y_lst = []

# Load data
source_size = 1000
dataset = SynDataOneSrc()
m = 1 # number of sources

# Create output directory
if not os.path.exists(OUTPUT_DIR):
  os.makedirs(OUTPUT_DIR)

# Train the models
for j in range(1):
  loss_lst = []
  trans_factors = []
  test_stats = []
  test_stats_w = []
  test_stats_y = []
  for i, (datas, datat) in enumerate(dataset.get_source_target(discrepancy_idx=j)):
  # for i in range(10):
  #   datas, datat = get_source_target(datas_lst, datat_lst, i, j)
    print('======================================================================================')
    print('Dataset: {}, Replicate: {}'.format(j+1, i+1))
    print('======================================================================================')
    sleep(0.5)
    # Source data to tensor
    Ws, Ys, Y_cfs, mus, Xs = datas[0][1].reshape(-1), datas[0][2].reshape(-1), datas[1][0].reshape(-1),\
                              np.concatenate((datas[1][1],datas[1][2]),axis=1), datas[0][0]
    Ts = len(Ys)
    xs = torch.from_numpy(Xs.reshape(Ts,-1)).float().to(device)
    ys = torch.from_numpy(Ys.reshape(-1,1)).float().to(device)
    ws = torch.from_numpy(Ws.reshape(-1,1)).float().to(device)

    # Target train data to tensor
    Wt, Yt, Y_cft, mut, Xt = datat[0][1].reshape(-1)[:N_target_train],\
                              datat[0][2].reshape(-1)[:N_target_train],\
                              datat[1][0].reshape(-1)[:N_target_train],\
                              np.concatenate((datat[1][1],datat[1][2]),axis=1)[:N_target_train],\
                              datat[0][0][:N_target_train,:]
    Tt = len(Yt)
    xt = torch.from_numpy(Xt.reshape(Tt,-1)).float().to(device)
    yt = torch.from_numpy(Yt.reshape(-1,1)).float().to(device)
    wt = torch.from_numpy(Wt.reshape(-1,1)).float().to(device)

    # Target test data to tensor
    Wtte, Ytte, Y_cftte, mutte, Xtte = datat[0][1].reshape(-1)[N_target_train:],\
                                        datat[0][2].reshape(-1)[N_target_train:],\
                                        datat[1][0].reshape(-1)[N_target_train:],\
                                        np.concatenate((datat[1][1],datat[1][2]),axis=1)[N_target_train:],\
                                        datat[0][0][:,:][N_target_train:]
    Ttte = len(Ytte)
    xtte = torch.from_numpy(Xtte.reshape(Ttte,-1)).float().to(device)
    ytte = torch.from_numpy(Ytte.reshape(-1,1)).float().to(device)
    wtte = torch.from_numpy(Wtte.reshape(-1,1)).float().to(device)

    # Train
    source_ranges = [(idx, idx+source_size) for idx in range(0,m*source_size,source_size)]
    domain_ranges =  source_ranges + [(m*source_size, m*source_size+N_target_train)]
    print('*** P(W|X)')
    model_w = trainW(train_x=torch.cat((xs,xt),dim=0),
                     train_w=torch.cat((ws,wt),dim=0).reshape(-1),
                     n_domains=m+1, domain_ranges=domain_ranges,
                     training_iter=training_iter, transfer_flag=transfer_flag,
                     reg_alpha=reg_alpha_w, lr=lr_w)
    w_samples = model_w.sample(x=xtte,n_samples=n_samples)

    print('*** P(Y|X,W)')
    model_y = trainY(train_x=torch.cat((xs,xt),dim=0),
                     train_y=torch.cat((ys,yt),dim=0),
                     train_w=torch.cat((ws,wt),dim=0),
                     n_domains=m+1, domain_ranges=domain_ranges,
                     training_iter=training_iter, transfer_flag=transfer_flag,
                     reg_alpha=reg_alpha_y, reg_noise=reg_noise_y, lr=lr_y)
    y_samples_xw = model_y.sample(x=xtte,w_samples=w_samples,n_samples=n_samples)

    print('*** P(Z|Y,X,W) and P(Y|W,Z)')
    model_zy = trainZY(train_x=torch.cat((xs,xt),dim=0),
                       train_y=torch.cat((ys,yt),dim=0),
                       train_w=torch.cat((ws,wt),dim=0),
                       n_domains=m+1, domain_ranges=domain_ranges, dim_z=dim_z[j],
                       training_iter=training_iter, transfer_flag=transfer_flag,
                       reg_alpha=reg_alpha_zy, lr=lr_zy)
    
    
    # Evaluate P(W|X)
    accur = torch.sum((model_w.pred(xtte, domain_ranges=[(0,0)]*m+[(0,xtte.shape[0])])>0)*1.0==wtte.reshape(-1))/len(xtte)
    test_stats_w.append(np.insert(model_w.trans_factor().cpu().detach().numpy(),0,accur.cpu().detach().numpy()))

    # Evaluate P(Y|X,W)
    mae = torch.mean(torch.abs(model_y.pred(xtte, wtte, domain_ranges=[(0,0)]*m+[(0,xtte.shape[0])]).reshape(-1) - ytte.reshape(-1)))
    rmse = torch.sqrt(torch.mean((model_y.pred(xtte, wtte, domain_ranges=[(0,0)]*m+[(0,xtte.shape[0])]).reshape(-1) - ytte.reshape(-1))**2))
    test_stats_y.append(np.concatenate((np.asarray([mae.cpu().detach().numpy(), rmse.cpu().detach().numpy()]),
                                        model_y.trans_factor().cpu().detach().numpy())))

    # Evaluate ATE, ITE
    y_samples_do0,_ = model_zy.sample_v2(xtte, do_w=torch.zeros((xtte.shape[0],1),device=device),
                                         y_samples=y_samples_xw, w_samples=w_samples, n_samples=n_samples)
    y_samples_do1,_ = model_zy.sample_v2(xtte, do_w=torch.ones((xtte.shape[0],1),device=device),
                                         y_samples=y_samples_xw, w_samples=w_samples, n_samples=n_samples)
    
    y_do0_mean = torch.mean(y_samples_do0,dim=1)
    y_do1_mean = torch.mean(y_samples_do1,dim=1)

    # Y0tte = (1-Wtte)*Ytte + Wtte*Y_cftte
    # Y1tte = Wtte*Ytte + (1-Wtte)*Y_cftte
    eval = Evaluation(m0=mutte[:,0], m1=mutte[:,1])
    abs_err = eval.absolute_err_ate(y_do0_mean.cpu().detach().numpy(), y_do1_mean.cpu().detach().numpy())
    pehe = eval.pehe(y_do0_mean.cpu().detach().numpy(), y_do1_mean.cpu().detach().numpy())

    test_stats.append((abs_err, pehe))
    trans_factors.append(np.concatenate((model_w.trans_factor().cpu().detach().numpy(), model_y.trans_factor().cpu().detach().numpy(),
                                        model_zy.trans_factor().cpu().detach().numpy())))
    
    # Save output of replicate i+1
    np.savez(OUTPUT_DIR + '/graphs_learn_transfer_varying_zdim_{}_replicate_{}.npz'.format(j+1,i+1),
            trans_factors=trans_factors,
            test_stats=test_stats,
            test_stats_w=test_stats_w,
            test_stats_y=test_stats_y)

  learn_transfer_trans_factor_lst.append(trans_factors)
  learn_transfer_test_stats_lst.append(test_stats)
  learn_transfer_test_stats_w_lst.append(test_stats_w)
  learn_transfer_test_stats_y_lst.append(test_stats_y)

  # Save outputs
  learn_transfer_stats_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
  learn_transfer_stats_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
  np.savez(OUTPUT_DIR + '/graphs_learn_transfer_stats_varying_zdim.npz',
          stats_mean=learn_transfer_stats_mean,
          stats_stderr=learn_transfer_stats_stderr)
  
  learn_transfer_trans_factor_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_trans_factor_lst])
  learn_transfer_trans_factor_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_trans_factor_lst])
  np.savez(OUTPUT_DIR + '/graphs_learn_transfer_trans_factor_varying_zdim.npz',
          stats_mean=learn_transfer_trans_factor_mean,
          stats_stderr=learn_transfer_trans_factor_stderr)
  
  learn_transfer_stats_w_mean = np.asarray([np.mean(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
  learn_transfer_stats_w_stderr = np.asarray([stats.sem(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
  np.savez(OUTPUT_DIR + '/graphs_w_learn_transfer_stats_varying_zdim.npz',
          stats_mean=learn_transfer_stats_w_mean,
          stats_stderr=learn_transfer_stats_w_stderr)
  
  learn_transfer_stats_y_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
  learn_transfer_stats_y_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
  np.savez(OUTPUT_DIR + '/graphs_y_learn_transfer_stats_varying_zdim.npz',
          stats_mean=learn_transfer_stats_y_mean,
          stats_stderr=learn_transfer_stats_y_stderr)

Dataset: 1, Replicate: 1
*** P(W|X)
Iter 100/5000 - Loss: 2130.879   Transfer Ratio: [0.478]
Iter 200/5000 - Loss: 1519.055   Transfer Ratio: [0.461]
Iter 300/5000 - Loss: 1141.298   Transfer Ratio: [0.45]
Iter 400/5000 - Loss: 893.074   Transfer Ratio: [0.443]
Iter 500/5000 - Loss: 725.176   Transfer Ratio: [0.439]
Iter 600/5000 - Loss: 611.561   Transfer Ratio: [0.438]
Iter 700/5000 - Loss: 534.223   Transfer Ratio: [0.439]
Iter 800/5000 - Loss: 479.947   Transfer Ratio: [0.442]
Iter 900/5000 - Loss: 440.229   Transfer Ratio: [0.447]
Iter 1000/5000 - Loss: 410.026   Transfer Ratio: [0.452]
Iter 1100/5000 - Loss: 386.351   Transfer Ratio: [0.459]
Iter 1200/5000 - Loss: 367.381   Transfer Ratio: [0.467]
Iter 1300/5000 - Loss: 351.947   Transfer Ratio: [0.475]
Iter 1400/5000 - Loss: 339.253   Transfer Ratio: [0.484]
Iter 1500/5000 - Loss: 328.709   Transfer Ratio: [0.493]
Iter 1600/5000 - Loss: 319.855   Transfer Ratio: [0.503]
Iter 1700/5000 - Loss: 312.318   Transfer Ratio: [0.513]
It

### Print $\epsilon_{\text{ATE}}, \sqrt{\epsilon_{\text{PEHE}}}$

In [None]:
learn_transfer_stats_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
learn_transfer_stats_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
learn_transfer_stats_mean

array([[0.11782894, 1.09099345],
       [0.1497903 , 1.16562962],
       [0.30114955, 1.24740887],
       [0.26461714, 1.26342214],
       [0.23610094, 1.26582838]])

In [None]:
learn_transfer_stats_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst]) # learn transfer
learn_transfer_stats_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst]) # learn transfer
# np.savetxt('save_one_source_F/graphs_learn_transfer_stats_varying_zdim.txt', np.concatenate((learn_transfer_stats_mean,learn_transfer_stats_stderr),axis=1))
learn_transfer_stats_mean

array([[0.11782675, 1.09099268],
       [0.13794295, 1.17114677],
       [0.223591  , 1.20371401],
       [0.39212068, 1.30781013],
       [0.29160374, 1.25623264]])

### Print transfer factors

In [None]:
learn_transfer_trans_factor_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_trans_factor_lst])
learn_transfer_trans_factor_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_trans_factor_lst])
learn_transfer_trans_factor_mean

array([[0.7459139 , 0.87398607, 0.8885415 ],
       [0.7467107 , 0.78075504, 0.81378233],
       [0.73397964, 0.67152345, 0.7486127 ],
       [0.72342414, 0.578894  , 0.6908272 ],
       [0.7242181 , 0.5084177 , 0.63143694]], dtype=float32)

### Print prediction accuracy of $P(W|X)$

In [None]:
learn_transfer_stats_w_mean = np.asarray([np.mean(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
learn_transfer_stats_w_stderr = np.asarray([stats.sem(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
learn_transfer_stats_w_mean

array([[0.89463156, 0.7459139 ],
       [0.8930526 , 0.7467107 ],
       [0.89231586, 0.73397964],
       [0.89147365, 0.72342414],
       [0.89263153, 0.7242181 ]], dtype=float32)

### Print prediction error of $P(Y|W,X)$

In [None]:
learn_transfer_stats_y_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
learn_transfer_stats_y_stderr = np.array([stats.sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
learn_transfer_stats_y_mean

array([[1.5727255 , 2.3868768 , 0.87398607],
       [1.5383657 , 2.3008652 , 0.78075504],
       [1.542381  , 2.2997823 , 0.67152345],
       [1.568529  , 2.335389  , 0.578894  ],
       [1.5976369 , 2.3778934 , 0.5084177 ]], dtype=float32)