In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import mean_squared_error, roc_auc_score
import datetime
from net_archs import MLP

In [2]:
# Define train and test function
def train(net, optimizer, x, y, x_test, y_test, num_epoch=64, batch_size=128):
  test_loss = []

  for n in range(num_epoch):
    # Mini batch sgd
    permutation = torch.randperm(x.size()[0])
    for i in range(0, x.size()[0], batch_size):
      indices = permutation[i:i+batch_size]
      x_mini, y_mini = x[indices], y[indices]
      y_pred = net(x_mini)
      loss = nn.MSELoss()(y_pred, y_mini)
      optimizer.zero_grad()
      loss.mean().backward()
      optimizer.step()
  return test_loss

def test(net, x_test, y_test):
  with torch.no_grad():
    y_pred = net(x_test)
    y_pred = y_pred.detach().numpy()
    y_test = y_test.detach().numpy()

    mse = mean_squared_error(y_test, y_pred)
    auc = roc_auc_score(y_test, y_pred)
    return mse, auc


In [3]:
# Local models
X = torch.load('./dataset2/X_mlp.pt')
y = torch.load('./dataset2/y_mlp.pt')

In [4]:
from sklearn.model_selection import StratifiedKFold

random_state = np.random.randint(1000)

def fl_split(split, x_train, y_train):

    x_eq = []
    y_eq = []

    n_splits = 50
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    skf.get_n_splits(x_train, y_train)

    for i, (_, test_index) in enumerate(skf.split(x_train, y_train)):
        x_eq.append(x_train[test_index])
        y_eq.append(y_train[test_index])

    x_split = []
    y_split = []

    acc = 0
    for s in split:
        x_split.append(torch.cat(x_eq[acc:acc+int(s*n_splits)], 0))
        y_split.append(torch.cat(y_eq[acc:acc+int(s*n_splits)], 0))
        acc += int(s*n_splits)

    return x_split, y_split

In [11]:
from imblearn.over_sampling import RandomOverSampler

def run_local(split, X=X, y=y):
    mses = []
    aucs = []

    kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    X= X.to_numpy()
    y= y.to_numpy()


    start = datetime.datetime.now()
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X, y)):
        print(f"Fold {fold_idx + 1}:")
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        X_train = torch.tensor(X_train)
        X_test = torch.tensor(X_test)
        y_train = torch.reshape(torch.tensor(y_train), (-1, 1))
        y_test = torch.reshape(torch.tensor(y_test), (-1, 1))

        x_split, y_split = fl_split(split, X_train, y_train)

        for X_train, y_train in zip(x_split, y_split):
            ros = RandomOverSampler(random_state=random_state)
            X_train, y_train = ros.fit_resample(X_train, y_train)

            X_train = torch.from_numpy(X_train)
            y_train = torch.reshape(torch.from_numpy(y_train), (-1, 1))
            y_test = torch.reshape(y_test, (-1, 1))


            model = MLP(120, 1, layer_size=64, num_of_layers=2, dropout=False)

            optimizer = torch.optim.Adam(model.parameters(), lr=0.000005, weight_decay=0.00001)
            train(model, optimizer, X_train, y_train, X_test, y_test,num_epoch=10, batch_size=32)

            # Test local MLP
            mse, auc = test(model, X_test, y_test)
            mses.append(mse)
            aucs.append(auc)


    end = datetime.datetime.now()
    time = end - start

    print('split',split)
    print('Training time: ', time)
    print('mse',mses)
    print('auc',aucs)

In [12]:
split = [1]
run_local(split)

Fold 1:


KeyboardInterrupt: 

In [None]:
split = [0.5, 0.5]
run_local(split)

In [None]:
split = [0.9, 0.1]
run_local(split)

In [22]:
split = [0.34, 0.33, 0.33]
run_local(split)

num of clients 3
round 1
split [0.34, 0.33, 0.33]
mse [0.1991727, 0.19995187, 0.20034055]
auc [0.7564582860706495, 0.7551584450042247, 0.7475259886511583]


In [23]:
split = [0.6, 0.2, 0.2]
run_local(split)

num of clients 3
round 1
split [0.6, 0.2, 0.2]
mse [0.19718456, 0.20393908, 0.2028204]
auc [0.7551701922248331, 0.7423441297625434, 0.7384318178252336]


In [24]:
split = [0.8, 0.1, 0.1]
run_local(split)

num of clients 3
round 1
split [0.8, 0.1, 0.1]
mse [0.20017248, 0.22496118, 0.21737762]
auc [0.7551153516392626, 0.7305401114825929, 0.7225911287469904]


In [None]:
split = [0.2, 0.2, 0.2, 0.2, 0.2]
run_local(split)

In [30]:
split = [0.4, 0.15, 0.15, 0.15, 0.15]
run_local(split)

num of clients 5
round 1
split [0.4, 0.15, 0.15, 0.15, 0.15]
mse [0.19620517, 0.20827086, 0.20444717, 0.20475082, 0.2095296]
auc [0.7574899763382419, 0.7303446238587146, 0.7428862202987274, 0.736773950689735, 0.7343271497699912]


In [None]:
split = [0.6, 0.1, 0.1, 0.1, 0.1]
run_local(split)

In [26]:
split = [0.8, 0.05, 0.05, 0.05, 0.05]
run_local(split)

num of clients 5
round 1
split [0.8, 0.05, 0.05, 0.05, 0.05]
mse [0.19705181, 0.23927695, 0.2413147, 0.2523718, 0.24482094]
auc [0.7539705386997876, 0.6783518205931999, 0.707302914199466, 0.6705114808563084, 0.685694254659772]


In [None]:
split = [0.1, 0.1, 0.1, 0.1, 0.1,
        0.1, 0.1, 0.1, 0.1, 0.1]
run_local(split)

In [27]:
split = [0.4, 0.07, 0.07, 0.07, 0.07, 0.07, 0.07, 0.06, 0.06, 0.06]
run_local(split)

num of clients 10
round 1
split [0.4, 0.07, 0.07, 0.07, 0.07, 0.07, 0.07, 0.06, 0.06, 0.06]
mse [0.19559278, 0.22652367, 0.21584006, 0.24114326, 0.23175684, 0.22410439, 0.22891557, 0.2314748, 0.2125626, 0.22369927]
auc [0.7563601929952476, 0.6996476114347221, 0.695546964036976, 0.7245670824363681, 0.7076945754716981, 0.7029320865518551, 0.7201231480308978, 0.7172596199506142, 0.7429255176016208, 0.7247965999746738]


In [29]:
split = [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
run_local(split)

num of clients 10
round 1
split [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
mse [0.19848773, 0.2318051, 0.24397711, 0.25491175, 0.24122839, 0.2358786, 0.26125994, 0.24775077, 0.24984026, 0.23055176]
auc [0.756138908248991, 0.6908653287963632, 0.7077673284569836, 0.7001848725986656, 0.7277884950298743, 0.7217444113995838, 0.6683914744259573, 0.6938728576659611, 0.6898293277603622, 0.6920866489832007]


In [28]:
split = [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
run_local(split)

num of clients 10
round 1
split [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
mse [0.19804782, 0.26459602, 0.26391208, 0.24043137, 0.23864369, 0.24354953, 0.24618284, 0.24132615, 0.2361172, 0.26350534]
auc [0.754269641993694, 0.6838232010645804, 0.6571863640829158, 0.7130723682447822, 0.6581017960328306, 0.627990783163197, 0.6650054925916994, 0.5856844305120168, 0.6514928239066169, 0.6774107118934705]
