In [2]:
from sklearn.metrics import mean_squared_error, roc_auc_score
import numpy as np
import torch
from trees import construct_tree, TreeDataset, do_fl_partitioning
from sklearn.model_selection import train_test_split

In [3]:
from sklearn.model_selection import StratifiedKFold

random_state = np.random.randint(1000)

def fl_split(split, x_train, y_train):

    x_eq = []
    y_eq = []

    n_splits = 50
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    skf.get_n_splits(x_train, y_train)

    for i, (_, test_index) in enumerate(skf.split(x_train, y_train)):
        x_eq.append(x_train[test_index])
        y_eq.append(y_train[test_index])

    x_split = []
    y_split = []

    acc = 0
    for s in split:
        x_split.append(torch.cat(x_eq[acc:acc+int(s*n_splits)], 0))
        y_split.append(torch.cat(y_eq[acc:acc+int(s*n_splits)], 0))
        acc += int(s*n_splits)

    return x_split, y_split

In [4]:
# Local models
X = torch.load('./dataset2/X_mlp.pt')
y = torch.load('./dataset2/y_mlp.pt')

torch.Size([307511, 1])


In [5]:
from imblearn.over_sampling import RandomOverSampler

def run_local(split, X=X, y=y):
    mses = []
    aucs = []
    kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)


    X = X.to_numpy()
    y = y.to_numpy()

    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X, y)):
        print(f"Fold {fold_idx + 1}:")


        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        X_train = torch.tensor(X_train)
        X_test = torch.tensor(X_test)
        y_train = torch.reshape(torch.tensor(y_train), (-1, 1))
        y_test = torch.reshape(torch.tensor(y_test), (-1, 1))


        x_split, y_split = fl_split(split, X_train, y_train)

        for X_train, y_train in zip(x_split, y_split):
            ros = RandomOverSampler(random_state=random_state)
            X_train, y_train = ros.fit_resample(X_train, y_train)

            X_train = torch.from_numpy(X_train)
            y_train = torch.reshape(torch.from_numpy(y_train), (-1, 1))
            y_test = torch.reshape(y_test, (-1, 1))

            global_tree = construct_tree(X_train, y_train, learning_rate=0.1, max_depth=5, n_estimators=100,
               min_child_weight=3, subsample=0.8,colsample_bytree=1, gamma=0)

            preds_test = global_tree.predict_proba(X_test)
            pred = preds_test[:, 1]
            auc = roc_auc_score(y_test, pred)
            mse = mean_squared_error(y_test, pred)

            aucs.append(auc)
            mses.append(mse)
    print('split',split)
    print('mse',mses)
    print('auc',aucs)


In [12]:
split = [1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [1]
mse [0.18691278, 0.18647519, 0.1879661, 0.18585135, 0.18737356]
auc [0.764730865536986, 0.7721975240981936, 0.7641721528742123, 0.7710588994639431, 0.7615063977655927]


In [11]:
split = [0.5, 0.5]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.5, 0.5]
mse [0.18187952, 0.1818471, 0.1825949, 0.18232845, 0.18457149, 0.18219565, 0.18000978, 0.18148728, 0.18249567, 0.18114905]
auc [0.7598702787637557, 0.7596599985672106, 0.766072718271404, 0.7654489237243616, 0.7584456015142238, 0.7595511666726427, 0.7674728173536456, 0.7659134075785748, 0.756278178816888, 0.7585961860016597]


In [25]:
split = [0.6, 0.4]
run_local(split)

split = [0.8, 0.2]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.4]
mse [0.1839341, 0.17845273, 0.18348819, 0.1797912, 0.18504862, 0.18115424, 0.1829614, 0.17642435, 0.18340926, 0.1797918]
auc [0.7609034563889994, 0.7584871524706338, 0.7670344604601811, 0.7659658324973615, 0.7599986184844043, 0.7582089341416589, 0.7682921953934009, 0.7665866666538419, 0.7584937176575772, 0.7568494326657296]
Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.2]
mse [0.18587707, 0.16642405, 0.1856349, 0.17048632, 0.18721914, 0.17167601, 0.1846936, 0.16730422, 0.18633696, 0.16871148]
auc [0.7628909636905434, 0.7519127952763689, 0.7701396803821989, 0.7573211803422727, 0.7606291389248058, 0.7498997840108308, 0.7694635802582276, 0.7595797820001877, 0.7599231142752971, 0.7471805601874744]


In [26]:
split = [0.34, 0.33, 0.33]
run_local(split)

split = [0.6, 0.2, 0.2]
run_local(split)

