## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.exceptions import ConvergenceWarning

import warnings
warnings.filterwarnings("ignore", category=ConvergenceWarning)

from fdshapley import FederatedShapley

## Load data (old)

## Corrupt data

## Load data

In [None]:
N = 20
with open('dico_data.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)
xs_train, ys_train, ys_train_corrupted, X_test, y_test = loaded_dict.values()

In [None]:
params = {
    'max_iter': 1, 
    'warm_start': True,
    'fit_intercept':False
}

In [None]:
data_train = [[x,y] for x,y in zip(xs_train, ys_train)]
data_train_corrupted = [[x,y] for x,y in zip(xs_train, ys_train_corrupted)]
data_test = [X_test, y_test]

## Run Original Data Shapley

In [None]:
corrupted = False

Trepeat = 300
trunc = 5

In [None]:
if corrupted:
    file = 'res_orig_corrupted.pkl'
    my_data_train = data_train_corrupted
else:
    file = 'res_orig.pkl'
    my_data_train = data_train

with open(file, 'rb') as f:
    warm_start = pickle.load(f)
#warm_start = None

fed = FederatedShapley(my_data_train, data_test)
res_orig = fed.originalDataShapley(Trepeat, trunc, warm_start)


with open(file, 'wb') as f:
    pickle.dump({"s_hat":res_orig, "Tprev":warm_start["Tprev"]+Trepeat}, f)

In [None]:
with open('res_orig.pkl', 'rb') as f:
    warm_start = pickle.load(f)
print(warm_start["Tprev"])
plot_simple(warm_start['s_hat'], "Data Shapley value", "original_shapley")

In [None]:
with open('res_orig_corrupted.pkl', 'rb') as f:
    warm_start = pickle.load(f)
print(warm_start["Tprev"])
plot_simple(warm_start['s_hat'], "Data Shapley value", "original_shapley_corrupted")

## Run FL Shapley

In [None]:
T = 300

In [None]:
fed = FederatedShapley(data_train, data_test)
res, log = fed.federatedSVEstimation(0.1, T)

In [None]:
fed.u(fed.w)

In [None]:
def plot_simple(x, legend, filename):
    plt.figure(figsize=(10,7))
    colors =  ["tab:orange"]*5 + ["tab:blue"]*15
    plt.bar(range(N), x, color=colors)
    plt.xticks(range(N))
    plt.xlabel("Participant id")
    plt.ylabel(legend)
    
    plt.savefig(filename+".png")
    plt.show()

In [None]:
def plot(res, log, legend, filename):
    plt.figure(figsize=(10,7))

    cmap = cm.autumn
    norm = Normalize(vmin=0, vmax=10)
    colors =  [cmap(norm(x)) for x in log["first"]]

    plt.bar(range(N), res, color=colors)

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=10))
    sm._A = []
    cbar = plt.colorbar(sm)
    cbar.ax.set_yticklabels([0, 2, 4, 6, 8, '>10'])
    cbar.set_label('First round of participation', rotation=270)


    plt.xticks(range(N))
    plt.xlabel("Participant id")
    plt.ylabel(legend)

    plt.savefig(filename+".png")
    plt.show()

In [None]:
plot(res, log,  "Federated Data Shapley value", "fdshap_3")

In [None]:
fed = FederatedShapley(data_train_corrupted, data_test)
res, log = fed.federatedSVEstimation(0.1, T)

In [None]:
plot(res, log, "Federated Data Shapley value", "fdshap_withcorruption")

## Run Reweighted FL Shapley

In [None]:
T = 50

In [None]:
agg = ("exp_acc" ,50)

In [None]:
fed = FederatedShapley(data_train, data_test)
res, log = fed.federatedSVEstimation(0.1, T, aggregation=agg)

In [None]:
fed.u(fed.w)

In [None]:
plot(res, log, "Federated Data Shapley value", "tmp")

In [None]:
fed = FederatedShapley(data_train_corrupted, data_test)
res, log = fed.federatedSVEstimation(0.1, T, aggregation=agg)

In [None]:
plot(res, log, "Federated Data Shapley value", "tmp")

In [None]:
rounds = [i for i, parts in enumerate(log["all_participants"]) if 1 in parts]

In [None]:
dico = {}
dico_var = {}

nrepeat=15

for method in ["sum", "normalize"]:
    vals = []
    for n in range(nrepeat):
        fed = FederatedShapley(data_train, data_test)
        res, _ = fed.federatedSVEstimation(0.1, T, aggregation=(method,))
        vals.append(res)
    dico[method] = vals

for method in ["exp_acc", "linear_acc"]:
    print(method)
    for a in [1, 10, 50]:
        vals = []
        for n in range(nrepeat):
            fed = FederatedShapley(data_train, data_test)
            res, _ = fed.federatedSVEstimation(0.1, T, aggregation=(method, a))
            vals.append(res)
        dico[(method, a)] = vals


In [None]:
with open('res_orig.pkl', 'rb') as f:
    warm_start = pickle.load(f)
true_shapley = warm_start['s_hat'].copy()
true_shapley/=np.sum(true_shapley)


dico_scores = {
    method: [np.sum(np.abs(x-true_shapley)) for x in fedshapley] for method, fedshapley in dico.items()
}

dico_res = {
    method: (np.mean(x), np.std(x)) for method, x in dico_scores.items()
}

dico_res