In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

In [None]:
import sys
sys.path.append("..")

In [None]:
%load_ext autoreload
%autoreload 2
from common.utility import to_categorical, torch_device
from notebook_utils.generate_gaussian import generate_gaussian
from notebook_utils.eigan import Encoder, Discriminator
from notebook_utils.federated import federated
from notebook_utils.eigan_training import distributed, centralized
from notebook_utils.utility import class_plot, to_numpy
import notebook_utils.metrics as metrics

In [None]:
device='gpu'
device = torch_device(device=device)
device

# Training

In [None]:
MAX = 1024
history = {}
NUM_TRIALS = 10
VAR1 = 3
VAR2 = 3
NUM_NODES = 2
PHI=1
BATCHSIZE=512

In [None]:
X1, y_11, y_21 = generate_gaussian(VAR1/10, VAR2/10, 1000, 1)
X2, y_12, y_22 = generate_gaussian((VAR1)/10, (VAR2)/10, 1000, 1, MEAN=2)
X = [X1, X2] 
y_1 = [y_11, y_12]
y_2 = [y_21, y_22]

print('='*80)
print('VAR1: {}, VAR2: {}'.format(VAR1/10, VAR2/10))
print('='*80)
for _ in range(NUM_NODES):
    print('@node {}, X: {}, y_1: {}, y2: {}'.format(_, X[_].shape, y_1[_].shape, y_2[_].shape))
    
w_1 = []
w_2 = []
train_loaders = []
X_valids = []
X_trains = []
y_1_valids = []
y_1_trains = []
y_2_valids = []
y_2_trains = []
for node_idx in range(NUM_NODES):
    X_local = X[node_idx]
    y_1_local = y_1[node_idx]
    y_2_local = y_2[node_idx]
    
    X_train, X_valid, y_1_train, y_1_valid, y_2_train, y_2_valid = train_test_split(
        X_local, y_1_local, y_2_local, test_size=0.2, stratify=pd.DataFrame(
            np.concatenate((y_1_local, y_2_local), axis=1)
        ))
    print('@node {}: X_train, X_valid, y_1_train, y_1_valid, y_2_train, y_2_valid'.format(node_idx))
    print(X_train.shape, X_valid.shape, y_1_train.shape, y_1_valid.shape, y_2_train.shape, y_2_valid.shape)

    w = np.bincount(y_1_train.flatten())
    w_1.append(sum(w)/w)
    w = np.bincount(y_2_train.flatten())
    w_2.append(sum(w)/w)
    print('@node {}: class weights => w1, w2'.format(node_idx), w_1, w_2)

    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(X_train)
    X_valid = scaler.transform(X_valid)
    
    width = 0.35
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(13, 4))
    ax1.bar(np.unique(y_1_train.flatten()), np.bincount(y_1_train.flatten()), width, color='b')
    ax1.bar(np.unique(y_2_train.flatten())+width, np.bincount(y_2_train.flatten()), width, color='r')
    ax1.legend(['ally', 'adversary'])
    
    y_1_train = to_categorical(y_1_train)
    y_2_train = to_categorical(y_2_train)
    y_1_valid = to_categorical(y_1_valid)
    y_2_valid = to_categorical(y_2_valid)

    X_train = torch.Tensor(X_train)
    y_1_train = torch.Tensor(y_1_train)
    y_2_train = torch.Tensor(y_2_train)

    X_valids.append(torch.Tensor(X_valid))
    y_1_valids.append(torch.Tensor(y_1_valid))
    y_2_valids.append(torch.Tensor(y_2_valid))
    
    X_trains.append(X_train)
    y_1_trains.append(y_1_train)
    y_2_trains.append(y_2_train)

    class_plot(X_train, np.argmax(y_1_train, axis=1), np.argmax(y_2_train, axis=1), 'normalized train set @ node {}'.format(node_idx), ax2)
    class_plot(X_valid, np.argmax(y_1_valid, axis=1), np.argmax(y_2_valid, axis=1), 'normalized valid set @ node {}'.format(node_idx), ax3)

    print('@node {}: tensor sizes =>'.format(node_idx), X_train.shape, X_valid.shape, y_1_train.shape, y_1_valid.shape, y_2_train.shape, y_2_valid.shape)

    train_loaders.append(DataLoader(TensorDataset(X_train, y_1_train, y_2_train), batch_size=BATCHSIZE, shuffle=True))

