In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle as pkl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

In [None]:
import sys
sys.path.append("..")

In [None]:
%load_ext autoreload
%autoreload 2
from common.utility import to_categorical, torch_device
from notebook_utils.generate_gaussian import generate_gaussian
from notebook_utils.eigan import Encoder, Discriminator
from notebook_utils.federated import federated
from notebook_utils.eigan_training import distributed_3, centralized_3
from notebook_utils.utility import class_plot, to_numpy
import notebook_utils.metrics as metrics

In [None]:
device='gpu'
device = torch_device(device=device)
device

# Training

In [None]:
X_all, y_ally_all, y_advr_1_all, y_advr_2_all = pkl.load(
    open('../checkpoints/mimic/processed_data_X_y_ally_y_advr_y_advr_2.pkl', 'rb'))

y_ally_all = y_ally_all.reshape(-1, 1)
y_advr_1_all = y_advr_1_all.reshape(-1, 1)
y_advr_2_all = y_advr_2_all.reshape(-1, 1)

width=0.2
fig, (ax1) = plt.subplots(1, 1, figsize=(9, 4))
ax1.bar(np.unique(y_ally_all.flatten())-width, np.bincount(y_ally_all.flatten()), width, color='b')
ax1.bar(np.unique(y_advr_1_all.flatten()), np.bincount(y_advr_1_all.flatten()), width, color='r', hatch='o')
ax1.bar(np.unique(y_advr_2_all.flatten())+width, np.bincount(y_advr_2_all.flatten()), width, color='r', hatch='-')
ax1.legend(['ally', 'adversary 1', 'adversary 2'])

plt.show()

X_all.shape, y_ally_all.shape, y_advr_1_all.shape, y_advr_2_all.shape

In [None]:
BATCHSIZE=512

phi = 0.8
delta = 2
n_iter = 501
n_iter_gan = 501

