In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import pandas as pd
import random
from time import sleep
import scipy
from scipy.stats import sem
from scipy import stats
from load_datasets import SynData5SourcesDiff
from model_train import *
from evaluation import Evaluation

if torch.cuda.is_available():
  torch.set_default_tensor_type('torch.cuda.FloatTensor')
  print('Use ***GPU***')
  print(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024,'GB')
else:
  print('Use CPU')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Use ***GPU***
11.17303466796875 GB


In [None]:
# Configuration
training_iter = 8000
D = 400
learning_rate_w = 1e-3
learning_rate_y = 1e-3
learning_rate_zy = 1e-2
reg_beta_w = 1e-2
reg_beta_y = 1e-1
reg_sig_y = 1e-1
reg_beta_zy = 1e-1
transfer_flag = FLAGS_LEARN_TRANSFER
display_per_iters=100
z_dim = 60
is_binary_outcome = False
use_mh = True

learn_transfer_test_stats_lst = []
learn_transfer_test_stats_w_lst = []
learn_transfer_test_stats_y_lst = []
for j, m in enumerate([1,2]): # change to [1,2,3,4,5] to run all
  # Load data
  dataset = SynData5SourcesDiff()
  source_size = dataset.source_size
  train_size = dataset.train_size
  test_size = dataset.test_size
  val_size = dataset.val_size

  loss_lst = []
  test_stats = []
  test_stats_w = []
  test_stats_y = []
  for i, (data_train, data_test, data_val) in enumerate(dataset.get_train_test_val(m_sources=m)):
    source_ranges_train = [(idx, idx+train_size) for idx in range(0,m*train_size,train_size)]
    source_ranges_test = [(idx, idx+test_size) for idx in range(0,m*test_size,test_size)]
    source_ranges_val = [(idx, idx+val_size) for idx in range(0,m*val_size,val_size)]
    print('======================================================================================')
    print('# Num of sources: {}, Replicate: {}'.format(m, i+1))
    print('======================================================================================')
    sleep(0.5)
    # Training data
    Wtr, Ytr, Y_cftr, mutr, Xtr = data_train[0][1].reshape(-1),\
                              data_train[0][2].reshape(-1),\
                              data_train[1][0].reshape(-1),\
                              np.concatenate((data_train[1][1],data_train[1][2]),axis=1),\
                              data_train[0][0]
    Ttr = len(Ytr)
    xtr = torch.from_numpy(Xtr.reshape(Ttr,-1)).float().to(device)
    ytr = torch.from_numpy(Ytr.reshape(-1,1)).float().to(device)
    wtr = torch.from_numpy(Wtr.reshape(-1,1)).float().to(device)

    # Testing data
    Wte, Yte, Y_cfte, mute, Xte = data_test[0][1].reshape(-1),\
                              data_test[0][2].reshape(-1),\
                              data_test[1][0].reshape(-1),\
                              np.concatenate((data_test[1][1],data_test[1][2]),axis=1),\
                              data_test[0][0]
    Tte = len(Yte)
    xte = torch.from_numpy(Xte.reshape(Tte,-1)).float().to(device)
    yte = torch.from_numpy(Yte.reshape(-1,1)).float().to(device)
    wte = torch.from_numpy(Wte.reshape(-1,1)).float().to(device)
    y_cfte = torch.from_numpy(Y_cfte.reshape(-1,1)).float().to(device)

    # Train
    print('*** P(W|X)')
    model_server_w, model_sources_w, omega_w = trainW(train_x=xtr,
                                                      train_w=wtr.reshape(-1),
                                                      n_sources=m, source_ranges=source_ranges_train, D=D,
                                                      training_iter=training_iter, learning_rate=learning_rate_w,
                                                      reg=reg_beta_w, display_per_iters=display_per_iters,
                                                      transfer_flag=transfer_flag)
    w_samples = sampleW(model_sources=model_sources_w, x=xte, n_sources=m,
                        source_ranges=source_ranges_test, n_samples=100, idx_sources_to_test=[0])

    print('*** P(Y|X,W)')
    model_server_y, model_sources_y, omega_y = trainY(train_x=xtr,
                                                      train_y=ytr,
                                                      train_w=wtr,
                                                      n_sources=m, source_ranges=source_ranges_train,
                                                      D=D, is_binary=is_binary_outcome,
                                                      training_iter=training_iter, learning_rate=learning_rate_y,
                                                      reg_beta=reg_beta_y, reg_sig=reg_sig_y,
                                                      display_per_iters=display_per_iters,
                                                      transfer_flag=transfer_flag)
    
    y_samples_xw = sampleY(model_sources=model_sources_y, x=xte, w_samples=w_samples, n_sources=m,
                           source_ranges=source_ranges_test, n_samples=100, idx_sources_to_test=[0])

    print('*** P(Z|Y,X,W) and P(Y|W,Z)')
    model_server_zy, model_sources_zy, omega_z, omega_xy = trainZY(train_x=xtr,
                                                                 train_y=ytr,
                                                                 train_w=wtr,
                                                                 n_sources=m, source_ranges=source_ranges_train,
                                                                 feats_binary=None,
                                                                 feats_continuous=None,
                                                                 is_binary=is_binary_outcome, dim_z=z_dim, D=D, 
                                                                 training_iter=training_iter,
                                                                 display_per_iters=display_per_iters,
                                                                 transfer_flag=transfer_flag,
                                                                 reg_beta=reg_beta_zy,
                                                                 learning_rate=learning_rate_zy)
    
    
    # Evaluate P(W|X)
    test_stat_w = testW(model_sources=model_sources_w, test_x=xte, test_w=wte.reshape(-1),
                      n_sources=m, source_ranges=source_ranges_test, idx_sources_to_test=[0])
    test_stats_w.append(test_stat_w)


    # Evaluate P(Y|X,W)
    test_stat_y = testY(model_sources=model_sources_y, test_x=xte, test_y=yte, test_w=wte,
                        n_sources=m, source_ranges=source_ranges_test, idx_sources_to_test=[0])
    test_stats_y.append(test_stat_y)

    # Evaluate ATE, ITE
    test_stats_TEs = testTEs(model_sources=model_sources_zy,
                             xte=xte, wte=wte, yte=yte, y_cfte=y_cfte, mute=mute,
                             w_samples=w_samples, y_samples=y_samples_xw,
                             n_sources=m, source_ranges=source_ranges_test, n_samples=100,
                             idx_sources_to_test=[0], is_binary=is_binary_outcome, use_mh=use_mh)
    test_stats.append(test_stats_TEs)

  learn_transfer_test_stats_lst.append(test_stats)
  learn_transfer_test_stats_w_lst.append(test_stats_w)
  learn_transfer_test_stats_y_lst.append(test_stats_y)

  learn_transfer_stats_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
  learn_transfer_stats_stderr = np.array([sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_lst])
  np.savez('save_outputs/save_5sources_synthetic2/graphs_5sources_learn_transfer_stats3_n-sources_{}.npz'.format(m),
          stats_mean=learn_transfer_stats_mean,
          stats_stderr=learn_transfer_stats_stderr)
  print(learn_transfer_stats_mean)
  
  learn_transfer_stats_w_mean = np.asarray([np.mean(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
  learn_transfer_stats_w_stderr = np.asarray([sem(test_stats,axis=0) for test_stats in  learn_transfer_test_stats_w_lst])
  np.savez('save_outputs/save_5sources_synthetic2/graphs_5sources_w_learn_transfer_stats3_n-sources_{}.npz'.format(m),
          stats_mean=learn_transfer_stats_w_mean,
          stats_stderr=learn_transfer_stats_w_stderr)
  
  learn_transfer_stats_y_mean = np.array([np.mean(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
  learn_transfer_stats_y_stderr = np.array([sem(test_stats,axis=0) for test_stats in learn_transfer_test_stats_y_lst])
  np.savez('save_outputs/save_5sources_synthetic2/graphs_5sources_y_learn_transfer_stats3_n-sources_{}.npz'.format(m),
          stats_mean=learn_transfer_stats_y_mean,
          stats_stderr=learn_transfer_stats_y_stderr)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Source 0, Iter 4700/8000 - Loss: 49.128
Source 0, Iter 4800/8000 - Loss: 42.366
Source 0, Iter 4900/8000 - Loss: 34.954
Source 0, Iter 5000/8000 - Loss: 26.925
Source 0, Iter 5100/8000 - Loss: 18.337
Source 0, Iter 5200/8000 - Loss: 9.263
Source 0, Iter 5300/8000 - Loss: -0.214
Source 0, Iter 5400/8000 - Loss: -10.014
Source 0, Iter 5500/8000 - Loss: -20.059
Source 0, Iter 5600/8000 - Loss: -30.285
Source 0, Iter 5700/8000 - Loss: -40.637
Source 0, Iter 5800/8000 - Loss: -51.069
Source 0, Iter 5900/8000 - Loss: -61.549
Source 0, Iter 6000/8000 - Loss: -72.049
Source 0, Iter 6100/8000 - Loss: -82.549
Source 0, Iter 6200/8000 - Loss: -93.037
Source 0, Iter 6300/8000 - Loss: -103.503
Source 0, Iter 6400/8000 - Loss: -113.941
Source 0, Iter 6500/8000 - Loss: -124.346
Source 0, Iter 6600/8000 - Loss: -134.718
Source 0, Iter 6700/8000 - Loss: -145.048
Source 0, Iter 6800/8000 - Loss: -155.361
Source 0, Iter 6900/8000 - Loss: -1

  return array(a, dtype, copy=False, order=order)


In [None]:
learn_transfer_stats_mean

array([[0.45892661, 1.40442073],
       [0.99484577, 1.98290643]])

In [None]:
learn_transfer_stats_stderr

array([[0.09295114, 0.05450865],
       [0.1629084 , 0.14476103]])