In [4]:
#!/usr/bin/env python
# coding: utf-8
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torch.nn as nn
import scipy
import random
from scipy.stats import sem
import torchbnn as bnn
from torchbnn.utils import freeze, unfreeze
from model import *
from datasets import IHDP
from evaluation import *

source_id_to_run = 1 # source id to run, there are 10 replicates with id from 1 to 10
num_source_to_run = 4 # num of sources, there are maximum 6 sources in IHDP

print('n_sources {}, replicate {}'.format(num_source_to_run, source_id_to_run))

device_id = 0
print('PyTorch version', torch.__version__)
if torch.cuda.is_available():
  torch.set_default_tensor_type('torch.cuda.FloatTensor')
  torch.cuda.set_device(device_id)
  print('Use ***GPU***')
  print(torch.cuda.get_device_properties(device_id).total_memory/1024/1024/1024,'GB')
else:
  print('Use CPU')
device = torch.device("cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu")


RND_SEED = 2023
random.seed(RND_SEED)
np.random.seed(RND_SEED)
torch.manual_seed(RND_SEED)
torch.cuda.manual_seed_all(RND_SEED)
torch.backends.cudnn.deterministic=True

# Configuration
training_iter_z = 20000
training_iter_zhat = 20000
training_iter_y = 20000
learning_rate = 1e-4
display_per_iters=100
hidden_size = 20
output_dir = 'save_outputs'

# Load data
dataset = IHDP()
source_size = dataset.source_size
train_size = dataset.train_size
test_size = dataset.test_size
val_size = dataset.val_size
M = dataset.n_sources

test_stats_lst = []
for m in [num_source_to_run]:
  loss_lst = []
  test_stats = []
  for i, (data_train, data_test, data_val) in enumerate(dataset.get_train_test_val(m_sources=m)):
    if i!=source_id_to_run-1:
      continue

    source_ranges_train = [(idx, idx+train_size) for idx in range(0,m*train_size,train_size)]
    source_ranges_test = [(idx, idx+test_size) for idx in range(0,M*test_size,test_size)]
    source_ranges_val = [(idx, idx+val_size) for idx in range(0,m*val_size,val_size)]
    print(source_ranges_train)
    print(source_ranges_test)
    print('======================================================================================')
    print('# Source {}, Replicate: {}'.format(m, i+1))
    print('======================================================================================')

    # Training data
    Wtr, Ytr, Y_cftr, mutr, Xtr = data_train[0][1].reshape(-1)[:m*train_size],\
                              data_train[0][2].reshape(-1)[:m*train_size],\
                              data_train[1][0].reshape(-1)[:m*train_size],\
                              np.concatenate((data_train[1][1],data_train[1][2]),axis=1)[:m*train_size],\
                              data_train[0][0][:m*train_size]
    Ttr = len(Ytr)
    xtr = torch.from_numpy(Xtr.reshape(Ttr,-1)).float().to(device)
    ytr = torch.from_numpy(Ytr.reshape(-1,1)).float().to(device)
    wtr = torch.from_numpy(Wtr.reshape(-1,1)).float().to(device)

    # Testing data
    Wte, Yte, Y_cfte, mute, Xte, Xte_orgi = data_test[0][1].reshape(-1)[:M*test_size],\
                              data_test[0][2].reshape(-1)[:M*test_size],\
                              data_test[1][0].reshape(-1)[:M*test_size],\
                              np.concatenate((data_test[1][1],data_test[1][2]),axis=1)[:M*test_size],\
                              data_test[0][0][:M*test_size],\
                              data_test[1][3][:M*test_size]

    Tte = len(Yte)
    xte = torch.from_numpy(Xte.reshape(Tte,-1)).float().to(device)
    yte = torch.from_numpy(Yte.reshape(-1,1)).float().to(device)
    wte = torch.from_numpy(Wte.reshape(-1,1)).float().to(device)

    # Train
    print('*** P(Z|X,Y,W)')
    model_server_z, model_sources_z = trainZ_FedGrads(train_x=xtr[:,4:],
                                                      train_w=wtr.reshape(-1),
                                                      train_y=ytr.reshape(-1),
                                                      train_z=xtr[:,:4],
                                                      n_sources=m, source_ranges=source_ranges_train,
                                                      hidden_size=hidden_size,
                                                      training_iter=training_iter_z, learning_rate=learning_rate,
                                                      display_per_iters=display_per_iters)
    
