In [None]:
#!/usr/bin/env python
# coding: utf-8
import numpy as np
import torch
import torch.nn as nn
import scipy
import pandas as pd
import random
from scipy.stats import sem
import torchbnn as bnn
from torchbnn.utils import freeze, unfreeze
from model import *
from datasets import SynData50Sources

source_id_to_run = 1 # source id to run, there are 10 replicates with id from 1 to 10
num_source_to_run = 10 # num of sources, there are maximum 50 sources in synthetic data

print('n_sources {}, replicate {}'.format(num_source_to_run, source_id_to_run))


device_id = 0
print('PyTorch version', torch.__version__)
if torch.cuda.is_available():
  torch.set_default_tensor_type('torch.cuda.FloatTensor')
  torch.cuda.set_device(device_id)
  print('Use ***GPU***')
  print(torch.cuda.get_device_properties(device_id).total_memory/1024/1024/1024,'GB')
else:
  print('Use CPU')
device = torch.device("cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu")


RND_SEED = 2023
random.seed(RND_SEED)
np.random.seed(RND_SEED)
torch.manual_seed(RND_SEED)
torch.cuda.manual_seed_all(RND_SEED)
torch.backends.cudnn.deterministic=True

# Configuration
training_iter_z = 10000
training_iter_zhat = 10000
training_iter_y = 10000
learning_rate = 1e-3
display_per_iters=100
hidden_size = 10
output_dir = 'save_outputs'

# Load data
dataset = SynData50Sources()
source_size = dataset.source_size
train_size = dataset.train_size
test_size = dataset.test_size
val_size = dataset.val_size
M = dataset.n_sources

test_stats_lst = []
for m in [num_source_to_run]:
  loss_lst = []
  test_stats = []
  for i, (data_train, data_test, data_val) in enumerate(dataset.get_train_test_val(m_sources=m)):
    if i!=source_id_to_run-1:
      continue

    source_ranges_train = [(idx, idx+train_size) for idx in range(0,m*train_size,train_size)]
    source_ranges_test = [(idx, idx+test_size) for idx in range(0,M*test_size,test_size)]
    source_ranges_val = [(idx, idx+val_size) for idx in range(0,m*val_size,val_size)]
    print(source_ranges_train)
    print(source_ranges_test)
    print('======================================================================================')
    print('# Source {}, Replicate: {}'.format(m, i+1))
    print('======================================================================================')

    # Training data
    Wtr, Ytr, Y_cftr, mutr, Xtr = data_train[0][1].reshape(-1)[:m*train_size],\
                              data_train[0][2].reshape(-1)[:m*train_size],\
                              data_train[1][0].reshape(-1)[:m*train_size],\
                              np.concatenate((data_train[1][1],data_train[1][2]),axis=1)[:m*train_size],\
                              data_train[0][0][:m*train_size]
    Ttr = len(Ytr)
    xtr = torch.from_numpy(Xtr.reshape(Ttr,-1)).float().to(device)
    ytr = torch.from_numpy(Ytr.reshape(-1,1)).float().to(device)
    wtr = torch.from_numpy(Wtr.reshape(-1,1)).float().to(device)

    # Testing data
    Wte, Yte, Y_cfte, mute, Xte, Xte_orgi = data_test[0][1].reshape(-1)[:M*test_size],\
                              data_test[0][2].reshape(-1)[:M*test_size],\
                              data_test[1][0].reshape(-1)[:M*test_size],\
                              np.concatenate((data_test[1][1],data_test[1][2]),axis=1)[:M*test_size],\
                              data_test[0][0][:M*test_size],\
                              data_test[1][3][:M*test_size]

    Tte = len(Yte)
    xte = torch.from_numpy(Xte.reshape(Tte,-1)).float().to(device)
    yte = torch.from_numpy(Yte.reshape(-1,1)).float().to(device)
    wte = torch.from_numpy(Wte.reshape(-1,1)).float().to(device)


    # Train
    print('*** P(Z|X,Y,W)')
    model_server_z, model_sources_z = trainZ_FedGrads(train_x=xtr[:,10:],
                                                      train_w=wtr.reshape(-1),
                                                      train_y=ytr.reshape(-1),
                                                      train_z=xtr[:,:10],
                                                      n_sources=m, source_ranges=source_ranges_train,
                                                      hidden_size=hidden_size,
                                                      training_iter=training_iter_z, learning_rate=learning_rate,
                                                      display_per_iters=display_per_iters)
    