In [None]:
alpha = 1
lr_encd = 0.001
lr_1 = 0.0001
lr_2 = 0.0001
n_iter_gan = 501

input_size = X_train.shape[1]
hidden_size = input_size*8
output_size = 2

In [None]:
global_params = {}
encoders = {}

In [None]:
import pickle as pkl
history = pkl.load(open('history_numnodes_2_phi_delta_gaussian.pkl', 'rb'))

In [None]:
delta = 2
for phi in range(0, 11, 2):
    phi /= 10    
    print("-"*80)
    print('EIGAN Training w/ phi={} and delta={}'.format(phi, delta))
    print("-"*80)
    encoders['{}_{}'.format(phi, delta)] = distributed(NUM_NODES, phi, delta, 
                           X_trains, X_valids, 
                           y_1_trains, y_1_valids,
                           y_2_trains, y_2_valids,
                           input_size, hidden_size, output_size, 
                           alpha, lr_encd, lr_1, lr_2, w_1, w_2,
                           train_loaders, n_iter_gan, device, global_params)
        
phi = 0.8
for delta in range(0, 11, 2):

    print("-"*80)
    print('EIGAN Training w/ phi={} and delta={}'.format(phi, delta))
    print("-"*80)
    encoders['{}_{}'.format(phi, delta)] = distributed(NUM_NODES, phi, delta, 
                           X_trains, X_valids, 
                           y_1_trains, y_1_valids,
                           y_2_trains, y_2_valids,
                           input_size, hidden_size, output_size, 
                           alpha, lr_encd, lr_1, lr_2, w_1, w_2,
                           train_loaders, n_iter_gan, device, global_params)

In [None]:
import pickle as pkl
pkl.dump(encoders, open('encoders_numnodes_2_phi_delta_gaussian.pkl', 'wb'))

In [None]:
import pickle as pkl
encoders = pkl.load(open('encoders_numnodes_2_phi_delta_gaussian.pkl', 'rb'))

In [None]:
X_train = torch.cat(X_trains, dim=0).to(device)
X_valid = torch.cat(X_valids, dim=0).to(device)
y_1_train = torch.cat(y_1_trains, dim=0).to(device)
y_1_valid = torch.cat(y_1_valids, dim=0).to(device)
y_2_train = torch.cat(y_2_trains, dim=0).to(device)
y_2_valid = torch.cat(y_2_valids, dim=0).to(device)

In [None]:
train_loader = DataLoader(TensorDataset(X_train, y_1_train, y_2_train), batch_size=128, shuffle=True)

encoder = centralized(X_train, X_valid,
                      y_1_train, y_1_valid,
                      y_2_train, y_2_valid,
                      input_size, hidden_size, output_size,
                      alpha, lr_encd, lr_1, lr_2, w_1[0], w_2[0],
                      train_loader, n_iter_gan, device)

In [None]:
import pickle as pkl
pkl.dump(encoder, open('encoder_numnodes_2_phi_delta_gaussian.pkl', 'wb'))

In [None]:
import pickle as pkl
encoder = pkl.load(open('encoder_numnodes_2_phi_delta_gaussian.pkl', 'rb'))

In [None]:
print("-"*80)
print('ALLY: BASELINE')
print("-"*80)
history['baseline_ally'] = metrics.centralized(None, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_1_train, y_1_valid, 
                                             w_1[0], device)

print("-"*80)
print('ADVERSARY: BASELINE')
print("-"*80)
history['baseline_advr'] = metrics.centralized(None, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_2_train, y_2_valid, 
                                             w_2[0], device)

In [None]:
NUM_NODES = 2
X_trains = [_.to(device) for _ in X_trains]
X_valids = [_.to(device) for _ in X_valids]
y_1_trains = [_.to(device) for _ in y_1_trains]
y_1_valids = [_.to(device) for _ in y_1_valids]
y_2_trains = [_.to(device) for _ in y_2_trains]
y_2_valids = [_.to(device) for _ in y_2_valids]

In [None]:
print("-"*80)
print('ALLY: CENTRALIZED')
print("-"*80)
history['centralized_ally'] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_1_train, y_1_valid, 
                                             w_1[0], device)

print("-"*80)
print('ADVERSARY: CENTRALIZED')
print("-"*80)
history['centralized_advr'] = metrics.centralized(encoder, 
                                             input_size, hidden_size, output_size, 
                                             X_train, X_valid, y_2_train, y_2_valid, 
                                             w_2[0], device)