In [None]:
for NUM_NODES in range(2, 11):
    X, y_ally, y_advr_1, y_advr_2 = [], [], [], []
    crossovers = [0] + sorted(np.random.randint(0, X_all.shape[0], size=(NUM_NODES-1,))) + [X_all.shape[0]]
    for _ in range(NUM_NODES):
        X.append(X_all[crossovers[_]: crossovers[_+1]])
        y_ally.append(y_ally_all[crossovers[_]: crossovers[_+1]])
        y_advr_1.append(y_advr_1_all[crossovers[_]: crossovers[_+1]])
        y_advr_2.append(y_advr_2_all[crossovers[_]: crossovers[_+1]])
        

    for _ in range(NUM_NODES):
        print('@node {}, X: {}, y_ally: {}, y_advr_1: {}, y_advr_2: {}'.format(
            _, X[_].shape, y_ally[_].shape, y_advr_1[_].shape, y_advr_2[_].shape))

    w_ally = []
    w_advr_1 = []
    w_advr_2 = []
    train_loaders = []
    X_valids = []
    X_trains = []
    y_ally_valids = []
    y_ally_trains = []
    y_advr_1_valids = []
    y_advr_1_trains = []
    y_advr_2_valids = []
    y_advr_2_trains = []
    for node_idx in range(NUM_NODES):
        X_local = X[node_idx]
        y_ally_local = y_ally[node_idx]
        y_advr_1_local = y_advr_1[node_idx]
        y_advr_2_local = y_advr_2[node_idx]

        X_train, X_valid, y_ally_train, y_ally_valid, \
        y_advr_1_train, y_advr_1_valid, \
        y_advr_2_train, y_advr_2_valid = train_test_split(
            X_local, y_ally_local, y_advr_1_local, y_advr_2_local, test_size=0.2, stratify=pd.DataFrame(
                np.concatenate((y_ally_local, y_advr_1_local, y_advr_2_local), axis=1)
            ))
        print('@node {}: X_train, X_valid, y_ally_train, y_ally_valid,'
              'y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid'.format(node_idx))
        print(X_train.shape, X_valid.shape, 
              y_ally_train.shape, y_ally_valid.shape,
              y_advr_1_train.shape, y_advr_1_valid.shape, 
              y_advr_2_train.shape, y_advr_2_valid.shape)

        w = np.bincount(y_ally_train.flatten())
        w_ally.append(sum(w)/w)
        w = np.bincount(y_advr_1_train.flatten())
        w_advr_1.append(sum(w)/w)
        w = np.bincount(y_advr_2_train.flatten())
        w_advr_2.append(sum(w)/w)
        print('@node {}: class weights => w_ally, w_advr_1, w_advr_2'.format(node_idx), w_ally, w_advr_1, w_advr_2)

        scaler = MinMaxScaler()
        X_train = scaler.fit_transform(X_train)
        X_valid = scaler.transform(X_valid)

        width = 0.2
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
        ax1.bar(np.unique(
            y_ally_train.flatten())-width, np.bincount(y_ally_train.flatten()), width, color='b')
        ax1.bar(np.unique(
            y_advr_1_train.flatten()), np.bincount(y_advr_1_train.flatten()), width, color='r', hatch='o')
        ax1.bar(np.unique(
            y_advr_2_train.flatten())+width, np.bincount(y_advr_2_train.flatten()), width, color='r', hatch='-')
        ax1.legend(['ally', 'adversary 1', 'adversary 2'])
        ax1.set_title('train@{}'.format(node_idx+1))
        ax2.bar(np.unique(
            y_ally_valid.flatten())-width, np.bincount(y_ally_valid.flatten()), width, color='b')
        ax2.bar(np.unique(
            y_advr_1_valid.flatten()), np.bincount(y_advr_1_valid.flatten()), width, color='r', hatch='o')
        ax2.bar(np.unique(
            y_advr_2_valid.flatten())+width, np.bincount(y_advr_2_valid.flatten()), width, color='r', hatch='-')
        ax2.legend(['ally', 'adversary 1', 'adversary 2'])
        ax2.set_title('valid@{}'.format(node_idx+1))

        y_ally_train = to_categorical(y_ally_train)
        y_ally_valid = to_categorical(y_ally_valid)
        y_advr_1_train = to_categorical(y_advr_1_train)
        y_advr_2_train = to_categorical(y_advr_2_train)
        y_advr_1_valid = to_categorical(y_advr_1_valid)
        y_advr_2_valid = to_categorical(y_advr_2_valid)

        X_train = torch.Tensor(X_train)
        y_ally_train = torch.Tensor(y_ally_train)
        y_advr_1_train = torch.Tensor(y_advr_1_train)
        y_advr_2_train = torch.Tensor(y_advr_2_train)

        X_valids.append(torch.Tensor(X_valid))
        y_ally_valids.append(torch.Tensor(y_ally_valid))
        y_advr_1_valids.append(torch.Tensor(y_advr_1_valid))
        y_advr_2_valids.append(torch.Tensor(y_advr_2_valid))

        X_trains.append(X_train)
        y_ally_trains.append(y_ally_train)
        y_advr_1_trains.append(y_advr_1_train)
        y_advr_2_trains.append(y_advr_2_train)

        print('@node {}: tensor sizes =>'.format(node_idx), X_train.shape, X_valid.shape, 
              y_ally_train.shape, y_ally_valid.shape,
              y_advr_1_train.shape, y_advr_1_valid.shape, y_advr_2_train.shape, y_advr_2_valid.shape)

        train_loaders.append(DataLoader(TensorDataset(X_train, y_ally_train, y_advr_1_train, y_advr_2_train), 
                                        batch_size=BATCHSIZE, shuffle=True))

    alpha = 1
    lr_encd = 0.0001
    lr_ally = 0.00001
    lr_advr_1 = 0.00001
    lr_advr_2 = 0.00001

    input_size = X_train.shape[1]
    hidden_size = input_size*8
    output_size = 2
    
    global_params = {}
    encoders = {}
    history = {}

    print("-"*80)
    print('EIGAN Training w/ phi={} and delta={}'.format(phi, delta))
    print("-"*80)
    encoders['{}_{}'.format(phi, delta)] = distributed_3(NUM_NODES, phi, delta, 
                   X_trains, X_valids, 
                   y_ally_train, y_ally_valids,                            
                   y_advr_1_trains, y_advr_1_valids,
                   y_advr_2_trains, y_advr_2_valids,
                   input_size, hidden_size, [2]*NUM_NODES, 
                   alpha, lr_encd, lr_ally, lr_advr_1, lr_advr_2,w_ally, w_advr_1, w_advr_2,
                   train_loaders, n_iter_gan, device, global_params)

    pkl.dump(encoders, open(
        'encoders_mimic_num_nodes{}_phi{}_delta{}.pkl'.format(NUM_NODES, phi, delta), 'wb'))

    X_train, X_valid, y_ally_train, y_ally_valid, \
        y_advr_1_train, y_advr_1_valid, y_advr_2_train, y_advr_2_valid = train_test_split(
            X, y_ally, y_advr_1, y_advr_2, test_size=0.3)
    X_train = torch.Tensor(X_train).to(device)
    X_valid = torch.Tensor(X_valid).to(device)
    y_ally_train = torch.Tensor(y_ally_train).to(device)
    y_ally_valid = torch.Tensor(y_ally_valid).to(device)
    y_advr_1_train = torch.Tensor(y_advr_1_train).to(device)
    y_advr_1_valid = torch.Tensor(y_advr_1_valid).to(device)
    y_advr_2_train = torch.Tensor(y_advr_2_train).to(device)
    y_advr_2_valid = torch.Tensor(y_advr_2_valid).to(device)

    X_trains = [_.to(device) for _ in X_trains]
    X_valids = [_.to(device) for _ in X_valids]
    y_ally_trains = [_.to(device) for _ in y_ally_trains]
    y_ally_valids = [_.to(device) for _ in y_ally_valids]
    y_advr_1_trains = [_.to(device) for _ in y_advr_1_trains]
    y_advr_1_valids = [_.to(device) for _ in y_advr_1_valids]
    y_advr_2_trains = [_.to(device) for _ in y_advr_2_trains]
    y_advr_2_valids = [_.to(device) for _ in y_advr_2_valids]

    train_loader = DataLoader(TensorDataset(X_train, y_ally_train, y_advr_1_train, y_advr_2_train), batch_size=BATCHSIZE, shuffle=True)

    encoder = centralized_3(X_train, X_valid,
                      y_ally_train, y_ally_valid,
                      y_advr_1_train, y_advr_1_valid,
                      y_advr_2_train, y_advr_2_valid,
                      input_size, hidden_size, [2, 2, 2],
                      alpha, lr_encd, lr_ally, lr_advr_1, lr_advr_2, w_ally[0], w_advr_1[0], w_advr_2[0],
                      train_loader, n_iter_gan, device)

    pkl.dump(encoder, open('encoder_mimic_num_nodes{}_central_compare.pkl'.format(NUM_NODES), 'wb'))

    print("-"*80)
    print('ALLY: BASELINE')
    print("-"*80)
    history['baseline_ally_{}'.format(NUM_NODES)] = metrics.centralized(None, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_ally_train, y_ally_valid, 
                                             w_ally[0], device, ['mlp', 'logistic'], n_iter)
    print("-"*80)
    print('ADVR 1: BASELINE')
    print("-"*80)
    history['baseline_advr_1_{}'.format(NUM_NODES)] = metrics.centralized(None, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_advr_1_train, y_advr_1_valid, 
                                             w_advr_1[0], device, ['mlp', 'logistic'], n_iter)

    print("-"*80)
    print('ADVR 2: BASELINE')
    print("-"*80)
    history['baseline_advr_2_{}'.format(NUM_NODES)] = metrics.centralized(None, 
                                                 input_size, hidden_size, output_size, 
                                                 X_train, X_valid, y_advr_2_train, y_advr_2_valid, 
                                                 w_advr_2[0], device, ['mlp', 'logistic'], n_iter)
    
    print("-"*80)
    print('ALLY: CENTRALIZED')
    print("-"*80)
    history['centralized_ally_{}'.format(NUM_NODES)] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_ally_train, y_ally_valid, 
                                             w_ally[0], device, ['mlp', 'logistic'], n_iter)
    print("-"*80)
    print('ADVR 1: CENTRALIZED')
    print("-"*80)
    history['centralized_advr_1_{}'.format(NUM_NODES)] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_advr_1_train, y_advr_1_valid, 
                                             w_advr_1[0], device, ['mlp', 'logistic'], n_iter)

    print("-"*80)
    print('ADVR 2: CENTRALIZED')
    print("-"*80)
    history['centralized_advr_2_{}'.format(NUM_NODES)] = metrics.centralized(encoder, 
                                                 input_size, hidden_size, output_size, 
                                                 X_train, X_valid, y_advr_2_train, y_advr_2_valid, 
                                                 w_advr_2[0], device, ['mlp', 'logistic'], n_iter)


    for key, encd in encoders.items():
        print("-"*80)
        print('ALLY: {}'.format(key))
        print("-"*80)
        history['decentralized_ally_{}'.format(NUM_NODES)] = metrics.distributed(encd, NUM_NODES,
                                                             input_size, hidden_size, output_size, 
                                                             X_trains, X_valids, y_ally_trains, y_ally_valids, 
                                                             w_ally[0], device, ['mlp', 'logistic'], n_iter)
        print("-"*80)
        print('ADVR 1: {}'.format(key))
        print("-"*80)
        history['decentralized_advr_1_{}'.format(NUM_NODES)] = metrics.distributed(encd, NUM_NODES,
                                                             input_size, hidden_size, output_size, 
                                                             X_trains, X_valids, y_advr_1_trains, y_advr_1_valids, 
                                                             w_advr_1[0], device, ['mlp', 'logistic'], n_iter)
        print("-"*80)
        print('ADVR 2: {}'.format(key))
        print("-"*80)
        history['decentralized_advr_2_{}'.format(NUM_NODES)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_advr_2_trains, y_advr_2_valids, 
                                                                     w_advr_2[0], device, ['mlp', 'logistic'], n_iter)
        
    baseline_ally = []
    baseline_advr_1 = []
    baseline_advr_2 = []
    eigan_ally = []
    eigan_advr_1 = []
    eigan_advr_2 = []
    dist_x = []
    dist_ally = []
    dist_advr_1 = []
    dist_advr_2 = []

    tmp = history['baseline_ally_{}'.format(NUM_NODES)][2]
    baseline_ally.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['baseline_advr_1_{}'.format(NUM_NODES)][2]
    baseline_advr_1.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['baseline_advr_2_{}'.format(NUM_NODES)][2]
    baseline_advr_2.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['centralized_ally_{}'.format(NUM_NODES)][2]
    eigan_ally.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['centralized_advr_1_{}'.format(NUM_NODES)][2]
    eigan_advr_1.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['centralized_advr_2_{}'.format(NUM_NODES)][2]
    eigan_advr_2.append(max(tmp['logistic'], tmp['mlp']))

    fig, (ax1) = plt.subplots(1, 1, figsize=(5, 4))

    dist_x.append(phi)
    tmp = history['decentralized_ally_{}'.format(NUM_NODES)][2]
    dist_ally.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_advr_1_{}'.format(NUM_NODES)][2]
    dist_advr_1.append(max(tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_advr_2_{}'.format(NUM_NODES)][2]
    dist_advr_2.append(max(tmp['logistic'], tmp['mlp']))

    ax1.hlines(y=eigan_ally[0], xmin=-0.1, xmax=1.1, color='b', linestyle='dashed')
    ax1.hlines(y=eigan_advr_1[0], xmin=-0.1, xmax=1.1, color='r', linestyle='dashed')
    ax1.hlines(y=eigan_advr_2[0], xmin=-0.1, xmax=1.1, color='c', linestyle='dashed')
    ax1.bar(np.array(dist_x)-0.05, dist_ally, width=0.05, color='b')
    ax1.bar(np.array(dist_x), dist_advr_1, width=0.05, color='r')
    ax1.bar(np.array(dist_x)+0.05, dist_advr_2, width=0.05, color='c')
    ax1.set_xticks(dist_x)
    ax1.set_xlim(left=-0.1, right=1.1)
    ax1.legend(['E-ally', 'E-advr 1', 'E-advr 2', 'D`-ally', 'D-advr 1', 'D-advr 2'], loc='lower right')
    ax1.set_xlabel('fraction of parameters shared')
    ax1.set_ylabel('accuracy')
    ax1.set_title('(b)', y=-0.3)
    ax1.grid()

    plt.rcParams.update({'font.size': 14})
    fig.subplots_adjust(wspace=0.3)
    plt.savefig('mimic_distributed_eigan_comparison.png', bbox_inches='tight', dpi=300)

In [None]:
history

In [None]:
import pickle as pkl
pkl.dump(history, open('history_mimic_num_nodes_expt.pkl', 'wb'))

In [None]:
num_nodes = []
baseline_ally = []
baseline_advr = []
eigan_ally = []
eigan_advr = []
dist_ally = []
dist_advr = []

for _ in range(2, 11):
    num_nodes.append(_)
    tmp = history['baseline_ally_{}'.format(_)][3]
    baseline_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['baseline_advr_{}'.format(_)][3]
    baseline_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['centralized_ally_{}'.format(_)][3]
    eigan_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['centralized_advr_{}'.format(_)][3]
    eigan_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['decentralize_ally_{}'.format(_)][3]
    dist_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_advr_{}'.format(_)][3]
    dist_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))

In [None]:
eigan_advr

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 4))

ax1.bar(np.array(num_nodes)-0.3, eigan_advr, width=0.25, color='b', hatch='x')
ax1.bar(np.array(num_nodes), dist_advr, width=0.25, color='b')
ax1.bar(np.array(num_nodes)+0.3, baseline_advr, width=0.25, color='b', hatch='-')
ax1.set_xticks(dist_x)
# ax1.set_xlim(left=-0.1, right=1.1)
ax1.legend(['c-ally', 'c-adversary', 'd-ally', 'd-adversary'], loc='lower right')
ax1.set_xlabel('fraction of parameters shared')
ax1.set_ylabel('f1 score')
ax1.set_title('(b)', y=-0.3)
ax1.grid()

plt.rcParams.update({'font.size': 14})
fig.subplots_adjust(wspace=0.3)
plt.savefig('num_nodes_comparison.png', bbox_inches='tight', dpi=300)