#     print('*** P(Zr~|X,Zr)')
#     model_server_zhat, model_sources_zhat = trainZhat_FedGrads(train_x=xtr[:,10:],
#                                                                 train_y=ytr,
#                                                                 train_w=wtr,
#                                                                 model_z=model_sources_z,
#                                                                 dim_z=xtr[:,:10].shape[1],
#                                                                 n_sources=m, source_ranges=source_ranges_train,
#                                                                 training_iter=training_iter_zhat, learning_rate=learning_rate,
#                                                                 display_per_iters=display_per_iters)

    print('*** P(Y|X,Z,W), P(Zr~|X,Zr)')
    model_server_zhaty, model_sources_zhaty = trainY_FedGrads(train_x=xtr[:,10:],
                                                      train_w=wtr.reshape(-1),
                                                      train_y=ytr.reshape(-1),
                                                      model_z=model_sources_z,
                                                      dim_z=xtr[:,:10].shape[1],
                                                      n_sources=m, source_ranges=source_ranges_train,
                                                      hidden_size=hidden_size,
                                                      training_iter=training_iter_y, learning_rate=learning_rate,
                                                      display_per_iters=display_per_iters)
    model_server_zhat = model_server_zhaty.model_zhat
    model_sources_zhat = [model.model_zhat for model in model_sources_zhaty]
    model_server_y = model_server_zhaty.model_y
    model_sources_y = [model.model_y for model in model_sources_zhaty]
    # model_server_y, model_sources_y = trainY_FedParams(train_x=xtr[:,10:],
    #                                                   train_w=wtr.reshape(-1),
    #                                                   train_y=ytr.reshape(-1),
    #                                                   model_z=model_sources_z,
    #                                                   dim_z=10,
    #                                                   n_sources=m, source_ranges=source_ranges_train,
    #                                                   training_iter=100, num_agg=200,
    #                                                   learning_rate=learning_rate,
    #                                                   display_per_iters=display_per_iters)

    
    # Test
    y0pred, y1pred = pred_y0y1(model_server_zhat=model_server_zhat, model_server_y=model_server_y,
                              test_x=xte[:,10:], test_z=xte[:,:10],
                              test_w=wte, test_y=yte, n_sources=m, 
                              source_ranges_test=source_ranges_test, idx_sources_to_test=list(range(M)))

    eval = Evaluation(mute[:,0], mute[:,1])
    y0pred = y0pred.detach().cpu().numpy()
    y1pred = y1pred.detach().cpu().numpy()
    test_stats.append((eval.absolute_err_ate(y0pred,y1pred), eval.pehe(y0pred, y1pred)))

    np.savez('{}/synthetic_test_stats_m{}_replicate{}.npz'.format(output_dir, m,i+1), test_stats=np.asarray(test_stats))
  test_stats = np.asarray(test_stats)

n_sources 10, replicate 1
PyTorch version 2.0.1+cu118
Use ***GPU***
31.7393798828125 GB
[(0, 100), (100, 200), (200, 300), (300, 400), (400, 500), (500, 600), (600, 700), (700, 800), (800, 900), (900, 1000)]
[(0, 50), (50, 100), (100, 150), (150, 200), (200, 250), (250, 300), (300, 350), (350, 400), (400, 450), (450, 500), (500, 550), (550, 600), (600, 650), (650, 700), (700, 750), (750, 800), (800, 850), (850, 900), (900, 950), (950, 1000), (1000, 1050), (1050, 1100), (1100, 1150), (1150, 1200), (1200, 1250), (1250, 1300), (1300, 1350), (1350, 1400), (1400, 1450), (1450, 1500), (1500, 1550), (1550, 1600), (1600, 1650), (1650, 1700), (1700, 1750), (1750, 1800), (1800, 1850), (1850, 1900), (1900, 1950), (1950, 2000), (2000, 2050), (2050, 2100), (2100, 2150), (2150, 2200), (2200, 2250), (2250, 2300), (2300, 2350), (2350, 2400), (2400, 2450), (2450, 2500)]
# Source 10, Replicate: 1
*** P(Z|X,Y,W)
Source 0, Iter 100/10000 - Loss: 584.129
Source 1, Iter 100/10000 - Loss: 678.820
Source 2, I

Source 0, Iter 1900/10000 - Loss: 348.257
Source 1, Iter 1900/10000 - Loss: 403.748
Source 2, Iter 1900/10000 - Loss: 388.280
Source 3, Iter 1900/10000 - Loss: 418.678
Source 4, Iter 1900/10000 - Loss: 365.106
Source 5, Iter 1900/10000 - Loss: 403.336
Source 6, Iter 1900/10000 - Loss: 370.167
Source 7, Iter 1900/10000 - Loss: 365.122
Source 8, Iter 1900/10000 - Loss: 378.277
Source 9, Iter 1900/10000 - Loss: 395.878
Source 0, Iter 2000/10000 - Loss: 337.695
Source 1, Iter 2000/10000 - Loss: 396.597
Source 2, Iter 2000/10000 - Loss: 387.469
Source 3, Iter 2000/10000 - Loss: 414.878
Source 4, Iter 2000/10000 - Loss: 357.229
Source 5, Iter 2000/10000 - Loss: 407.818
Source 6, Iter 2000/10000 - Loss: 382.782
Source 7, Iter 2000/10000 - Loss: 352.698
Source 8, Iter 2000/10000 - Loss: 362.013
Source 9, Iter 2000/10000 - Loss: 387.444
Source 0, Iter 2100/10000 - Loss: 341.259
Source 1, Iter 2100/10000 - Loss: 400.642
Source 2, Iter 2100/10000 - Loss: 390.650
Source 3, Iter 2100/10000 - Loss: 

