In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle as pkl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

In [None]:
import sys
sys.path.append("..")

In [None]:
%load_ext autoreload
%autoreload 2
from common.utility import to_categorical, torch_device
from notebook_utils.generate_gaussian import generate_gaussian
from notebook_utils.eigan import Encoder, Discriminator
from notebook_utils.federated import federated
from notebook_utils.eigan_training import distributed, centralized
from notebook_utils.utility import class_plot, to_numpy
import notebook_utils.metrics as metrics

In [None]:
device='gpu'
device = torch_device(device=device)
device

# Training

In [None]:
MAX = 2048
NUM_TRIALS = 10
VAR1 = 1
VAR2 = 1
NUM_NODES = 2
PHI=1
BATCHSIZE=512

phi = 0.8
delta = 10

In [None]:
master = {}
for run in [0] + list(range(8,10)):
    print('run: ---------------------------------------->', run)
    history = {}
    for NUM_NODES in range(2, 11):
    # for NUM_NODES in [10]:
        VAR_ADD = 1
        X, y_1, y_2 = [], [], []
        for _ in range(NUM_NODES):
            data = generate_gaussian((VAR1+VAR_ADD)/10, (VAR2+VAR_ADD)/10, MAX//NUM_NODES, 1)
            X.append(data[0])
            y_1.append(data[1])
            y_2.append(data[2])
            VAR_ADD += 1

        print('='*80)
        print('{} NODE DATA'.format(NUM_NODES))
        print('='*80)
        for _ in range(NUM_NODES):
            print('@node {}, X: {}, y_1: {}, y2: {}'.format(_, X[_].shape, y_1[_].shape, y_2[_].shape))

        w_1 = []
        w_2 = []
        train_loaders = []
        X_valids = []
        X_trains = []
        y_1_valids = []
        y_1_trains = []
        y_2_valids = []
        y_2_trains = []
        for node_idx in range(NUM_NODES):
            X_local = X[node_idx]
            y_1_local = y_1[node_idx]
            y_2_local = y_2[node_idx]

            X_train, X_valid, y_1_train, y_1_valid, y_2_train, y_2_valid = train_test_split(
                X_local, y_1_local, y_2_local, test_size=0.2, stratify=pd.DataFrame(
                    np.concatenate((y_1_local, y_2_local), axis=1)
                ))
            print('@node {}: X_train, X_valid, y_1_train, y_1_valid, y_2_train, y_2_valid'.format(node_idx))
            print(X_train.shape, X_valid.shape, y_1_train.shape, y_1_valid.shape, y_2_train.shape, y_2_valid.shape)

            w = np.bincount(y_1_train.flatten())
            w_1.append(sum(w)/w)
            w = np.bincount(y_2_train.flatten())
            w_2.append(sum(w)/w)
            print('@node {}: class weights => w1, w2'.format(node_idx), w_1, w_2)

            scaler = MinMaxScaler()
            X_train = scaler.fit_transform(X_train)
            X_valid = scaler.transform(X_valid)

            width = 0.35
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(13, 4))
            ax1.bar(np.unique(y_1_train.flatten()), np.bincount(y_1_train.flatten()), width, color='b')
            ax1.bar(np.unique(y_2_train.flatten())+width, np.bincount(y_2_train.flatten()), width, color='r')
            ax1.legend(['ally', 'adversary'])

            y_1_train = to_categorical(y_1_train)
            y_2_train = to_categorical(y_2_train)
            y_1_valid = to_categorical(y_1_valid)
            y_2_valid = to_categorical(y_2_valid)

            X_train = torch.Tensor(X_train)
            y_1_train = torch.Tensor(y_1_train)
            y_2_train = torch.Tensor(y_2_train)

            X_valids.append(torch.Tensor(X_valid))
            y_1_valids.append(torch.Tensor(y_1_valid))
            y_2_valids.append(torch.Tensor(y_2_valid))

            X_trains.append(X_train)
            y_1_trains.append(y_1_train)
            y_2_trains.append(y_2_train)

            class_plot(X_train, np.argmax(y_1_train, axis=1), np.argmax(y_2_train, axis=1), 
                       'normalized train set @ node {}'.format(node_idx), ax2)
            class_plot(X_valid, np.argmax(y_1_valid, axis=1), np.argmax(y_2_valid, axis=1), 
                       'normalized valid set @ node {}'.format(node_idx), ax3)

            print('@node {}: tensor sizes =>'.format(node_idx), X_train.shape, X_valid.shape, y_1_train.shape, y_1_valid.shape, y_2_train.shape, y_2_valid.shape)

            train_loaders.append(DataLoader(TensorDataset(X_train, y_1_train, y_2_train), batch_size=BATCHSIZE, shuffle=True))

        alpha = 1
        lr_encd = 0.0001
        lr_1 = 0.0001
        lr_2 = 0.0001
        n_iter_gan = 501

        input_size = X_train.shape[1]
        hidden_size = input_size*8
        output_size = 2

        global_params = {}
        encoders = {}

        print("-"*80)
        print('EIGAN Training w/ phi={} and delta={}'.format(phi, delta))
        print("-"*80)
        encoders['{}_{}'.format(phi, delta)] = distributed(NUM_NODES, phi, delta, 
                               X_trains, X_valids, 
                               y_1_trains, y_1_valids,
                               y_2_trains, y_2_valids,
                               input_size, hidden_size, output_size, 
                               alpha, lr_encd, lr_1, lr_2, w_1, w_2,
                               train_loaders, n_iter_gan, device, global_params)

        pkl.dump(encoders, open(
            'encoders_num_nodes{}_phi{}_delta{}.pkl'.format(NUM_NODES, phi, delta), 'wb'))

        X_train = torch.cat(X_trains, dim=0).to(device)
        X_valid = torch.cat(X_valids, dim=0).to(device)
        y_1_train = torch.cat(y_1_trains, dim=0).to(device)
        y_1_valid = torch.cat(y_1_valids, dim=0).to(device)
        y_2_train = torch.cat(y_2_trains, dim=0).to(device)
        y_2_valid = torch.cat(y_2_valids, dim=0).to(device)

        X_trains = [_.to(device) for _ in X_trains]
        X_valids = [_.to(device) for _ in X_valids]
        y_1_trains = [_.to(device) for _ in y_1_trains]
        y_1_valids = [_.to(device) for _ in y_1_valids]
        y_2_trains = [_.to(device) for _ in y_2_trains]
        y_2_valids = [_.to(device) for _ in y_2_valids]

        train_loader = DataLoader(TensorDataset(X_train, y_1_train, y_2_train), batch_size=BATCHSIZE, shuffle=True)

        encoder = centralized(X_train, X_valid,
                              y_1_train, y_1_valid,
                              y_2_train, y_2_valid,
                              input_size, hidden_size, output_size,
                              alpha, lr_encd, lr_1, lr_2, w_1[0], w_2[0],
                              train_loader, n_iter_gan, device)

        pkl.dump(encoder, open('encoder_num_nodes{}_central_compare.pkl'.format(NUM_NODES), 'wb'))

        print("-"*80)
        print('ALLY: BASELINE')
        print("-"*80)
        history['baseline_ally_{}'.format(NUM_NODES)] = metrics.centralized(None, 
                                                     input_size, hidden_size, output_size, 
                                                     X_train, X_valid, y_1_train, y_1_valid, 
                                                     w_1[0], device)

        print("-"*80)
        print('ADVERSARY: BASELINE')
        print("-"*80)
        history['baseline_advr_{}'.format(NUM_NODES)] = metrics.centralized(None, 
                                                     input_size, hidden_size, output_size, 
                                                     X_train, X_valid, y_2_train, y_2_valid, 
                                                     w_2[0], device)

        print("-"*80)
        print('ALLY: CENTRALIZED')
        print("-"*80)
        history['centralized_ally_{}'.format(NUM_NODES)] = metrics.centralized(encoder, 
                                                     input_size, hidden_size, output_size, 
                                                     X_train, X_valid, y_1_train, y_1_valid, 
                                                     w_1[0], device)

        print("-"*80)
        print('ADVERSARY: CENTRALIZED')
        print("-"*80)
        history['centralized_advr_{}'.format(NUM_NODES)] = metrics.centralized(encoder, 
                                                     input_size, hidden_size, output_size, 
                                                     X_train, X_valid, y_2_train, y_2_valid, 
                                                     w_2[0], device)

        for key, encd in encoders.items():
            print("-"*80)
            print('ALLY: {}'.format(key))
            print("-"*80)
            history['decentralize_ally_{}'.format(NUM_NODES)] = metrics.distributed(encd, NUM_NODES,
                                                                 input_size, hidden_size, output_size, 
                                                                 X_trains, X_valids, y_1_trains, y_1_valids, 
                                                                 w_1[0], device)
            print("-"*80)
            print('ADVERSARY: {}'.format(key))
            print("-"*80)
            history['decentralized_advr_{}'.format(NUM_NODES)] = metrics.distributed(encd, NUM_NODES,
                                                                         input_size, hidden_size, output_size, 
                                                                         X_trains, X_valids, y_2_trains, y_2_valids, 
                                                                         w_2[0], device)

        baseline_ally = []
        baseline_advr = []
        eigan_ally = []
        eigan_advr = []
        dist_x = []
        dist_ally = []
        dist_advr = []

        tmp = history['baseline_ally_{}'.format(NUM_NODES)][3]
        baseline_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['baseline_advr_{}'.format(NUM_NODES)][3]
        baseline_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['centralized_ally_{}'.format(NUM_NODES)][3]
        eigan_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['centralized_advr_{}'.format(NUM_NODES)][3]
        eigan_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))

        fig, (ax1) = plt.subplots(1, 1, figsize=(5, 4))

        dist_x.append(phi)
        tmp = history['decentralize_ally_{}'.format(NUM_NODES)][3]
        dist_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['decentralized_advr_{}'.format(NUM_NODES)][3]
        dist_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))

        ax1.hlines(y=eigan_ally[0], xmin=-0.1, xmax=1.1, color='b', linestyle='dashed')
        ax1.hlines(y=eigan_advr[0], xmin=-0.1, xmax=1.1, color='r', linestyle='dashed')
        ax1.bar(np.array(dist_x)-0.025, dist_ally, width=0.05, color='b')
        ax1.bar(np.array(dist_x)+0.025, dist_advr, width=0.05, color='r')
        ax1.set_xticks(dist_x)
        ax1.set_xlim(left=-0.1, right=1.1)
        ax1.legend(['c-ally', 'c-adversary', 'd-ally', 'd-adversary'], loc='lower right')
        ax1.set_xlabel('fraction of parameters shared')
        ax1.set_ylabel('f1 score')
        ax1.set_title('(b)', y=-0.3)
        ax1.grid()

        plt.rcParams.update({'font.size': 14})
        fig.subplots_adjust(wspace=0.3)
        plt.savefig('distributed_eigan_comparison.png', bbox_inches='tight', dpi=300)
    master[run] = history