In [None]:
# for key, encd in encoders.items():
delta = 2
for phi in range(0, 11, 2):
    phi /= 10
    key = '{}_{}'.format(phi, delta)
    encd = encoders[key]
    print("-"*80)
    print('ALLY: {}'.format(key))
    print("-"*80)
    history['decentralize_ally_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                         input_size, hidden_size, output_size, 
                                                         X_trains, X_valids, y_1_trains, y_1_valids, 
                                                         w_1[0], device)
    print("-"*80)
    print('ADVERSARY: {}'.format(key))
    print("-"*80)
    history['decentralized_advr_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_2_trains, y_2_valids, 
                                                                     w_2[0], device)
    
phi = 0.8
for delta in range(0, 11, 2):
    key = '{}_{}'.format(phi, delta)
    encd = encoders[key]
    print("-"*80)
    print('ALLY: {}'.format(key))
    print("-"*80)
    history['decentralize_ally_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                         input_size, hidden_size, output_size, 
                                                         X_trains, X_valids, y_1_trains, y_1_valids, 
                                                         w_1[0], device)
    print("-"*80)
    print('ADVERSARY: {}'.format(key))
    print("-"*80)
    history['decentralized_advr_{}'.format(key)] = metrics.distributed(encd, NUM_NODES,
                                                                     input_size, hidden_size, output_size, 
                                                                     X_trains, X_valids, y_2_trains, y_2_valids, 
                                                                     w_2[0], device)

In [None]:
import pickle as pkl
pkl.dump(history, open('history_numnodes_2_phi_delta_gaussian.pkl', 'wb'))

In [None]:
delta = 2
baseline_ally = []
baseline_advr = []
eigan_ally = []
eigan_advr = []
dist_x = []
dist_ally = []
dist_advr = []

tmp = history['baseline_ally'][3]
baseline_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
tmp = history['baseline_advr'][3]
baseline_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
tmp = history['centralized_ally'][3]
eigan_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
tmp = history['centralized_advr'][3]
eigan_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))

fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(10, 4))

for phi in range(0, 11, 2):
    phi /= 10
    dist_x.append(phi)
    tmp = history['decentralize_{}_{}_{}'.format('ally', phi, delta)][3]
    dist_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_{}_{}_{}'.format('advr', phi, delta)][3]
    dist_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))

ax1.hlines(y=eigan_ally[0], xmin=-0.1, xmax=1.1, color='b', linestyle='dashed')
ax1.hlines(y=eigan_advr[0], xmin=-0.1, xmax=1.1, color='r', linestyle='dashed')
ax1.bar(np.array(dist_x)-0.025, dist_ally, width=0.05, color='b')
ax1.bar(np.array(dist_x)+0.025, dist_advr, width=0.05, color='r')
ax1.set_xticks(dist_x)
ax1.set_xlim(left=-0.1, right=1.1)
ax1.legend(['c-ally', 'c-adversary', 'd-ally', 'd-adversary'], loc='lower right')
ax1.set_xlabel('fraction of parameters shared')
ax1.set_ylabel('f1 score')
ax1.set_title('(b)', y=-0.3)
ax1.grid()

phi = 0.8
dist_x = []
dist_ally = []
dist_advr = []
for delta in range(0, 11, 2):
    dist_x.append(delta)
    tmp = history['decentralize_{}_{}_{}'.format('ally', phi, delta)][3]
    dist_ally.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))
    tmp = history['decentralized_{}_{}_{}'.format('advr', phi, delta)][3]
    dist_advr.append(max(tmp['svm'], tmp['logistic'], tmp['mlp']))

ax2.hlines(y=eigan_ally[0], xmin=0, xmax=11, color='b', linestyle='dashed')
ax2.hlines(y=eigan_advr[0], xmin=0, xmax=11, color='r', linestyle='dashed')
ax2.bar(np.array(dist_x)-0.25, dist_ally, width=0.5, color='b')
ax2.bar(np.array(dist_x)+0.25, dist_advr, width=0.5, color='r')
ax2.set_xticks(dist_x)
ax2.set_xlim(left=0, right=11)
ax2.set_xlabel('frequency of sync')
ax2.set_ylabel('f1 score')
ax2.set_title('(a)', y=-.3)
ax2.grid()

plt.rcParams.update({'font.size': 14})
fig.subplots_adjust(wspace=0.3)
plt.savefig('distributed_eigan_comparison.png', bbox_inches='tight', dpi=300)