split = [0.8, 0.1, 0.1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.34, 0.33, 0.33]
mse [0.17733523, 0.1781404, 0.17489931, 0.17742357, 0.17637573, 0.17592822, 0.17897855, 0.17868803, 0.1773661, 0.17760807, 0.1756116, 0.1755328, 0.17829342, 0.17720734, 0.17784882]
auc [0.7569948267466521, 0.7560774015512102, 0.7578217282910402, 0.7628678924286694, 0.7644249581159063, 0.7632458516547578, 0.7546045250406915, 0.7578436322773843, 0.7557860967127535, 0.7607028957553681, 0.7622599436303875, 0.7621275685729855, 0.753988758460113, 0.7511642483998529, 0.7527424536269156]
Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.2, 0.2]
mse [0.1839341, 0.16973132, 0.16642405, 0.18348819, 0.16862018, 0.17048632, 0.18504862, 0.17051022, 0.17167601, 0.1829614, 0.1671214, 0.16730422, 0.18340926, 0.1688321, 0.16871148]
auc [0.7609034563889994, 0.7512070396058695, 0.7519127952763689, 0.7670344604601811, 0.7582224162091465, 0.7573211803422727, 0.7599986184844043, 0.7500605446181712, 0.7498997840108308, 0.7682921953934009, 0.

In [None]:
split = [0.2, 0.2, 0.2, 0.2, 0.2]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.2, 0.2, 0.2, 0.2, 0.2]
mse [0.17004369, 0.16957828, 0.17139108, 0.16973132, 0.16642405, 0.16918348, 0.16932978, 0.16839217, 0.16862018, 0.17048632, 0.17107102, 0.17121184, 0.17241742, 0.17051022, 0.17167601, 0.17032689, 0.16925338, 0.17020398, 0.1671214, 0.16730422, 0.16892503, 0.16981256, 0.16873291, 0.1688321, 0.16871148]
auc [0.7502232864477747, 0.7503257743537601, 0.7538829520036556, 0.7512070396058695, 0.7519127952763689, 0.7575358923754465, 0.7557825378316807, 0.7576836856883873, 0.7582224162091465, 0.7573211803422727, 0.746843435470192, 0.7489204344449742, 0.7514107641475185, 0.7500605446181712, 0.7498997840108308, 0.755083908814912, 0.7538019172037895, 0.7577933359185987, 0.7572363122503828, 0.7595797820001877, 0.7490823439403487, 0.7431076434523419, 0.7460333090962488, 0.7516893258558357, 0.7471805601874744]
Fold 1:
Fold 2:
Fold 3:
Fold 4:


In [8]:
split = [0.6, 0.1, 0.1, 0.1, 0.1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.1, 0.1, 0.1, 0.1]
mse [0.1840823, 0.15291038, 0.15634201, 0.1568086, 0.15345518, 0.18427986, 0.15426579, 0.15714166, 0.15527886, 0.15559772, 0.18486944, 0.15500705, 0.16014992, 0.15341276, 0.15883832, 0.18230191, 0.15305787, 0.15559919, 0.15113972, 0.15510686, 0.1843711, 0.1559257, 0.15480421, 0.15460163, 0.15617211]
auc [0.761281442772655, 0.7447622925015773, 0.734924719953253, 0.7392789374929398, 0.7433313519372955, 0.7686647272367919, 0.7442021632546385, 0.7471008024208087, 0.745671897420294, 0.7434100824383273, 0.7594430091062647, 0.7411770644685252, 0.7353550520908507, 0.7390056856776643, 0.7372683799419396, 0.7694793333834569, 0.7436199495483187, 0.7455740388781217, 0.7429565032237175, 0.7470512648624921, 0.7577400720443641, 0.7365887120307868, 0.7362952219029144, 0.7376093912850983, 0.7395292045646087]


In [9]:
split = [0.8, 0.05, 0.05, 0.05, 0.05]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.05, 0.05, 0.05, 0.05]
mse [0.18580565, 0.13063139, 0.13016546, 0.13548869, 0.12709536, 0.18604146, 0.12823439, 0.1280398, 0.1295744, 0.12848258, 0.18662998, 0.12989433, 0.12704177, 0.12884563, 0.13227504, 0.184332, 0.125984, 0.12511966, 0.12645432, 0.12763833, 0.18645114, 0.13123527, 0.12888236, 0.122652225, 0.12995258]
auc [0.7631850168270825, 0.7131930286208418, 0.7225409573121012, 0.7178783676474293, 0.7157558087196886, 0.7710244524163619, 0.7232603532935797, 0.7277091131633516, 0.7247907523098749, 0.7255132728540861, 0.7622370175251381, 0.7135720405610557, 0.7237150226159055, 0.7175289445418565, 0.7164624700761424, 0.7697822515180953, 0.7240499118998811, 0.7236334943148122, 0.722193125371062, 0.7274982129447405, 0.7589120340250404, 0.7213911819298757, 0.7172070278959455, 0.7210466099244226, 0.7128145118844096]