#     print('*** P(Zr~|X,Zr)')
#     model_server_zhat, model_sources_zhat = trainZhat_FedGrads(train_x=xtr[:,4:],
#                                                                 train_y=ytr,
#                                                                 train_w=wtr,
#                                                                 model_z=model_sources_z,
#                                                                 dim_z=xtr[:,:4].shape[1],
#                                                                 n_sources=m, source_ranges=source_ranges_train,
#                                                                 training_iter=training_iter_zhat, learning_rate=learning_rate,
#                                                                 display_per_iters=display_per_iters)

    print('*** P(Y|X,Z,W), P(Zr~|X,Zr)')
    model_server_zhaty, model_sources_zhaty = trainY_FedGrads(train_x=xtr[:,4:],
                                                      train_w=wtr.reshape(-1),
                                                      train_y=ytr.reshape(-1),
                                                      model_z=model_sources_z,
                                                      dim_z=xtr[:,:4].shape[1],
                                                      n_sources=m, source_ranges=source_ranges_train,
                                                      hidden_size=hidden_size,
                                                      training_iter=training_iter_y, learning_rate=learning_rate,
                                                      display_per_iters=display_per_iters)
    model_server_zhat = model_server_zhaty.model_zhat
    model_sources_zhat = [model.model_zhat for model in model_sources_zhaty]
    model_server_y = model_server_zhaty.model_y
    model_sources_y = [model.model_y for model in model_sources_zhaty]
    # model_server_y, model_sources_y = trainY_FedParams(train_x=xtr[:,10:],
    #                                                   train_w=wtr.reshape(-1),
    #                                                   train_y=ytr.reshape(-1),
    #                                                   model_z=model_sources_z,
    #                                                   dim_z=10,
    #                                                   n_sources=m, source_ranges=source_ranges_train,
    #                                                   training_iter=100, num_agg=200,
    #                                                   learning_rate=learning_rate,
    #                                                   display_per_iters=display_per_iters)

    
    # Test
    y0pred, y1pred = pred_y0y1(model_server_zhat=model_server_zhat, model_server_y=model_server_y,
                              test_x=xte[:,4:], test_z=xte[:,:4],
                              test_w=wte, test_y=yte, n_sources=m, 
                              source_ranges_test=source_ranges_test, idx_sources_to_test=list(range(M)))

    eval = Evaluation(mute[:,0], mute[:,1])
    y0pred = y0pred.detach().cpu().numpy()
    y1pred = y1pred.detach().cpu().numpy()
    test_stats.append((eval.absolute_err_ate(y0pred,y1pred), eval.pehe(y0pred, y1pred)))

    np.savez('{}/ihdp_test_stats_m{}_replicate{}.npz'.format(output_dir, m,i+1), test_stats=np.asarray(test_stats))
  test_stats = np.asarray(test_stats)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
n_sources 4, replicate 1
PyTorch version 2.0.1+cu118
Use ***GPU***
31.7393798828125 GB
[(0, 80), (80, 160), (160, 240), (240, 320)]
[(0, 24), (24, 48), (48, 72), (72, 96), (96, 120), (120, 144)]
# Source 4, Replicate: 1
*** P(Z|X,Y,W)
Source 0, Iter 100/20000 - Loss: 169.851
Source 1, Iter 100/20000 - Loss: 162.661
Source 2, Iter 100/20000 - Loss: 165.842
Source 3, Iter 100/20000 - Loss: 188.615
Source 0, Iter 200/20000 - Loss: 155.933
Source 1, Iter 200/20000 - Loss: 166.665
Source 2, Iter 200/20000 - Loss: 169.244
Source 3, Iter 200/20000 - Loss: 188.731
Source 0, Iter 300/20000 - Loss: 161.957
Source 1, Iter 300/20000 - Loss: 155.402
Source 2, Iter 300/20000 - Loss: 166.426
Source 3, Iter 300/20000 - Loss: 171.273
Source 0, Iter 400/20000 - Loss: 158.441
Source 1, Iter 400/20000 - Loss: 150.909
Source 2, Iter 400/20000 - Loss: 169.689
Source 3, Iter 400/20000 - Loss: 169.295
Source 0, Iter 500/20

