<a href="https://colab.research.google.com/github/romoreira/distributed_learning/blob/master/TRAILS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorflow



In [3]:
import os
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.backend import image_data_format

import logging
import threading
import time

import matplotlib.pyplot as plt
import numpy as np
import copy
import random
import sys

from google.colab import drive
drive.mount('/content/gdrive/')
import sys
sys.path.append('/content/gdrive/MyDrive/Colab Notebooks')

from build_model import Model
import csv

# client config
NUMOFCLIENTS = 2 # number of client(as particles)
SELECT_CLIENTS = 0.5 # c
EPOCHS = 1 # number of total iteration
CLIENT_EPOCHS = 5 # number of each client's iteration
BATCH_SIZE = 10 # Size of batches to train on
DROP_RATE = 0

#lists of federated_clients
federated_clients_as_threads = []

# model config 
LOSS = 'categorical_crossentropy' # Loss function
NUMOFCLASSES = 10 # Number of classes
lr = 0.0025
# OPTIMIZER = SGD(lr=0.015, decay=0.01, nesterov=False)
OPTIMIZER = SGD(lr=lr, momentum=0.9, decay=lr/(EPOCHS*CLIENT_EPOCHS), nesterov=False) # lr = 0.015, 67 ~ 69%

def utf8len(s):
    return len(s.encode('utf-8'))

def write_csv(method_name, list):
    file_name = '{name}_CIFAR10_randomDrop_{drop}%_output_C_{c}_LR_{lr}_CLI_{cli}_CLI_EPOCHS_{cli_epoch}_TOTAL_EPOCHS_{epochs}_BATCH_{batch}.csv'
    file_name = file_name.format(folder="origin_drop",drop=DROP_RATE, name=method_name, c=SELECT_CLIENTS, lr=lr, cli=NUMOFCLIENTS, cli_epoch=CLIENT_EPOCHS, epochs=EPOCHS, batch=BATCH_SIZE)

    save_path = "/content/gdrive/MyDrive/Colab Notebooks"
    completeName = os.path.join(save_path, file_name)

    f = open(completeName, 'w', encoding='utf-8', newline='')
    wr = csv.writer(f)
    
    

    for l in list:
        wr.writerow(l)
    
    f.close()


def load_dataset():
    # Code for experimenting with CIFAR-10 datasets.
    (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
    
    # Code for experimenting with MNIST datasets.
    # (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
    # X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
    # X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
    
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train = X_train / 255.0
    X_test = X_test / 255.0

    Y_train = to_categorical(Y_train)
    Y_test = to_categorical(Y_test)

    return (X_train, Y_train), (X_test, Y_test)


def init_model(train_data_shape):
    print("Data Shape: "+str(train_data_shape))
    model = Model(loss=LOSS, optimizer=OPTIMIZER, classes=NUMOFCLASSES)
    fl_model = model.fl_paper_model(train_shape=train_data_shape)

    return fl_model


def client_data_config(x_train, y_train):
    client_data = [() for _ in range(NUMOFCLIENTS)] # () for _ in range(NUMOFCLIENTS)
    num_of_each_dataset = int(x_train.shape[0] / NUMOFCLIENTS)

    print("Size of x_train: "+str(len(x_train)))

    print("Num_of_each_dataset: "+str(num_of_each_dataset))
    
    for i in range(NUMOFCLIENTS):
        split_data_index = []
        while len(split_data_index) < num_of_each_dataset:
            item = random.choice(range(x_train.shape[0]))
            if item not in split_data_index:
                split_data_index.append(item)
        
        new_x_train = np.asarray([x_train[k] for k in split_data_index])
        new_y_train = np.asarray([y_train[k] for k in split_data_index])
    
        client_data[i] = (new_x_train, new_y_train)

    return client_data


def fedAVG(server_weight):
    #print("Server_weight[0]): "+str(server_weight[0]))
    avg_weight = np.array(server_weight[0])
    print("len(Server_weight[0]): "+str(len(server_weight)))

    if len(server_weight) > 1:
        for i in range(1, len(server_weight)):
            print("Each i of server_weight: "+str(server_weight[i]))
            avg_weight += server_weight[i]
    
    avg_weight = avg_weight / len(server_weight)

    return avg_weight


def client_update(index, client, avg_weight, x_test, y_test):
    print("Fed_Client Thread {}/{} fitting\n".format(index + 1, int(NUMOFCLIENTS * SELECT_CLIENTS)))

    
    client.fit(client_data[index][0], client_data[index][1],
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=1,
        validation_split=0.2,
    )

    print("\n\n\nEnd of training Thread: "+str(index))
    print("\n\n\n Validating the client training: "+str(index))
    scores = fed_client_evaluation(client, x_test, y_test)
    print("\n\nThe accuracy of Client "+str(index)+" is: "+str(scores)+" \n\n\n\n")
    send_model_to_server(scores, index)

def fed_client_evaluation(model, x_test, y_test):
    return model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=1)


