In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle as pkl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

In [None]:
import sys
sys.path.append("..")

In [None]:
%load_ext autoreload
%autoreload 2
from common.utility import to_categorical, torch_device
from notebook_utils.generate_gaussian import generate_gaussian
from notebook_utils.eigan import Encoder, Discriminator
from notebook_utils.federated import federated
from notebook_utils.eigan_training import distributed_3, centralized_3
from notebook_utils.utility import class_plot, to_numpy
import notebook_utils.metrics as metrics

In [None]:
device='gpu'
device = torch_device(device=device)
device

# Training

In [None]:
X, y_ally, y_advr_1, y_advr_2 = pkl.load(
    open('../checkpoints/mimic/processed_data_X_y_ally_y_advr_y_advr_2.pkl', 'rb'))

y_ally = y_ally.reshape(-1, 1)
y_advr_1 = y_advr_1.reshape(-1, 1)
y_advr_2 = y_advr_2.reshape(-1, 1)


width=0.2
fig, (ax1) = plt.subplots(1, 1, figsize=(9, 4))
ax1.bar(np.unique(y_ally.flatten())-width, np.bincount(y_ally.flatten()), width, color='b')
ax1.bar(np.unique(y_advr_1.flatten()), np.bincount(y_advr_1.flatten()), width, color='r', hatch='o')
ax1.bar(np.unique(y_advr_2.flatten())+width, np.bincount(y_advr_2.flatten()), width, color='r', hatch='-')
ax1.legend(['ally', 'adversary 1', 'adversary 2'])

plt.show()

X.shape, y_ally.shape, y_advr_1.shape, y_advr_2.shape

In [None]:
rand_idx = np.random.permutation(X.shape[0])
crossover = int(0.3*X.shape[0])

X1, X2 = X[rand_idx[:crossover]], X[rand_idx[crossover:]] 
y_ally_1, y_ally_2 = y_ally[rand_idx[:crossover]], y_ally[rand_idx[crossover:]]
y_advr_11, y_advr_12 = y_advr_1[rand_idx[:crossover]], y_advr_1[rand_idx[crossover:]]
y_advr_21, y_advr_22 = y_advr_2[rand_idx[:crossover]], y_advr_2[rand_idx[crossover:]]

width=0.2
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
ax1.bar(np.unique(y_ally_1.flatten())-width, np.bincount(y_ally_1.flatten()), width, color='b')
ax1.bar(np.unique(y_advr_11.flatten()), np.bincount(y_advr_11.flatten()), width, color='r', hatch='o')
ax1.bar(np.unique(y_advr_21.flatten())+width, np.bincount(y_advr_21.flatten()), width, color='r', hatch='-')
ax1.set_title('@1')
ax2.bar(np.unique(y_ally_2.flatten())-width, np.bincount(y_ally_2.flatten()), width, color='b')
ax2.bar(np.unique(y_advr_12.flatten()), np.bincount(y_advr_12.flatten()), width, color='r', hatch='o')
ax2.bar(np.unique(y_advr_22.flatten())+width, np.bincount(y_advr_22.flatten()), width, color='r', hatch='-')
ax2.set_title('@2')
plt.show()

In [None]:
MAX = 1024
history = {}
NUM_TRIALS = 10
VAR1 = 4
VAR2 = 1
NUM_NODES = 2
PHI=1
BATCHSIZE=512

In [None]:
NUM_NODES = 2
BATCHSIZE = 512

X = [X1, X2] 
y_ally = [y_ally_1, y_ally_2]
y_advr_1 = [y_advr_11, y_advr_12]
y_advr_2 = [y_advr_21, y_advr_22]

for _ in range(NUM_NODES):
    print('@node {}, X: {}, y_ally: {}, y_advr_1: {}, y_advr_2: {}'.format(
        _, X[_].shape, y_ally[_].shape, y_advr_1[_].shape, y_advr_2[_].shape))
    