Source 0, Iter 4800/20000 - Loss: 104.008
Source 1, Iter 4800/20000 - Loss: 95.695
Source 2, Iter 4800/20000 - Loss: 116.920
Source 3, Iter 4800/20000 - Loss: 108.479
Source 0, Iter 4900/20000 - Loss: 103.812
Source 1, Iter 4900/20000 - Loss: 93.781
Source 2, Iter 4900/20000 - Loss: 108.262
Source 3, Iter 4900/20000 - Loss: 119.538
Source 0, Iter 5000/20000 - Loss: 112.807
Source 1, Iter 5000/20000 - Loss: 93.199
Source 2, Iter 5000/20000 - Loss: 104.335
Source 3, Iter 5000/20000 - Loss: 127.546
Source 0, Iter 5100/20000 - Loss: 110.574
Source 1, Iter 5100/20000 - Loss: 84.782
Source 2, Iter 5100/20000 - Loss: 102.008
Source 3, Iter 5100/20000 - Loss: 108.250
Source 0, Iter 5200/20000 - Loss: 106.485
Source 1, Iter 5200/20000 - Loss: 95.886
Source 2, Iter 5200/20000 - Loss: 104.480
Source 3, Iter 5200/20000 - Loss: 123.004
Source 0, Iter 5300/20000 - Loss: 92.945
Source 1, Iter 5300/20000 - Loss: 90.293
Source 2, Iter 5300/20000 - Loss: 107.394
Source 3, Iter 5300/20000 - Loss: 106.240

Source 0, Iter 9800/20000 - Loss: 94.549
Source 1, Iter 9800/20000 - Loss: 81.609
Source 2, Iter 9800/20000 - Loss: 92.009
Source 3, Iter 9800/20000 - Loss: 95.019
Source 0, Iter 9900/20000 - Loss: 89.993
Source 1, Iter 9900/20000 - Loss: 83.823
Source 2, Iter 9900/20000 - Loss: 249.471
Source 3, Iter 9900/20000 - Loss: 115.126
Source 0, Iter 10000/20000 - Loss: 91.486
Source 1, Iter 10000/20000 - Loss: 113.940
Source 2, Iter 10000/20000 - Loss: 86.864
Source 3, Iter 10000/20000 - Loss: 85.537
Source 0, Iter 10100/20000 - Loss: 89.925
Source 1, Iter 10100/20000 - Loss: 79.453
Source 2, Iter 10100/20000 - Loss: 93.763
Source 3, Iter 10100/20000 - Loss: 119.222
Source 0, Iter 10200/20000 - Loss: 93.676
Source 1, Iter 10200/20000 - Loss: 77.001
Source 2, Iter 10200/20000 - Loss: 89.050
Source 3, Iter 10200/20000 - Loss: 92.416
Source 0, Iter 10300/20000 - Loss: 84.487
Source 1, Iter 10300/20000 - Loss: 71.739
Source 2, Iter 10300/20000 - Loss: 85.651
Source 3, Iter 10300/20000 - Loss: 125

