In [1]:
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import mean_squared_error, roc_auc_score

In [2]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Training on {DEVICE} using PyTorch {torch.__version__}")
print('main3_mlp_central')

Training on cpu using PyTorch 1.13.1+cpu
main3_mlp_central


In [3]:
X = torch.load('./dataset3/X.pt')
y = torch.load('./dataset3/y.pt')

In [5]:
def train(net, optimizer, x, y, x_test, y_test, num_epoch=64, batch_size=128):
  print_every = -1
  test_loss = []

  for n in range(num_epoch):
    # Mini batch sgd
    permutation = torch.randperm(x.size()[0])
    for i in range(0, x.size()[0], batch_size):
      indices = permutation[i:i+batch_size]
      x_mini, y_mini = x[indices], y[indices]
      y_pred = net(x_mini)

      loss = nn.MSELoss()(y_pred, y_mini)
      optimizer.zero_grad()
      loss.mean().backward()
      optimizer.step()

  return test_loss


def test(net, x_test, y_test):
  with torch.no_grad():
    y_pred = net(x_test)
    y_pred = y_pred.detach().numpy()
    y_test = y_test.detach().numpy()

    mse = mean_squared_error(y_test, y_pred)
    auc = roc_auc_score(y_test, y_pred)

    #print('auc', auc)
    return mse, auc

Local model

In [6]:
from sklearn.model_selection import StratifiedKFold


random_state = np.random.randint(1000)
def fl_split(split, x_train, y_train):

    x_eq = []
    y_eq = []

    n_splits = 50
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    skf.get_n_splits(x_train, y_train)

    for i, (_, test_index) in enumerate(skf.split(x_train, y_train)):
        x_eq.append(x_train[test_index])
        y_eq.append(y_train[test_index])

    x_split = []
    y_split = []

    acc = 0
    for s in split:
        x_split.append(torch.cat(x_eq[acc:acc+int(s*n_splits)], 0))
        y_split.append(torch.cat(y_eq[acc:acc+int(s*n_splits)], 0))
        acc += int(s*n_splits)

    return x_split, y_split

In [9]:
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler
from net_archs import MLP
import datetime

def run_local(split, X=X, y=y):
    mses = []
    aucs = []

    kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)


    start = datetime.datetime.now()

    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X, y)):
        print(f"Fold {fold_idx + 1}:")
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        x_split, y_split = fl_split(split, X_train, y_train)

        for X_train, y_train in zip(x_split, y_split):
            ros = RandomOverSampler(random_state=random_state)
            X_train, y_train = ros.fit_resample(X_train, y_train)
            X_train = torch.from_numpy(X_train)

            y_train = torch.reshape(torch.from_numpy(y_train), (-1, 1))
            y_test = torch.reshape(y_test, (-1, 1))

            model = MLP(10, 1, layer_size=64, num_of_layers=2)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.00005, weight_decay=0.00001)
            train(model, optimizer, X_train, y_train, X_test, y_test, num_epoch=10, batch_size=32)
            mse, auc = test(model, X_test, y_test)

            aucs.append(auc)
            mses.append(mse)
    end = datetime.datetime.now()
    time = end - start

    print('split',split)
    print('Training time: ', time)
    print('mse',mses)
    print('auc',aucs)

In [10]:
split = [1]
run_local(split)

Fold 1:


KeyboardInterrupt: 

In [15]:
split = [0.5, 0.5]
run_local(split)

round 1
torch.Size([60000, 10])
auc 0.8120392322062427
auc 0.8120392322062427
torch.Size([60000, 10])


KeyboardInterrupt: 

In [None]:
split = [0.6, 0.4]
run_local(split)

In [33]:
split = [0.8, 0.2]
run_local(split)

round 1
torch.Size([12000, 10])
auc 0.7007957910515759
auc 0.7580617013615889
auc 0.7830541696364932
auc 0.7939972150987208
auc 0.7994540417808672
auc 0.7996855045521012
auc 0.8011624574733084
auc 0.8018834823757631
auc 0.8018540902778287
auc 0.8023464079182312
auc 0.8015767023535723
auc 0.8017861210513554
auc 0.8015197551638241
auc 0.8009521202724648
auc 0.8011946050804242
auc 0.8011946050804242
torch.Size([108000, 10])
auc 0.8301970729117798
auc 0.8342739710538193
auc 0.8357062463107455
auc 0.8366664307180619
auc 0.8359890448529382
auc 0.8356839740485179
auc 0.8374738002238143
auc 0.8370032329197014
auc 0.837902442624866
auc 0.8379007546679845
auc 0.8380593458895257
auc 0.8375183776305466
auc 0.8377611803632568
auc 0.837226229560952
auc 0.8378465975579124
auc 0.8378465975579124
[0.1, 0.9]
[0.8011946050804242, 0.8378465975579124]
client mse [0.15911835 0.14056116]
client auc [0.80119461 0.8378466 ]


In [None]:
split = [0.34, 0.33, 0.33]
run_local(split)

In [51]:
split = [0.6, 0.2, 0.2]
run_local(split)

round 1
[0.6, 0.2, 0.2]
[0.8336498340372493, 0.8250584821428572, 0.8171984375]
client mse [0.14485916 0.15435468 0.1622122 ]
client auc [0.83364983 0.82505848 0.81719844]