w_ally = []
w_advr_1 = []
w_advr_2 = []
train_loaders = []
X_valids = []
X_trains = []
y_ally_valids = []
y_ally_trains = []
y_advr_1_valids = []
y_advr_1_trains = []
y_advr_2_valids = []
y_advr_2_trains = []
for node_idx in range(NUM_NODES):
    X_local = X[node_idx]
    y_ally_local = y_ally[node_idx]
    y_advr_1_local = y_advr_1[node_idx]
    y_advr_2_local = y_advr_2[node_idx]
    
    X_train, X_valid, y_ally_train, y_ally_valid, \
    y_advr_1_train, y_advr_1_valid, \
    y_advr_2_train, y_advr_2_valid = train_test_split(
        X_local, y_ally_local, y_advr_1_local, y_advr_2_local, test_size=0.2, stratify=pd.DataFrame(
            np.concatenate((y_ally_local, y_advr_1_local, y_advr_2_local), axis=1)
        ))
    print('@node {}: X_train, X_valid, y_ally_train, y_ally_valid,'
          'y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid'.format(node_idx))
    print(X_train.shape, X_valid.shape, 
          y_ally_train.shape, y_ally_valid.shape,
          y_advr_1_train.shape, y_advr_1_valid.shape, 
          y_advr_2_train.shape, y_advr_2_valid.shape)

    w = np.bincount(y_ally_train.flatten())
    w_ally.append(sum(w)/w)
    w = np.bincount(y_advr_1_train.flatten())
    w_advr_1.append(sum(w)/w)
    w = np.bincount(y_advr_2_train.flatten())
    w_advr_2.append(sum(w)/w)
    print('@node {}: class weights => w_ally, w_advr_1, w_advr_2'.format(node_idx), w_ally, w_advr_1, w_advr_2)

    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(X_train)
    X_valid = scaler.transform(X_valid)
    
    width = 0.2
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
    ax1.bar(np.unique(
        y_ally_train.flatten())-width, np.bincount(y_ally_train.flatten()), width, color='b')
    ax1.bar(np.unique(
        y_advr_1_train.flatten()), np.bincount(y_advr_1_train.flatten()), width, color='r', hatch='o')
    ax1.bar(np.unique(
        y_advr_2_train.flatten())+width, np.bincount(y_advr_2_train.flatten()), width, color='r', hatch='-')
    ax1.legend(['ally', 'adversary 1', 'adversary 2'])
    ax1.set_title('train@{}'.format(node_idx+1))
    ax2.bar(np.unique(
        y_ally_valid.flatten())-width, np.bincount(y_ally_valid.flatten()), width, color='b')
    ax2.bar(np.unique(
        y_advr_1_valid.flatten()), np.bincount(y_advr_1_valid.flatten()), width, color='r', hatch='o')
    ax2.bar(np.unique(
        y_advr_2_valid.flatten())+width, np.bincount(y_advr_2_valid.flatten()), width, color='r', hatch='-')
    ax2.legend(['ally', 'adversary 1', 'adversary 2'])
    ax2.set_title('valid@{}'.format(node_idx+1))
    
    y_ally_train = to_categorical(y_ally_train)
    y_ally_valid = to_categorical(y_ally_valid)
    y_advr_1_train = to_categorical(y_advr_1_train)
    y_advr_2_train = to_categorical(y_advr_2_train)
    y_advr_1_valid = to_categorical(y_advr_1_valid)
    y_advr_2_valid = to_categorical(y_advr_2_valid)

    X_train = torch.Tensor(X_train)
    y_ally_train = torch.Tensor(y_ally_train)
    y_advr_1_train = torch.Tensor(y_advr_1_train)
    y_advr_2_train = torch.Tensor(y_advr_2_train)

    X_valids.append(torch.Tensor(X_valid))
    y_ally_valids.append(torch.Tensor(y_ally_valid))
    y_advr_1_valids.append(torch.Tensor(y_advr_1_valid))
    y_advr_2_valids.append(torch.Tensor(y_advr_2_valid))
    
    X_trains.append(X_train)
    y_ally_trains.append(y_ally_train)
    y_advr_1_trains.append(y_advr_1_train)
    y_advr_2_trains.append(y_advr_2_train)

    print('@node {}: tensor sizes =>'.format(node_idx), X_train.shape, X_valid.shape, 
          y_ally_train.shape, y_ally_valid.shape,
          y_advr_1_train.shape, y_advr_1_valid.shape, y_advr_2_train.shape, y_advr_2_valid.shape)

    train_loaders.append(DataLoader(TensorDataset(X_train, y_ally_train, y_advr_1_train, y_advr_2_train), 
                                    batch_size=BATCHSIZE, shuffle=True))
    

In [None]:
alpha = 1
lr_encd = 0.0001
lr_ally = 0.00001
lr_advr_1 = 0.00001
lr_advr_2 = 0.00001
n_iter_gan = 501

input_size = X_train.shape[1]
hidden_size = input_size*8
output_size = 2

In [None]:
global_params = {}
encoders = {}
history = {}

In [None]:
history = pkl.load(open('mimic_var_phi_delta_1.pkl', 'rb'))

