In [None]:
import tensorflow as tf
DEVICE = 'GPU'
numWorkers = len(tf.config.list_physical_devices('GPU'))
if numWorkers == 0:
    DEVICE = 'CPU'
    numWorkers = !cat /proc/cpuinfo | grep processor | wc -l
    numWorkers = int(numWorkers[0])
    
print(numWorkers)

In [None]:
import pandas
import collections
import numpy as np
from tensorflow.keras.optimizers import SGD
from multiprocessing.pool import ThreadPool
import talos
import os
import math
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
from pathlib import Path
from copy import deepcopy
from dataset_loader import load_tf_dataset
from models_keras import get_model
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [None]:
def randomSplitClientsData(data, labels, ds_info):
    
    numParties = ds_info['num_clients']
    sample_height, sample_width, sample_channels = ds_info['sample_shape']
    num_classes = ds_info['num_classes']
    
    numSamplesPerClient = int(data.shape[0]/numParties)
    #print(numSamplesPerClient)
    clientsData = np.zeros((numParties,int(numSamplesPerClient),sample_height,sample_width,sample_channels))
    clientsDataLabels = np.zeros((numParties,int(numSamplesPerClient),num_classes))
    #print(numSamplesPerClient)
    ind = 0
    for i in range(numParties):
        clientsData[i] = data[ind:ind+numSamplesPerClient]
        clientsDataLabels[i]=labels[ind:ind+numSamplesPerClient]
        ind = ind+numSamplesPerClient
    return clientsData, clientsDataLabels

def prepare_data_for_X_clients(x_train, y_train, ds_info):
    clientsData, clientsDataLabels = randomSplitClientsData(x_train, y_train, ds_info)
    return clientsData, clientsDataLabels

In [None]:
# plot diagnostic learning curves
def summarize_diagnostics(history, params):
    print('##########################################################')
    print(params)
    # plot loss
    plt.subplot(211)
    plt.title('MSE')
    plt.plot(history.history['loss'], color='blue', label='train')
    plt.plot(history.history['val_loss'], color='orange', label='test')
    # plot accuracy
    plt.subplot(212)
    plt.title('Classification Accuracy')
    plt.plot(history.history['accuracy'], color='blue', label='train')
    plt.plot(history.history['val_accuracy'], color='orange', label='test')
    # save plot to file
    #filename = sys.argv[0].split('/')[-1]
    #pyplot.savefig(filename + '_plot.png')
    #pyplot.close()
    plt.show()
    print('##########################################################')

In [None]:
def experiment(x_train, y_train, x_val, y_val, params):
        
    
    optimizer = SGD(learning_rate=params['learn_rate'], momentum=params['momentum'], nesterov=False, name='SGD')
    
    model = get_model(params, ds_info)
    model.compile(optimizer=optimizer,
                  loss="mean_squared_error",
                  metrics=['accuracy', tf.keras.metrics.Recall(),tf.keras.metrics.Precision()],
                  run_eagerly=False)

    early_stop = tf.keras.callbacks.EarlyStopping(monitor="val_accuracy",
                                                  min_delta=0.01,
                                                  patience=5)
    history = model.fit(x=x_train,
                    y=y_train,
                    epochs=params['epochs'],
                    batch_size=params['batch_size'],
                    callbacks=[early_stop],
                    validation_data=(x_val, y_val),
                    verbose=0)

    hist.append(history)
    hist_params.append(params)
        
    return history, model

In [None]:
def client_gridsearch(work):
    
    client_number, clientData, clientDataLabels, param_grid = work
    
    free = np.where(workers == 1)
    i = free[0][0]
    workers[i] = 0

    #Distribute load accross DEVICEs
    with tf.device(f"/{DEVICE}:{i}"):
        print(f"training on {DEVICE}: {i}")

        scan_results = talos.Scan(x=clientData,
                                  y=clientDataLabels,
                                  params=param_grid,
                                  model=experiment,
                                  experiment_name=f"{experiment_name}_{client_number}")
        scan_res[client_number]=scan_results
        print(f"client running on {DEVICE}: {i} finished")

        workers[i] = 1

        return


def grid_search_for_X_clients(numClients, clientsData, clientsDataLabels, param_grid):
    
    global scan_res
    global workers
    
    scan_res = np.zeros(numClients, dtype=object)
    workers = np.ones(numWorkers)
    
    global hist
    global hist_params
    hist = []
    hist_params = []
                
    work = [(i, clientsData[i], clientsDataLabels[i], deepcopy(param_grid)) for i in range(numClients)]
    
    with ThreadPool(len(workers)) as p:
        p.map(client_gridsearch, work)

    return scan_res

In [None]:
def client_test_metrics(work):
    
        client_number, clientData, clientDataLabels, avg_test_params = work
    
        free = np.where(workers == 1)
        i = free[0][0]
        workers[i] = 0
    
        #Distribute load accross DEVICEs
        with tf.device(f"/{DEVICE}:{i%numWorkers}"):
            
            print(f"training on {DEVICE}: {i}")
            model = get_model(avg_test_params, ds_info)
            optimizer = SGD(learning_rate=avg_test_params['learn_rate'], momentum=avg_test_params['momentum'], nesterov=False, name='SGD')
            model.compile(optimizer=optimizer, loss="mean_squared_error", metrics=['accuracy',tf.keras.metrics.Recall(),tf.keras.metrics.Precision()])

            
            model.fit(x=clientData,
                        y=clientDataLabels,
                        epochs=avg_test_params['epochs'],
                        batch_size=avg_test_params['batch_size'],
                        verbose=0)
                        
            print(f"client running on {DEVICE}: {i} finished")
            metrics = model.evaluate(x_test, y_test)
            metrics_res[client_number] = metrics
            
            workers[i] = 1
            
            
        return

