In [None]:
import math
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
import sys
sys.path.append("..")

import numpy as np
import talos
import tensorflow as tf
from tensorflow.keras.optimizers import SGD
import tensorflow.keras.backend as K
from federated_library.distributions import qty_skew_distrib, label_skew_distrib, \
    feature_skew_distrib, iid_distrib
from federated_library.dataset_loader import load_tf_dataset
from federated_library.models_keras import get_model
from constants import DATASETS, NR_PARTIES, HP_GRID, SKEWS

DEVICE = 'GPU'
numWorkers = len(tf.config.list_physical_devices('GPU'))
if numWorkers == 0:
    DEVICE = 'CPU'
    numWorkers = os.cpu_count()

print(numWorkers)

In [None]:
degree = 3


def reluApprox(x):
    if degree == 3:
        if interval == 3:
            return 0.7146 + 1.5000 * K.pow(x/interval, 1) + 0.8793 * K.pow(x/interval, 2)
        if interval == 5:
            return 0.7865 + 2.5000 * K.pow(x/interval, 1) + 1.88 * K.pow(x/interval, 2)
        if interval == 7:
            return 0.9003 + 3.5000 * K.pow(x/interval, 1) + 2.9013 * K.pow(x/interval, 2)
        if interval == 10:
            return 1.1155 + 5 * K.pow(x/interval, 1) + 4.4003 * K.pow(x/interval, 2)
        if interval == 12:
            return 1.2751 + 6 * K.pow(x/interval, 1) + 5.3803 * K.pow(x/interval, 2)
    if degree == 5:
        if interval == 7:
            return 0.7521 + 3.5000 * K.pow(x/interval, 1) + 4.3825 * K.pow(x/interval, 2) - 1.7281 * K.pow(x/interval, 4)
        if interval == 20:
            return 1.3127 + 10 * K.pow(x/interval, 1) + 15.7631 * K.pow(x/interval, 2) - 7.6296 * K.pow(x/interval, 4)


def sigmoidApprox(x):
    if degree == 3:
        if interval == 3:
            return 0.5 + 0.6997 * K.pow(x/interval, 1) - 0.2649 * K.pow(x/interval, 3)
        if interval == 5:
            return 0.5 + 0.9917 * K.pow(x/interval, 1) - 0.5592 * K.pow(x/interval, 3)
        if interval == 7:
            return 0.5 + 1.1511 * K.pow(x/interval, 1) - 0.7517 * K.pow(x/interval, 3)
        if interval == 8:
            return 0.5 + 1.2010 * K.pow(x/interval, 1) - 0.8156 * K.pow(x/interval, 2)
        if interval == 12:
            return 0.5 + 1.2384 * K.pow(x/interval, 1) - 0.8647 * K.pow(x/interval, 2)


def tanApprox(x):
    if degree == 3:
        if interval == 1:
            return 0.9797 * K.pow(x/interval, 1) - 0.2268 * K.pow(x/interval, 3)
        if interval == 2:
            return 1.7329 * K.pow(x/interval, 1) - 0.8454 * K.pow(x/interval, 3)
        if interval == 3:
            return 2.1673 * K.pow(x/interval, 1) - 1.3358 * K.pow(x/interval, 3)
        if interval == 5:
            return 2.5338 * K.pow(x/interval, 1) - 1.8051 * K.pow(x/interval, 3)
        if interval == 7:
            return 2.6629 * K.pow(x/interval, 1) - 1.9801 * K.pow(x/interval, 3)
        if interval == 12:
            return 2.7599 * K.pow(x/interval, 1) - 2.1140 * K.pow(x/interval, 2)
    if degree == 12:
        print('ooopssss')

In [None]:
def experiment(x_train, y_train, x_val, y_val, params):

    optimizer = SGD(
        learning_rate=params['client_lr'],
        momentum=params['client_momentum'],
        nesterov=False, name='SGD'
    )

    model = get_model(params, ds_info)
    model.compile(
        optimizer=optimizer,
        loss="mean_squared_error",
        metrics=['accuracy', tf.keras.metrics.Recall(),
                 tf.keras.metrics.Precision()],
        run_eagerly=False
    )

    early_stop = tf.keras.callbacks.EarlyStopping(monitor="val_accuracy",
                                                  min_delta=0.01,
                                                  patience=10)
    history = model.fit(x=x_train,
                        y=y_train,
                        epochs=params['epochs'],
                        batch_size=params['batch_size'],
                        validation_split=0.1,
                        callbacks=[early_stop],
                        verbose=0)

    return history, model