Source 0, Iter 14700/20000 - Loss: 91.169
Source 1, Iter 14700/20000 - Loss: 67.549
Source 2, Iter 14700/20000 - Loss: 85.349
Source 3, Iter 14700/20000 - Loss: 74.687
Source 0, Iter 14800/20000 - Loss: 81.234
Source 1, Iter 14800/20000 - Loss: 67.077
Source 2, Iter 14800/20000 - Loss: 69.076
Source 3, Iter 14800/20000 - Loss: 76.156
Source 0, Iter 14900/20000 - Loss: 108.595
Source 1, Iter 14900/20000 - Loss: 78.772
Source 2, Iter 14900/20000 - Loss: 80.912
Source 3, Iter 14900/20000 - Loss: 124.625
Source 0, Iter 15000/20000 - Loss: 84.244
Source 1, Iter 15000/20000 - Loss: 73.898
Source 2, Iter 15000/20000 - Loss: 84.131
Source 3, Iter 15000/20000 - Loss: 79.399
Source 0, Iter 15100/20000 - Loss: 79.283
Source 1, Iter 15100/20000 - Loss: 81.149
Source 2, Iter 15100/20000 - Loss: 100.102
Source 3, Iter 15100/20000 - Loss: 74.310
Source 0, Iter 15200/20000 - Loss: 91.097
Source 1, Iter 15200/20000 - Loss: 97.570
Source 2, Iter 15200/20000 - Loss: 76.894
Source 3, Iter 15200/20000 - Lo

Source 0, Iter 19600/20000 - Loss: 88.174
Source 1, Iter 19600/20000 - Loss: 57.531
Source 2, Iter 19600/20000 - Loss: 82.334
Source 3, Iter 19600/20000 - Loss: 70.268
Source 0, Iter 19700/20000 - Loss: 84.931
Source 1, Iter 19700/20000 - Loss: 85.953
Source 2, Iter 19700/20000 - Loss: 60.082
Source 3, Iter 19700/20000 - Loss: 72.576
Source 0, Iter 19800/20000 - Loss: 76.036
Source 1, Iter 19800/20000 - Loss: 45.427
Source 2, Iter 19800/20000 - Loss: 68.340
Source 3, Iter 19800/20000 - Loss: 73.143
Source 0, Iter 19900/20000 - Loss: 77.798
Source 1, Iter 19900/20000 - Loss: 55.157
Source 2, Iter 19900/20000 - Loss: 128.416
Source 3, Iter 19900/20000 - Loss: 77.603
Source 0, Iter 20000/20000 - Loss: 72.116
Source 1, Iter 20000/20000 - Loss: 55.382
Source 2, Iter 20000/20000 - Loss: 66.722
Source 3, Iter 20000/20000 - Loss: 78.547
*** P(Y|X,Z,W), P(Zr~|X,Zr)
Source 0, Iter 100/20000 - Loss: 499.322
Source 1, Iter 100/20000 - Loss: 538.988
Source 2, Iter 100/20000 - Loss: 458.111
Source 3

Source 0, Iter 4500/20000 - Loss: 151.968
Source 1, Iter 4500/20000 - Loss: 171.148
Source 2, Iter 4500/20000 - Loss: 128.054
Source 3, Iter 4500/20000 - Loss: 154.108
Source 0, Iter 4600/20000 - Loss: 153.082
Source 1, Iter 4600/20000 - Loss: 144.133
Source 2, Iter 4600/20000 - Loss: 146.206
Source 3, Iter 4600/20000 - Loss: 159.164
Source 0, Iter 4700/20000 - Loss: 138.157
Source 1, Iter 4700/20000 - Loss: 142.573
Source 2, Iter 4700/20000 - Loss: 162.579
Source 3, Iter 4700/20000 - Loss: 140.000
Source 0, Iter 4800/20000 - Loss: 142.949
Source 1, Iter 4800/20000 - Loss: 158.812
Source 2, Iter 4800/20000 - Loss: 131.069
Source 3, Iter 4800/20000 - Loss: 115.721
Source 0, Iter 4900/20000 - Loss: 141.240
Source 1, Iter 4900/20000 - Loss: 145.018
Source 2, Iter 4900/20000 - Loss: 136.883
Source 3, Iter 4900/20000 - Loss: 124.150
Source 0, Iter 5000/20000 - Loss: 153.092
Source 1, Iter 5000/20000 - Loss: 156.286
Source 2, Iter 5000/20000 - Loss: 195.094
Source 3, Iter 5000/20000 - Loss: 