def test_metrics_for_X_clients(numClients, clientsData, clientsDataLabels, avg_test_params):

    global metrics_res
    global workers
    
    metrics_res = np.zeros(numClients, dtype=object)   
    workers = np.ones(numWorkers)
                
    work = [(i, clientsData[i], clientsDataLabels[i], avg_test_params) for i in range(numClients)]
    
    with ThreadPool(len(workers)) as p:
        p.map(client_test_metrics, work)
    
    return metrics_res

In [None]:
def run(params, ds, test_split, ds_info):
    
    numClients = ds_info['num_clients']
    
    (x_train, y_train) = ds
    
    global x_test
    global y_test
    (x_test, y_test) = test_split
    
    Path(experiment_name+"_res/res"+str(numClients)).mkdir(parents=True, exist_ok=True)

    clientsData, clientsDataLabels = prepare_data_for_X_clients(x_train, y_train, ds_info)

    big_res = []

    res = grid_search_for_X_clients(numClients, clientsData, clientsDataLabels, params)

    big_res=res

    ## Sort dataframes
    sorted_data = []
    big_res = [r for r in big_res if not r == 0]

    for _,df in enumerate(big_res):
        sorted_data.append(df.data.sort_values(by='val_accuracy',ascending=False).head())

    ## Write dataframes to files
    for i,df in enumerate(sorted_data):
        df.to_csv(experiment_name+"_res/res"+str(numClients)+"/res"+str(numClients)+"_client_"+str(i)+".csv")

    #for i in range(len(sorted_data)):
        #with pandas.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            #display(sorted_data[i])

    avg_params = np.zeros(4, dtype=float) #[lr, batchsize, epochs, momentum]

    for _,client_data in enumerate(sorted_data):
        avg_params[0] += client_data.head(1)['learn_rate'].item()
        avg_params[1] += client_data.head(1)['batch_size'].item()
        avg_params[2] += client_data.head(1)['round_epochs'].item()
        avg_params[3] += client_data.head(1)['momentum'].item()

    avg_params = avg_params / len(sorted_data)

    print("\n\n")
    print("Avg lr:", avg_params[0], "Avg batchsize:", int(math.ceil(avg_params[1])), "Avg epochs:", int(math.ceil(avg_params[2])), "Avg momentum:",avg_params[3])
    print("\n\n")
    
    #summarize_diagnostics(hist[0], hist_params[0])

    ####### Retrain each client to run the test set and get best metrics
    avg_test_params = dict(learn_rate=avg_params[0], batch_size=int(math.ceil(avg_params[1])), epochs=int(math.ceil(avg_params[2])), act_fn='relu', momentum=avg_params[3])
    avg_test_res = []

    clientsData, clientsDataLabels = prepare_data_for_X_clients(x_train, y_train, ds_info)

    res = test_metrics_for_X_clients(numClients, clientsData, clientsDataLabels, avg_test_params)

    avg_test_res=res

    best_val_acc = 0.0
    best_precision = 0.0
    best_recall = 0.0

    avg_test_res = [r for r in avg_test_res if not r == 0]

    for _,client_metrics in enumerate(avg_test_res):
        if client_metrics[1] >= best_val_acc:
            best_val_acc = client_metrics[1]
            best_precision = client_metrics[2]
            best_recall = client_metrics[3]

    print("Best val_acc: ", best_val_acc)
    print("Best precision: ", best_precision)
    print("Best recall: ", best_recall)


    ####### Activation functions optimization
    act_fn = ["relu", "sigmoid", "tanh"]
    act_fn_params = dict(learn_rate=[avg_params[0]], batch_size=[int(math.ceil(avg_params[1]))], epochs=[int(math.ceil(avg_params[2]))],momentum=[avg_params[3]], act_fn=act_fn)

    clientsData, clientsDataLabels = prepare_data_for_X_clients(x_train, y_train, ds_info)

    res = grid_search_for_X_clients(numClients, clientsData, clientsDataLabels, act_fn_params)

    act_big_res=res

    ## Sort dataframes
    act_big_res = [r for r in act_big_res if not r == 0]
    sorted_data_activation = []
    for _,df in enumerate(act_big_res):
        sorted_data_activation.append(df.data.sort_values(by='val_accuracy',ascending=False))

    for i in range(len(sorted_data_activation)):
        with pandas.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
            display(sorted_data_activation[i])

    ## Get best activation function
    act_fn_count = np.zeros(len(act_fn), dtype=float)

    for _,client_data in enumerate(sorted_data_activation):
        for i,fn in enumerate(act_fn):
            if client_data.head(1)['act_fn'].item() == fn:
                act_fn_count[i] += 1

    best_act_fn = act_fn[np.argmax(act_fn_count)]
    print("Best activation function :", best_act_fn)

In [None]:
dataset_name = 'svhn_cropped'
experiment_name = f"{dataset_name}_iid"

# define the grid search parameters
params = dict(act_fn = ["relu"],
              batch_size = [64],#[32,16,8,4,2]#[32,64,128]#[8,16,32]#[256,128,64,32]
              epochs = [125],#[75,125]
              learn_rate = [0.1],#[0.1,0.15,0.2,0.25,0.3]#[0.08, 0.1, 0.2, 0.3]#[0.001,0.01, 0.1]
              momentum = [0.9])

ds, test_split, ds_info = load_tf_dataset(dataset_name=dataset_name, decentralized=True, display=True)
ds_info['num_clients'] = 10
run(params, ds, test_split, ds_info)