In [None]:
master.keys()

In [None]:
import pickle as pkl

In [None]:
pkl.dump(master, open('history_distributed_noniid_numnodes_gaussian_10_runs_0.6.pkl', 'wb'))
# pkl.dump(history, open('history_distributed_noniid_numnodes_gaussian.pkl', 'wb'))

In [None]:
master = pkl.load(open('history_distributed_noniid_numnodes_gaussian_10_runs_0.6.pkl', 'rb'))
master.keys()

In [None]:
num_nodes = []
baseline_ally = np.zeros((9,))
baseline_advr = np.zeros((9,))
eigan_ally = np.zeros((9,))
eigan_advr = np.zeros((9,))
dist_ally = np.zeros((9,))
dist_advr = np.zeros((9,))

# NUM_RUNS = 10
# for run in range(NUM_RUNS):
NUM_RUNS = 3
for run in [0,8,9]:
    history = master[run]
    baseline_ally_iter = []
    baseline_advr_iter = []
    eigan_ally_iter = []
    eigan_advr_iter = []
    dist_ally_iter = []
    dist_advr_iter = []
    for _ in range(2, 11):
        if run == 0:
            num_nodes.append(_)
        tmp = history['baseline_ally_{}'.format(_)][2]
        baseline_ally_iter.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['baseline_advr_{}'.format(_)][2]
        baseline_advr_iter.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['centralized_ally_{}'.format(_)][2]
        eigan_ally_iter.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['centralized_advr_{}'.format(_)][2]
        eigan_advr_iter.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['decentralize_ally_{}'.format(_)][2]
        dist_ally_iter.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        tmp = history['decentralized_advr_{}'.format(_)][2]
        dist_advr_iter.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
        
    baseline_ally += np.array(baseline_ally_iter)    
    baseline_advr += np.array(baseline_advr_iter)
    eigan_ally += np.array(eigan_ally_iter)    
    eigan_advr += np.array(eigan_advr_iter)
    dist_ally += np.array(dist_ally_iter)    
    dist_advr += np.array(dist_advr_iter)
    