In [None]:
def client_gridsearch(work):

    client_number, clientData, clientDataLabels, param_grid = work

    free = np.where(workers == 1)
    i = free[0][0]
    workers[i] = 0

    # Distribute load accross DEVICEs
    with tf.device(f"/{DEVICE}:{i}"):
        print(f"training on {DEVICE}: {i}")

        scan_results = talos.Scan(x=clientData,
                                  y=clientDataLabels,
                                  params=param_grid,
                                  model=experiment,
                                  experiment_name=f"{experiment_name}_{client_number}")
        scan_res[client_number] = scan_results
        print(f"client running on {DEVICE}: {i} finished")

        workers[i] = 1

        return


def grid_search_for_X_clients(numClients, clientsData, clientsDataLabels, param_grid):

    global scan_res
    global workers

    scan_res = np.zeros(numClients, dtype=object)
    workers = np.ones(numWorkers)

    work = [(i, clientsData[i], clientsDataLabels[i], param_grid)
            for i in range(numClients)]

    with ThreadPool(len(workers)) as p:
        p.map(client_gridsearch, work)

    return scan_res


def intervals_search_for_X_clients(numClients, clientsData, clientsDataLabels, params, intervals):

    global scan_res
    global workers

    workers = np.ones(numWorkers)

    global interval

    intervals_res = np.zeros(len(intervals), dtype=object)

    for idx, inter in enumerate(intervals):
        interval = inter
        scan_res = np.zeros(numClients, dtype=object)

        work = [(i, clientsData[i], clientsDataLabels[i], params[i])
                for i in range(numClients)]

        try:
            with ThreadPool(len(workers)) as p:
                p.map(client_gridsearch, work)
        except:
            pass

        intervals_res[idx] = scan_res

    return intervals_res