Source 0, Iter 9400/20000 - Loss: 211.429
Source 1, Iter 9400/20000 - Loss: 162.747
Source 2, Iter 9400/20000 - Loss: 121.363
Source 3, Iter 9400/20000 - Loss: 137.401
Source 0, Iter 9500/20000 - Loss: 152.100
Source 1, Iter 9500/20000 - Loss: 124.668
Source 2, Iter 9500/20000 - Loss: 110.955
Source 3, Iter 9500/20000 - Loss: 124.756
Source 0, Iter 9600/20000 - Loss: 147.994
Source 1, Iter 9600/20000 - Loss: 159.173
Source 2, Iter 9600/20000 - Loss: 124.577
Source 3, Iter 9600/20000 - Loss: 147.674
Source 0, Iter 9700/20000 - Loss: 137.621
Source 1, Iter 9700/20000 - Loss: 139.371
Source 2, Iter 9700/20000 - Loss: 147.096
Source 3, Iter 9700/20000 - Loss: 114.355
Source 0, Iter 9800/20000 - Loss: 163.555
Source 1, Iter 9800/20000 - Loss: 132.647
Source 2, Iter 9800/20000 - Loss: 123.538
Source 3, Iter 9800/20000 - Loss: 123.065
Source 0, Iter 9900/20000 - Loss: 135.470
Source 1, Iter 9900/20000 - Loss: 132.201
Source 2, Iter 9900/20000 - Loss: 141.089
Source 3, Iter 9900/20000 - Loss: 

Source 0, Iter 14200/20000 - Loss: 151.152
Source 1, Iter 14200/20000 - Loss: 120.340
Source 2, Iter 14200/20000 - Loss: 137.074
Source 3, Iter 14200/20000 - Loss: 129.987
Source 0, Iter 14300/20000 - Loss: 141.941
Source 1, Iter 14300/20000 - Loss: 118.377
Source 2, Iter 14300/20000 - Loss: 106.215
Source 3, Iter 14300/20000 - Loss: 97.407
Source 0, Iter 14400/20000 - Loss: 119.426
Source 1, Iter 14400/20000 - Loss: 131.558
Source 2, Iter 14400/20000 - Loss: 130.458
Source 3, Iter 14400/20000 - Loss: 111.067
Source 0, Iter 14500/20000 - Loss: 152.260
Source 1, Iter 14500/20000 - Loss: 145.538
Source 2, Iter 14500/20000 - Loss: 110.475
Source 3, Iter 14500/20000 - Loss: 115.060
Source 0, Iter 14600/20000 - Loss: 129.552
Source 1, Iter 14600/20000 - Loss: 125.383
Source 2, Iter 14600/20000 - Loss: 138.930
Source 3, Iter 14600/20000 - Loss: 107.691
Source 0, Iter 14700/20000 - Loss: 115.881
Source 1, Iter 14700/20000 - Loss: 133.297
Source 2, Iter 14700/20000 - Loss: 154.154
Source 3, It

Source 0, Iter 19000/20000 - Loss: 106.009
Source 1, Iter 19000/20000 - Loss: 142.233
Source 2, Iter 19000/20000 - Loss: 134.717
Source 3, Iter 19000/20000 - Loss: 145.364
Source 0, Iter 19100/20000 - Loss: 104.062
Source 1, Iter 19100/20000 - Loss: 140.598
Source 2, Iter 19100/20000 - Loss: 140.901
Source 3, Iter 19100/20000 - Loss: 184.427
Source 0, Iter 19200/20000 - Loss: 129.521
Source 1, Iter 19200/20000 - Loss: 117.353
Source 2, Iter 19200/20000 - Loss: 128.614
Source 3, Iter 19200/20000 - Loss: 118.893
Source 0, Iter 19300/20000 - Loss: 109.447
Source 1, Iter 19300/20000 - Loss: 109.382
Source 2, Iter 19300/20000 - Loss: 164.110
Source 3, Iter 19300/20000 - Loss: 143.538
Source 0, Iter 19400/20000 - Loss: 141.295
Source 1, Iter 19400/20000 - Loss: 141.018
Source 2, Iter 19400/20000 - Loss: 167.867
Source 3, Iter 19400/20000 - Loss: 115.147
Source 0, Iter 19500/20000 - Loss: 108.336
Source 1, Iter 19500/20000 - Loss: 120.407
Source 2, Iter 19500/20000 - Loss: 132.568
Source 3, I