In [50]:
split = [0.8, 0.1, 0.1]
run_local(split)

round 1
[0.8, 0.1, 0.1]
[0.8389191512624855, 0.7945321428571429, 0.8101535714285715]
client mse [0.15591994 0.18518001 0.15593214]
client auc [0.83891915 0.79453214 0.81015357]


In [None]:
split = [0.2, 0.2, 0.2, 0.2, 0.2]
run_local(split)

In [55]:
split = [0.6, 0.1, 0.1, 0.1, 0.1]
run_local(split)

round 1
[0.6, 0.1, 0.1, 0.1, 0.1]
[0.8328602787581267, 0.8208732142857144, 0.8035857142857143, 0.7924803571428571, 0.8064232142857144]
client mse [0.14502504 0.16205046 0.18040295 0.17702585 0.15839897]
client auc [0.83286028 0.82087321 0.80358571 0.79248036 0.80642321]


In [54]:
split = [0.8, 0.05, 0.05, 0.05, 0.05]
run_local(split)

round 1
[0.8, 0.05, 0.05, 0.05, 0.05]
[0.8398367737569369, 0.8553962053571429, 0.8191517857142857, 0.8169140625, 0.7640457589285714]
client mse [0.14810213 0.16924135 0.16573179 0.16033974 0.14372839]
client auc [0.83983677 0.85539621 0.81915179 0.81691406 0.76404576]


In [11]:
split = [0.1, 0.1, 0.1, 0.1, 0.1,
        0.1, 0.1, 0.1, 0.1, 0.1]
run_local(split)

Fold 1:


KeyboardInterrupt: 

In [46]:
split = [0.4, 0.07, 0.07, 0.07, 0.07, 0.07, 0.07, 0.06, 0.06, 0.06]
run_local(split)

round 1
torch.Size([48000, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
torch.Size([7200, 10])
[0.4, 0.07, 0.07, 0.07, 0.07, 0.07, 0.07, 0.06, 0.06, 0.06]
[0.8374646617750832, 0.8068901697685065, 0.8109559507577808, 0.7961929563492064, 0.8180406746031745, 0.8311259920634921, 0.8289434523809525, 0.7926860119047618, 0.8125049603174602, 0.823874007936508]
client mse [0.14670356 0.17391726 0.16069983 0.15369567 0.16048577 0.15651706
 0.16144638 0.15831232 0.15309429 0.1551771 ]
client auc [0.83746466 0.80689017 0.81095595 0.79619296 0.81804067 0.83112599
 0.82894345 0.79268601 0.81250496 0.82387401]


In [57]:
split = [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
run_local(split)

round 1
[0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
[0.8328513351755795, 0.7889843750000001, 0.7977455357142856, 0.7480468750000001, 0.7957756696428571, 0.8151897321428572, 0.8558872767857143, 0.8404575892857142, 0.8214787946428571, 0.7500613839285714]
client mse [0.14797236 0.16713153 0.15241122 0.18521674 0.17096385 0.17195275
 0.17783752 0.1654726  0.16510616 0.14491653]
client auc [0.83285134 0.78898438 0.79774554 0.74804688 0.79577567 0.81518973
 0.85588728 0.84045759 0.82147879 0.75006138]


In [58]:
split = [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
run_local(split)

round 1
[0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
[0.8369004819305601, 0.7663392857142858, 0.7092410714285714, 0.7622321428571429, 0.8107142857142857, 0.73390625, 0.7917410714285715, 0.7628571428571429, 0.8219642857142857, 0.7046428571428571]
client mse [0.15644191 0.18714832 0.19965808 0.20672046 0.18479377 0.19478165
 0.17306052 0.19586898 0.18460417 0.18981056]
client auc [0.83690048 0.76633929 0.70924107 0.76223214 0.81071429 0.73390625
 0.79174107 0.76285714 0.82196429 0.70464286]


In [None]:
# split = [0.15, 0.15, 0.15, 0.15, 0.15,
#          0.05, 0.05, 0.05, 0.05, 0.05]
# run_local(split)

In [None]:
# split = [0.05, 0.05, 0.05, 0.05, 0.05,
#          0.05, 0.05, 0.05, 0.05, 0.05,
#          0.05, 0.05, 0.05, 0.05, 0.05,
#          0.05, 0.05, 0.05, 0.05, 0.05]
# run_local(split)

In [None]:
# split = [0.08, 0.08, 0.08, 0.08, 0.08,
#          0.08, 0.08, 0.08, 0.08, 0.08,
#          0.02, 0.02, 0.02, 0.02, 0.02,
#          0.02, 0.02, 0.02, 0.02, 0.02]
# run_local(split)

In [None]:
# split = [0.08, 0.08, 0.08, 0.08, 0.08,
#          0.08, 0.08, 0.08, 0.08, 0.08,
#          0.02, 0.02, 0.02, 0.02, 0.02,
#          0.02, 0.02, 0.02, 0.02, 0.02]
# run_local(split) # epoch 15

In [None]:
# split = [0.05, 0.05, 0.05, 0.05, 0.05,
#          0.05, 0.05, 0.05, 0.05, 0.05,
#          0.05, 0.05, 0.05, 0.05, 0.05,
#          0.05, 0.05, 0.05, 0.05, 0.05]
# run_local(split) # epoch 15