def send_model_to_server(metric_scores, client_index):
    federated_clients_as_threads[client_index].append(metric_scores)

def waiter():
    while(True):
        print("\n\n\n\nCurrent Status of Federated Client Set: "+str(federated_clients_as_threads))
        time.sleep(10)


def fed_client_selection(policy):

    if policy == "TRAILS":
        print("Decidir qual cliente escolher")
        #https://github.com/dnanhkhoa/simple-bloom-filter

if __name__ == "__main__":
    
    (x_train, y_train), (x_test, y_test) = load_dataset()

    federated_clients_as_threads = []

    server_model = init_model(train_data_shape=x_train.shape[1:])
    server_model.summary()

    client_data = client_data_config(x_train, y_train)
    print("Client_data: "+str(len(client_data)))
    fl_models = []
    for i in range(NUMOFCLIENTS):
        fl_models.append(init_model(train_data_shape=client_data[i][0].shape[1:]))


    avg_weight = np.zeros_like(server_model.get_weights())
    print("AVG_Weight: "+str(avg_weight))
    server_evaluate_acc = []

    print("NUMOFCLIENTS: "+str(NUMOFCLIENTS))
    print("Select_clients: "+str(SELECT_CLIENTS))


    waiter = threading.Thread(target=waiter, args=(), daemon=True)
    waiter.start()
    
      
    for index, client in enumerate(fl_models):
   
      #print("Index: "+str(index))
      #print("CLIENT: "+str(client))
      fed = []
      a = threading.Thread(target=client_update, args=(index, client, avg_weight, x_test, y_test,))
      fed.append(a)
      fed.append(index)
      fed.append(EPOCHS)

      federated_clients_as_threads.append(fed)
      #print(federated_clients_as_threads)
      #print("Client "+str(index)+" is inside list and ready to train: "+str(federated_clients_as_threads[index][0]))
 
#            recv_model = client_update(index, client, epoch, avg_weight)
#            evaluation = fed_client_evaluation(recv_model)
#            fed_client_selection(evaluation)
#
#            print("Tamanho do RECV_MODEL: "+str(utf8len(str(recv_model))))
#            
#            rand = random.randint(0,99)
#            drop_communication = range(DROP_RATE)
#            if rand not in drop_communication:
#                server_weight.append(copy.deepcopy(recv_model.get_weights()))
#        
#        avg_weight = fedAVG(server_weight)

        #print("avg_weight: "+str(avg_weight))


#        server_model.set_weights(avg_weight)
#        print("server {}/{} evaluate".format(epoch + 1, EPOCHS))
#        server_evaluate_acc.append(server_model.evaluate(x_test, y_test, batch_size=BATCH_SIZE, verbose=1))

    #for i in range(len(federated_clients_as_threads)):
    #    print("IIIII: "+str(i))
    #    federated_clients_as_threads[i][0].start()

    for i in range(len(federated_clients_as_threads)):
        federated_clients_as_threads[i][0].start()

    waiter.join()
    print(federated_clients_as_threads)
    for i in range(len(federated_clients_as_threads)):
        federated_clients_as_threads[i][0].join()
    

#    write_csv("FedAvg", server_evaluate_acc)
      
    
  

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


  super(SGD, self).__init__(name, **kwargs)


Data Shape: (32, 32, 3)
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 32, 32, 32)        2432      
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 32)        25632     
                                                                 
 max_pooling2d (MaxPooling2D  (None, 16, 16, 32)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 16, 16, 32)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 16, 16, 64)        51264     
                                                                 
 conv2d_3 (Conv2D)           (None, 16, 16, 64)        102464    
                                



 262/2000 [==>...........................] - ETA: 38s - loss: 11.6709 - accuracy: 0.1603



Current Status of Federated Client Set: [[<Thread(Thread-12, started 140533439444736)>, 0, 1], [<Thread(Thread-13, started 140533431052032)>, 1, 1]]



Current Status of Federated Client Set: [[<Thread(Thread-12, started 140533439444736)>, 0, 1], [<Thread(Thread-13, started 140533431052032)>, 1, 1]]



Current Status of Federated Client Set: [[<Thread(Thread-12, started 140533439444736)>, 0, 1], [<Thread(Thread-13, started 140533431052032)>, 1, 1]]



Current Status of Federated Client Set: [[<Thread(Thread-12, started 140533439444736)>, 0, 1], [<Thread(Thread-13, started 140533431052032)>, 1, 1]]



Current Status of Federated Client Set: [[<Thread(Thread-12, started 140533439444736)>, 0, 1], [<Thread(Thread-13, started 140533431052032)>, 1, 1]]




Current Status of Federated Client Set: [[<Thread(Thread-12, started 140533439444736)>, 0, 1], [<Thread(Thread-13, started 140533431052032)>, 1, 1]

KeyboardInterrupt: ignored