In [None]:
def run(hyperparams, ds, test_dataset, ds_info, with_intervals=False, display=False):
    for c in hyperparams['clients_set']:

        ds_info['num_clients'] = c

        x_train, y_train = ds

        TRAIN_SAMPLES_NUMBER = len(y_train)
        print(TRAIN_SAMPLES_NUMBER)

        for skew in hyperparams['skews_set']:
            BASE_DIR = f"{dataset_name}_non_iid_res/"

            Path(f"{BASE_DIR}").mkdir(parents=True, exist_ok=True)

            #####################  QUANTITY SKEW BEGIN  #####################
            if skew_type == "qty":
                clientsData, clientsDataLabels = qty_skew_distrib(
                    x_train, y_train, ds_info, skew, decentralized=True,
                    display=display, is_tf=True
                )

                # Write in file the percentage of samples each client received
                textfile = open(
                    f"{BASE_DIR}{dataset_name}_{skew_type}_skew_{skew}_{c}clients_distribution.txt", "w")
                for i, cd in enumerate(clientsData):
                    textfile.write(
                        f"client: {i}, samples: {len(cd)} / {TRAIN_SAMPLES_NUMBER}, "
                        f"percentage: {len(cd)/TRAIN_SAMPLES_NUMBER}\n")
                textfile.close()
            #####################  QUANTITY SKEW END  #####################

            #####################  LABEL SKEW BEGIN  ####################
            elif skew_type == "label":
                clientsData, clientsDataLabels = label_skew_distrib(
                    x_train, y_train, ds_info, skew, decentralized=True,
                    display=display, is_tf=True
                )

                # Write in file the amount of each class sample each client received
                textfile = open(
                    f"{BASE_DIR}{dataset_name}_{skew_type}_skew_{skew}_{c}clients_distribution.txt", "w")

                # File header
                header = "client_id"

                for i in range(ds_info['num_classes']):
                    header += f",class_{i}"
                textfile.write(f"{header}\n")

                for i, cdl in enumerate(clientsDataLabels):

                    line = ""
                    labelsMap = np.zeros(ds_info['num_classes'], dtype=int)

                    for label in cdl:
                        labelsMap[np.argmax(label)] += 1

                    for _, count in enumerate(labelsMap):
                        line += f",{count}"

                    textfile.write(f"{i}{line}\n")
                textfile.close()
            #####################  LABEL SKEW END  #####################

            #####################  FEATURE SKEW BEGIN  ####################
            elif skew_type == "feature":
                clientsData, clientsDataLabels = feature_skew_distrib(
                    x_train, y_train, ds_info, skew, decentralized=True,
                    display=display, is_tf=True
                )
            #####################  FEATURE SKEW END  ####################

            #####################  IID BEGIN  ####################
            else:
                clientsData, clientsDataLabels = iid_distrib(
                    x_train, y_train, ds_info, decentralized=True,
                    display=display, is_tf=True)
            #####################  IID END  ####################

            #####################  TRAINING AND TEST BEGIN  #####################
            params_grid = dict(act_fn=hyperparams['act_fn'], client_lr=hyperparams['client_lr'],
                               client_momentum=hyperparams['client_momentum'],
                               batch_size=hyperparams['batch_size'], epochs=hyperparams['epochs'])

            res = grid_search_for_X_clients(
                c, clientsData, clientsDataLabels, params_grid)
            #####################  TRAINING AND TEST END  #####################

            #####################  SAVE GRIDSEARCH RESULTS BEGIN  #####################
            # Sort results and write to file
            sorted_data = []
            res = [r for r in res if not r == 0]

            for _, scan in enumerate(res):
                sorted_data.append(scan.data.sort_values(
                    by='val_accuracy', ascending=False))

            for i, r in enumerate(sorted_data):
                r.to_csv(f"{BASE_DIR}{skew}_{skew_type}_{c}clts_clt{i}.txt")
            #####################  SAVE GRIDSEARCH RESULTS END  #####################

            #####################  INTERVALS SEARCH BEGIN  #####################
            if with_intervals:
                # Prepare grid for interval search

                intervals_grid = []

                # Retrieve each best hyperparams for each client
                for i in range(c):
                    client_i_lr = sorted_data[i].head(
                        1).get(['client_lr']).values[0][0]
                    client_i_mom = sorted_data[i].head(1).get(
                        ['client_momentum']).values[0][0]
                    client_i_batch_size = sorted_data[i].head(
                        1).get(['batch_size']).values[0][0]

                    client_interval_grid = dict(
                        act_fn=hyperparams['act_fn_approx'],
                        client_lr=[client_i_lr],
                        client_momentum=[client_i_mom],
                        batch_size=[client_i_batch_size],
                        epochs=hyperparams['epochs']
                    )
                    intervals_grid.append(client_interval_grid)

                intervals_res = intervals_search_for_X_clients(
                    c, clientsData, clientsDataLabels, intervals_grid, hyperparams['intervals'])

                best_interval_per_client = np.zeros(c, dtype=int)

                for i in range(c):
                    best_client_interval_acc = 0.0
                    best_client_interval = 0
                    for j in range(len(hyperparams['intervals'])):
                        try:
                            if (intervals_res[j][i].data.head(1).get(['val_accuracy']).values[0][0] >= best_client_interval_acc):
                                best_client_interval = hyperparams['intervals'][j]
                                best_client_interval_acc = intervals_res[j][i].data.head(
                                    1).get(['val_accuracy']).values[0][0]
                        except:
                            pass

                    best_interval_per_client[i] = best_client_interval

                with open(f"{BASE_DIR}{dataset_name}_{skew_type}_skew_{skew}_{c}clients_intervals.txt", "w") as textfile:
                    for i in best_interval_per_client:
                        textfile.write(f"{i}\n")
            #####################  INTERVALS SEARCH END  #####################


In [None]:
for dataset_name in DATASETS:
    experiment_name = f"{dataset_name}_non-iid"
    for skew_type in SKEWS.keys():
        ds, ds_test, ds_info = load_tf_dataset(
            dataset_name=dataset_name, skew_type=skew_type,
            decentralized=True, display=False
        )
        test_dataset = tf.data.Dataset.from_tensor_slices(ds_test)
        test_dataset = test_dataset.batch(64)

        hyperparams = dict(
            act_fn=['relu'],
            act_fn_approx=[reluApprox],
            intervals=HP_GRID["interval"],
            client_lr=HP_GRID["lr"],
            client_momentum=HP_GRID["mom"],
            batch_size=HP_GRID["bs"],
            epochs=[30],
            clients_set=NR_PARTIES,
            skews_set=SKEWS[skew_type]
        )

        run(hyperparams, ds, test_dataset, ds_info,
            with_intervals=True, display=True)