baseline_ally /= NUM_RUNS    
baseline_advr /= NUM_RUNS
eigan_ally /= NUM_RUNS    
eigan_advr /= NUM_RUNS
dist_ally /= NUM_RUNS    
dist_advr /= NUM_RUNS

In [None]:
dist_x = list(range(2, 11))

In [None]:
baseline_advr

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))

ax1.bar(np.array(num_nodes)-0.3, eigan_advr, width=0.2, color='r')
ax1.bar(np.array(num_nodes), dist_advr, width=0.2, color='maroon')
ax1.bar(np.array(num_nodes)+0.3, baseline_advr, width=0.2, color='orange')
ax1.set_xticks(dist_x)
# ax1.set_xlim(left=-0.1, right=1.1)
ax1.set_ylim(top=1.0, bottom=0.4)
# ax1.legend(['c-ally', 'c-adversary', 'd-ally', 'd-adversary'], loc='lower right')
ax1.set_xlabel('number of nodes')
ax1.set_ylabel('adversary accuracy')
ax1.set_title('(a)', y=-0.3)
ax1.grid()

ax2.bar(np.array(num_nodes)-0.3, eigan_ally, width=0.2, color='r')
ax2.bar(np.array(num_nodes), dist_ally, width=0.2, color='maroon')
ax2.bar(np.array(num_nodes)+0.3, baseline_ally, width=0.2, color='orange')
ax2.set_xticks(dist_x)
# ax1.set_xlim(left=-0.1, right=1.1)
ax2.set_ylim(top=1.0, bottom=0.6)
ax2.legend(['EIGAN', 'D-EIGAN', 'Unencoded'], loc='upper right', prop={'size': 10})
ax2.set_xlabel('number of nodes')
ax2.set_ylabel('ally accuracy')
ax2.set_title('(b)', y=-0.3)
ax2.grid()