In [None]:
delta = 2
for phi in range(0, 11, 2):
    phi /= 10
    print("-"*80)
    print('EIGAN Training w/ phi={} and delta={}'.format(phi, delta))
    print("-"*80)
    encoders['{}_{}'.format(phi, delta)] = distributed_3(NUM_NODES, phi, delta, 
                   X_trains, X_valids, 
                   y_ally_train, y_ally_valids,                            
                   y_advr_1_trains, y_advr_1_valids,
                   y_advr_2_trains, y_advr_2_valids,
                   input_size, hidden_size, [2, 2, 2], 
                   alpha, lr_encd, lr_ally, lr_advr_1, lr_advr_2,w_ally, w_advr_1, w_advr_2,
                   train_loaders, n_iter_gan, device, global_params)

In [None]:
import pickle as pkl
pkl.dump(encoders, open('encoders_mimic_num_nodes_2_phi_var_delta_2.pkl', 'wb'))

In [None]:
import pickle as pkl
encoders = pkl.load(open('encoders_mimic_num_nodes_2_phi_var_delta_2.pkl', 'rb'))

In [None]:
X_train = torch.cat(X_trains, dim=0).to(device)
X_valid = torch.cat(X_valids, dim=0).to(device)
y_ally_train = torch.cat(y_ally_trains, dim=0).to(device)
y_ally_valid = torch.cat(y_ally_valids, dim=0).to(device)
y_advr_1_train = torch.cat(y_advr_1_trains, dim=0).to(device)
y_advr_1_valid = torch.cat(y_advr_1_valids, dim=0).to(device)
y_advr_2_train = torch.cat(y_advr_2_trains, dim=0).to(device)
y_advr_2_valid = torch.cat(y_advr_2_valids, dim=0).to(device)

In [None]:
train_loader = DataLoader(TensorDataset(X_train, y_ally_train, y_advr_1_train, y_advr_2_train), batch_size=BATCHSIZE, shuffle=True)

encoder = centralized_3(X_train, X_valid,
                      y_ally_train, y_ally_valid,
                      y_advr_1_train, y_advr_1_valid,
                      y_advr_2_train, y_advr_2_valid,
                      input_size, hidden_size, [2, 2, 2],
                      alpha, lr_encd, lr_ally, lr_advr_1, lr_advr_2, w_ally[0], w_advr_1[0], w_advr_2[0],
                      train_loader, n_iter_gan, device)

In [None]:
import pickle as pkl
pkl.dump(encoder, open('encoder_mimic_num_nodes_2_phi_delta.pkl', 'wb'))

In [None]:
import pickle as pkl
encoder = pkl.load(open('encoder_mimic_num_nodes_2_phi_delta.pkl', 'rb'))

In [None]:
print("-"*80)
print('ALLY: BASELINE')
print("-"*80)
history['baseline_ally'] = metrics.centralized(None, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_1_train, y_1_valid, 
                                             w_1[0], device)

print("-"*80)
print('ADVERSARY: BASELINE')
print("-"*80)
history['baseline_advr'] = metrics.centralized(None, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_2_train, y_2_valid, 
                                             w_2[0], device)

In [None]:
print("-"*80)
print('ALLY: CENTRALIZED')
print("-"*80)
history['centralized_ally'] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_ally_train, y_ally_valid, 
                                             w_ally[0], device, ['logistic', 'mlp'])

print("-"*80)
print('ADVR 1: CENTRALIZED')
print("-"*80)
history['centralized_advr_1'] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_advr_1_train, y_advr_1_valid, 
                                             w_advr_1[0], device, ['logistic', 'mlp'])

print("-"*80)
print('ADVR 2: CENTRALIZED')
print("-"*80)
history['centralized_advr_2'] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_advr_2_train, y_advr_2_valid, 
                                             w_advr_2[0], device, ['logistic', 'mlp'])

In [None]:
NUM_NODES = 2
X_trains = [_.to(device) for _ in X_trains]
X_valids = [_.to(device) for _ in X_valids]
y_ally_trains = [_.to(device) for _ in y_ally_trains]
y_ally_valids = [_.to(device) for _ in y_ally_valids]
y_advr_1_trains = [_.to(device) for _ in y_advr_1_trains]
y_advr_1_valids = [_.to(device) for _ in y_advr_1_valids]
y_advr_2_trains = [_.to(device) for _ in y_advr_2_trains]
y_advr_2_valids = [_.to(device) for _ in y_advr_2_valids]