In [6]:
split = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
mse [0.15504995, 0.15715306, 0.15675975, 0.15656984, 0.15398732, 0.15538879, 0.15291038, 0.15634201, 0.1568086, 0.15345518, 0.15369038, 0.15677075, 0.15321308, 0.1533349, 0.15410392, 0.15808626, 0.15426579, 0.15714166, 0.15527886, 0.15559772, 0.1584116, 0.15745096, 0.15269253, 0.1574338, 0.15727657, 0.1570961, 0.15500705, 0.16014992, 0.15341276, 0.15883832, 0.15491414, 0.1523217, 0.15371859, 0.15612243, 0.1516785, 0.15433908, 0.15305787, 0.15559919, 0.15113972, 0.15510686, 0.15760653, 0.15534918, 0.15423095, 0.15632722, 0.15603374, 0.15256512, 0.1559257, 0.15480421, 0.15460163, 0.15617211]
auc [0.7420307499697999, 0.7394635311448419, 0.7366081513607028, 0.740678288648079, 0.7389398042122799, 0.7407413160651926, 0.7447622925015773, 0.734924719953253, 0.7392789374929398, 0.7433313519372955, 0.7427347143964987, 0.7459367401586297, 0.7416516442876637, 0.7435775475643653, 0.7437911677086013, 0.7

In [7]:
split = [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.6, 0.05, 0.05, 0.05, 0.05, 0.04, 0.04, 0.04, 0.04, 0.04]
mse [0.1840823, 0.13145703, 0.12907581, 0.12597279, 0.1333825, 0.12532897, 0.13063139, 0.13016546, 0.13548869, 0.12709536, 0.18427986, 0.1272232, 0.1325394, 0.1313441, 0.13163868, 0.13286203, 0.12823439, 0.1280398, 0.1295744, 0.12848258, 0.18486944, 0.12814917, 0.12589025, 0.13117766, 0.13102429, 0.12906483, 0.12989433, 0.12704177, 0.12884563, 0.13227504, 0.18230191, 0.12934946, 0.123277955, 0.1310159, 0.12752825, 0.123551734, 0.125984, 0.12511966, 0.12645432, 0.12763833, 0.1843711, 0.12840201, 0.12579758, 0.13099883, 0.12982424, 0.12692142, 0.13123527, 0.12888236, 0.122652225, 0.12995258]
auc [0.761281442772655, 0.7194075105739468, 0.7163536438539301, 0.7259611614315169, 0.7109681937487561, 0.7155727985459218, 0.7131930286208418, 0.7225409573121012, 0.7178783676474293, 0.7157558087196886, 0.7686647272367919, 0.7263017306653411, 0.7249622910900739, 0.7264896264049453, 0.72553197212

In [10]:
split = [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
run_local(split)

Fold 1:
Fold 2:
Fold 3:
Fold 4:
Fold 5:
split [0.8, 0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
mse [0.18580565, 0.111047186, 0.103074074, 0.10886639, 0.11042133, 0.11146707, 0.10799031, 0.10603186, 0.10563887, 0.11168817, 0.18604146, 0.11076831, 0.10639225, 0.105606735, 0.107726626, 0.10813952, 0.10547369, 0.11036399, 0.10318935, 0.10556593, 0.18662998, 0.10850151, 0.1076147, 0.105861165, 0.1087903, 0.110289454, 0.10785766, 0.110663615, 0.10987626, 0.10900856, 0.184332, 0.104783095, 0.104767144, 0.10600617, 0.10387945, 0.109129295, 0.10089764, 0.10673631, 0.10669229, 0.10479591, 0.18645114, 0.11030923, 0.10608784, 0.109484546, 0.10934329, 0.10034522, 0.10369458, 0.10835866, 0.10635614, 0.10560783]
auc [0.7631850168270825, 0.6944209647945253, 0.6988181072381267, 0.7104003342652878, 0.7033554418229955, 0.696955217350275, 0.7170380430532921, 0.6958680981593999, 0.7036861144499522, 0.6914582789847658, 0.7710244524163619, 0.700412602208063, 0.7176728013547118, 0.7046689242227474,