markers = ['o', 'x', 'o', 'x']                                                                                                                                                                                                                                                                                                                                  
colors = ['b', 'r']  
X = data[0]
y_1 = data[1]
y_2 = data[2]
for i in range(2):                                                                                                                                                                                                                                                                                                                                              
    for j in range(2):                                                                                                                                                                                                                                                                                                                                          
        tmp = X[np.intersect1d(                                                                                                                                                                                                                                                                                                                                 
            np.where(y_1 == i)[0], np.where(y_2 == j)[0])]                                                                                                                                                                                                                                                                                                      
        ax3.scatter(tmp[:, 0], tmp[:, 1],                                                                                                                                                                                                                                                                                                                       
                    c=colors[i], marker=markers[2*i+j])                                                                                                                                                                                                                                                                                                         
ax3.set_xlim(left=0, right=3)
ax3.set_ylim(top=3, bottom=0)
ax3.set_yticks([0, 1, 2, 3])
ax3.set_ylabel("ally (reds vs blues)")
ax3.set_xlabel("adversary (x's vs o's)")
# ax3.axis('equal')               `
ax3.set_title('(c)', y=-0.3)  


plt.rcParams.update({'font.size': 14})
fig.subplots_adjust(wspace=0.3)
plt.savefig('figure_distributed_noniid_numnodes_gaussian.png', bbox_inches='tight', dpi=300)

# Metrics for change in ally/advr loss with inc in advrs/allies