In [None]:
for key, encd in encoders.items():
    
    
    print("-"*80)
    print('ALLY: {}'.format(key))
    print("-"*80)
    history['decentralize_ally_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                         input_size, hidden_size, output_size, 
                                                         X_trains, X_valids, y_ally_trains, y_ally_valids, 
                                                         w_ally[0], device, ['mlp', 'logistic'])
    print("-"*80)
    print('ADVERSARY 1: {}'.format(key))
    print("-"*80)
    history['decentralized_advr_1_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_advr_1_trains, y_advr_1_valids, 
                                                                     w_advr_1[0], device, ['mlp', 'logistic'])
    
    print("-"*80)
    print('ADVERSARY: {}'.format(key))
    print("-"*80)
    history['decentralized_advr_2_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_advr_2_trains, y_advr_2_valids, 
                                                                     w_advr_2[0], device, ['mlp', 'logistic'])

In [None]:
import pickle as pkl
encoders = pkl.load(open('encoders_mimic_num_nodes_2_phi_0.8_delta_var.pkl', 'rb'))

In [None]:
for key, encd in encoders.items():
    
    
    print("-"*80)
    print('ALLY: {}'.format(key))
    print("-"*80)
    history['decentralize_ally_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                         input_size, hidden_size, output_size, 
                                                         X_trains, X_valids, y_ally_trains, y_ally_valids, 
                                                         w_ally[0], device, ['mlp', 'logistic'])
    print("-"*80)
    print('ADVERSARY 1: {}'.format(key))
    print("-"*80)
    history['decentralized_advr_1_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_advr_1_trains, y_advr_1_valids, 
                                                                     w_advr_1[0], device, ['mlp', 'logistic'])
    
    print("-"*80)
    print('ADVERSARY: {}'.format(key))
    print("-"*80)
    history['decentralized_advr_2_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_advr_2_trains, y_advr_2_valids, 
                                                                     w_advr_2[0], device, ['mlp', 'logistic'])

In [None]:
pkl.dump(history, open('mimic_var_phi_delta_1.pkl', 'wb'))

In [None]:
import pickle as pkl

In [None]:
history = pkl.load(open('mimic_var_phi_delta.pkl', 'rb'))

In [None]:
import matplotlib.pyplot as plt

In [None]:
delta = 2
baseline_ally = []
baseline_advr = []
eigan_ally = []
eigan_advr_1 = []
eigan_advr_2 = []
dist_x = []
dist_ally = []
dist_advr_1 = []
dist_advr_2 = []

# tmp = history['baseline_ally'][3]
# baseline_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
# tmp = history['baseline_advr'][3]
# baseline_advr.append(max(tmp['logistic'], tmp['mlp']))
tmp = history['centralized_ally'][2]
eigan_ally.append(max(tmp['logistic'], tmp['mlp']))
tmp = history['centralized_advr_1'][2]
eigan_advr_1.append(max(tmp['logistic'], tmp['mlp']))
tmp = history['centralized_advr_2'][2]
eigan_advr_2.append(max(tmp['logistic'], tmp['mlp']))


fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(10, 4))

for phi in range(2, 11, 2):
    phi /= 10
    dist_x.append(phi)
    tmp = history['decentralize_{}_{}_{}'.format('ally', phi, delta)][2]
    dist_ally.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_{}_{}_{}'.format('advr_1', phi, delta)][2]
    dist_advr_1.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_{}_{}_{}'.format('advr_2', phi, delta)][2]
    dist_advr_2.append(max(tmp['logistic'], tmp['mlp']))

ax1.hlines(y=eigan_ally[0], xmin=0.1, xmax=1.1, color='b', linestyle='dashed')
ax1.hlines(y=eigan_advr_1[0], xmin=0.1, xmax=1.1, color='r', linestyle='dashed')
ax1.hlines(y=eigan_advr_2[0], xmin=0.1, xmax=1.1, color='c', linestyle='dashed')
ax1.bar(np.array(dist_x)-0.04, dist_ally, width=0.04, color='b')
ax1.bar(np.array(dist_x), dist_advr_1, width=0.04, color='r')
ax1.bar(np.array(dist_x)+0.04, dist_advr_2, width=0.04, color='c')
ax1.set_xticks(dist_x)
ax1.set_xlim(left=0.1, right=1.1)
ax1.legend(['E-ally', 'E-advr 1', 'E-advr 2' 'D-ally', 'D-advr 1', 'D-advr 2'])
ax1.set_xlabel('fraction of parameters shared')
ax1.set_ylabel('accuracy')
ax1.set_title('(b)', y=-0.3)
ax1.grid()