Source 0, Iter 3900/10000 - Loss: 330.304
Source 1, Iter 3900/10000 - Loss: 382.119
Source 2, Iter 3900/10000 - Loss: 378.544
Source 3, Iter 3900/10000 - Loss: 408.562
Source 4, Iter 3900/10000 - Loss: 341.883
Source 5, Iter 3900/10000 - Loss: 368.965
Source 6, Iter 3900/10000 - Loss: 374.138
Source 7, Iter 3900/10000 - Loss: 346.983
Source 8, Iter 3900/10000 - Loss: 357.383
Source 9, Iter 3900/10000 - Loss: 369.275
Source 0, Iter 4000/10000 - Loss: 324.477
Source 1, Iter 4000/10000 - Loss: 388.316
Source 2, Iter 4000/10000 - Loss: 382.374
Source 3, Iter 4000/10000 - Loss: 395.097
Source 4, Iter 4000/10000 - Loss: 342.207
Source 5, Iter 4000/10000 - Loss: 365.738
Source 6, Iter 4000/10000 - Loss: 359.297
Source 7, Iter 4000/10000 - Loss: 353.228
Source 8, Iter 4000/10000 - Loss: 343.149
Source 9, Iter 4000/10000 - Loss: 366.150
Source 0, Iter 4100/10000 - Loss: 329.446
Source 1, Iter 4100/10000 - Loss: 383.920
Source 2, Iter 4100/10000 - Loss: 373.860
Source 3, Iter 4100/10000 - Loss: 

Source 0, Iter 5900/10000 - Loss: 326.924
Source 1, Iter 5900/10000 - Loss: 382.164
Source 2, Iter 5900/10000 - Loss: 376.706
Source 3, Iter 5900/10000 - Loss: 394.052
Source 4, Iter 5900/10000 - Loss: 334.158
Source 5, Iter 5900/10000 - Loss: 363.892
Source 6, Iter 5900/10000 - Loss: 348.936
Source 7, Iter 5900/10000 - Loss: 345.318
Source 8, Iter 5900/10000 - Loss: 341.162
Source 9, Iter 5900/10000 - Loss: 364.738
Source 0, Iter 6000/10000 - Loss: 330.157
Source 1, Iter 6000/10000 - Loss: 383.325
Source 2, Iter 6000/10000 - Loss: 371.502
Source 3, Iter 6000/10000 - Loss: 396.630
Source 4, Iter 6000/10000 - Loss: 343.674
Source 5, Iter 6000/10000 - Loss: 364.359
Source 6, Iter 6000/10000 - Loss: 353.560
Source 7, Iter 6000/10000 - Loss: 338.077
Source 8, Iter 6000/10000 - Loss: 341.512
Source 9, Iter 6000/10000 - Loss: 366.138
Source 0, Iter 6100/10000 - Loss: 326.936
Source 1, Iter 6100/10000 - Loss: 384.062
Source 2, Iter 6100/10000 - Loss: 381.136
Source 3, Iter 6100/10000 - Loss: 

Source 0, Iter 7900/10000 - Loss: 326.572
Source 1, Iter 7900/10000 - Loss: 384.681
Source 2, Iter 7900/10000 - Loss: 370.047
Source 3, Iter 7900/10000 - Loss: 384.635
Source 4, Iter 7900/10000 - Loss: 321.826
Source 5, Iter 7900/10000 - Loss: 352.855
Source 6, Iter 7900/10000 - Loss: 345.465
Source 7, Iter 7900/10000 - Loss: 338.084
Source 8, Iter 7900/10000 - Loss: 331.805
Source 9, Iter 7900/10000 - Loss: 353.457
Source 0, Iter 8000/10000 - Loss: 323.616
Source 1, Iter 8000/10000 - Loss: 388.064
Source 2, Iter 8000/10000 - Loss: 371.036
Source 3, Iter 8000/10000 - Loss: 383.370
Source 4, Iter 8000/10000 - Loss: 325.361
Source 5, Iter 8000/10000 - Loss: 351.374
Source 6, Iter 8000/10000 - Loss: 338.602
Source 7, Iter 8000/10000 - Loss: 335.172
Source 8, Iter 8000/10000 - Loss: 337.216
Source 9, Iter 8000/10000 - Loss: 356.368
Source 0, Iter 8100/10000 - Loss: 331.916
Source 1, Iter 8100/10000 - Loss: 387.490
Source 2, Iter 8100/10000 - Loss: 371.737
Source 3, Iter 8100/10000 - Loss: 