In [None]:
# ally_loss stays constant
import pickle as pkl                                                                                                                                                                                                                                                                                                                                                
import matplotlib.pyplot as plt                                                                                                                                                                                                                                                                                                                                     
import matplotlib                                                                                                                                                                                                                                                                                                                                                   
matplotlib.rcParams.update({'font.size': 12})                                                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                                    
_12 = pkl.load(open('../checkpoints/mimic_centralized/n_advr_ind_gan_training_history_02_12_2020_19_30_38.pkl', 'rb'))                                                                                                                                                                                                                                                             
_34 = pkl.load(open('../checkpoints/mimic_centralized/n_advr_ind_gan_training_history_02_12_2020_20_01_01.pkl', 'rb'))                                                                                                                                                                                                                                                             
_56 = pkl.load(open('../checkpoints/mimic_centralized/n_advr_ind_gan_training_history_02_12_2020_23_07_22.pkl', 'rb'))                                                                                                                                                                                                                                                             
_789 = pkl.load(open('../checkpoints/mimic_centralized/n_advr_ind_gan_training_history_02_13_2020_13_54_53.pkl', 'rb'))                                                                                                                                                                                                                                                            
_n = pkl.load(open('../checkpoints/mimic_centralized/n_advr_ind_gan_training_history_02_13_2020_18_31_15.pkl', 'rb'))                                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                                                                                                                    
h = {1: _12, 2: _12, 3: _34, 4:_34, 5: _56, 6: _56, 7: _789, 8: _789, 9: _789, 10: _n}                                                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                                                                                                                    
orig = pkl.load(open('../checkpoints/mimic_centralized/n_ind_training_history_02_02_2020_16_15_58.pkl', 'rb'))
loss_ind = orig['admission_type']['y_valid'][-1]
num_ally = []                                                                                                                                                                                                                                                                                                                                                       
loss = []                                                                                                                                                                                                                                                                                                                                                           
for idx, hist in h.items():                                                                                                                                                                                                                                                                                                                                         
    num_ally.append(idx)                                                                                                                                                                                                                                                                                                                                            
    loss.append(h[idx][idx-1]['advr_valid'][-1])  
    
loss = np.array(loss)
tmp = (loss_ind - loss)/loss_ind
tmp.mean(), tmp.std(), loss.mean(), loss.std()

In [None]:
loss_ind

In [None]:
# advr loss stays constant
import pickle as pkl                                                                                                                                                                                                                                                                                                                                                  
import matplotlib.pyplot as plt                                                                                                                                                                                                                                                                                                                                       
import matplotlib                                                                                                                                                                                                                                                                                                                                                     
matplotlib.rcParams.update({'font.size': 12})                                                                                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                                                                                                      
_13 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_03_2020_17_41_09.pkl', 'rb'))                                                                                                                                                                                                                                                                    
_28 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_05_2020_00_23_29.pkl', 'rb'))                                                                                                                                                                                                                                                                    
_4 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_03_2020_20_09_39.pkl', 'rb'))                                                                                                                                                                                                                                                                     
_5 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_04_2020_00_21_50.pkl', 'rb'))                                                                                                                                                                                                                                                                     
_67 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_04_2020_05_30_09.pkl', 'rb'))                                                                                                                                                                                                                                                                    
_8 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_05_2020_00_23_29.pkl', 'rb'))                                                                                                                                                                                                                                                                     
_9 = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_05_2020_18_09_03.pkl', 'rb'))                                                                                                                                                                                                                                                                     
_n = pkl.load(open('../checkpoints/mimic_centralized/n_ind_gan_training_history_02_04_2020_20_13_29.pkl', 'rb'))                                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                                                                                      
h = {1: _13, 2: _28, 3: _13, 4:_4, 5: _5, 6: _67, 7: _67, 8: _8, 9: _9, 10: _n}                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                                      
orig = pkl.load(open('../checkpoints/mimic_centralized/n_ind_training_history_02_02_2020_16_15_58.pkl', 'rb'))                                                                                                                                                                                                                                                                       
loss_ind = orig['admission_type']['y_valid'][-1]                                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                                                                                      
num_ally = []                                                                                                                                                                                                                                                                                                                                                         
loss = []                                                                                                                                                                                                                                                                                                                                                             
for idx, hist in h.items():                                                                                                                                                                                                                                                                                                                                           
    num_ally.append(idx)                                                                                                                                                                                                                                                                                                                                              
    loss.append(h[idx][idx-1]['advr_valid'][-1])   

loss = np.array(loss)
tmp = (loss-loss_ind)/loss_ind
tmp.mean(), tmp.std(), loss.mean(), loss.std()