phi = 0.8
dist_x = []
dist_ally = []
dist_advr_1 = []
dist_advr_2 = []
for delta in range(2, 11, 2):
    dist_x.append(delta)
    tmp = history['decentralize_{}_{}_{}'.format('ally', phi, delta)][2]
    dist_ally.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_{}_{}_{}'.format('advr_1', phi, delta)][2]
    dist_advr_1.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_{}_{}_{}'.format('advr_2', phi, delta)][2]
    dist_advr_2.append(max(tmp['logistic'], tmp['mlp']))

ax2.hlines(y=eigan_ally[0], xmin=1, xmax=11, color='b', linestyle='dashed')
ax2.hlines(y=eigan_advr_1[0], xmin=1, xmax=11, color='r', linestyle='dashed')
ax2.hlines(y=eigan_advr_2[0], xmin=1, xmax=11, color='c', linestyle='dashed')
ax2.bar(np.array(dist_x)-0.4, dist_ally, width=0.4, color='b')
ax2.bar(np.array(dist_x), dist_advr_1, width=0.4, color='r')
ax2.bar(np.array(dist_x)+0.4, dist_advr_2, width=0.4, color='c')
ax2.set_xticks(dist_x)
ax2.set_xlim(left=1, right=11)
ax2.set_xlabel('frequency of sync')
ax2.set_ylabel('accuracy')
ax2.set_title('(a)', y=-.3)
ax2.grid()

plt.rcParams.update({'font.size': 14})
fig.subplots_adjust(wspace=0.3)
plt.savefig('distributed_eigan_comparison_var_phi_delta.png', bbox_inches='tight', dpi=300)

In [None]:
pkl.dump(history, open('mimic_var_phi_delta.pkl', 'wb'))

In [None]:
history['centralized_ally'][3]

In [None]:
plt.setp(bp1['boxes'], color='cyan')
plt.setp(bp1['medians'], color='black')
bp2 = ax1.boxplot(eigan_ally, vert=1, whis=1.5, widths=0.05, 
                  positions=[0.15, 0.35, 0.55, 0.75], showfliers=False, patch_artist=True)
plt.setp(bp2['boxes'], color='red')
plt.setp(bp2['medians'], color='black')
bp3 = ax1.boxplot(baseline_ally, vert=1, whis=1.5, widths=0.05, 
                  positions=[0.25, 0.45, 0.65, 0.85], showfliers=False, patch_artist=True)
plt.setp(bp3['boxes'], color='orange')
plt.setp(bp3['medians'], color='black')
ax1.legend([bp1["boxes"][0], bp2["boxes"][0], bp3["boxes"][0]], 
           ['Bertran', 'EIGAN', 'Unencoded'], loc='upper right', prop={'size':10})
ax1.set_xlim(left=0, right=1)
ax1.set_title('(b)', y=-0.3)
ax1.set_xlabel('variance along ally label')
ax1.set_ylabel('accuracy')
ax1.set_xticks([0.2, 0.4, 0.6, 0.8])
ax1.grid()

bp1 = ax2.boxplot(bertran_advr, vert=1, whis=1.5, widths=0.05, 
                  positions=[0.2, 0.4, 0.6, 0.8], showfliers=False, patch_artist=True)
plt.setp(bp1['boxes'], color='cyan')
plt.setp(bp1['medians'], color='black')
bp2 = ax2.boxplot(eigan_advr, vert=1, whis=1.5, widths=0.05, 
                  positions=[0.15, 0.35, 0.55, 0.75], showfliers=False, patch_artist=True)
plt.setp(bp2['boxes'], color='red')
plt.setp(bp2['medians'], color='black')
bp3 = ax2.boxplot(baseline_advr, vert=1, whis=1.5, widths=0.05, 
                  positions=[0.25, 0.45, 0.65, 0.85], showfliers=False, patch_artist=True)
plt.setp(bp3['boxes'], color='orange')
plt.setp(bp3['medians'], color='black')
ax2.set_xlim(left=0, right=1)
ax2.set_title('(a)', y=-0.3)
ax2.set_xlabel('variance along ally label')
ax2.set_ylabel('accuracy')
ax2.set_xticks([0.2, 0.4, 0.6, 0.8])
ax2.grid()

fig.subplots_adjust(wspace=0.3)
plt.savefig('0.4_advr_varying_ally_comparison.png', bbox_inches='tight', dpi=300)
plt.show()