In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt
import collections
import nest_asyncio
%reload_ext tensorboard
nest_asyncio.apply()
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

In [2]:
class server:
    
    def __init__(self, numOfClients):
        self.numOfClients = numOfClients #Num of Client
        self.clientIds = list(range(0,self.numOfClients)) #Client List
        self.updateRoundNum = 0  # Update Round Number
        self.serverModel = self.createBaseModel() #Create a Keras Model for MNIST
        self.clientActive = [] #If client is available or not based on time - day/night
        self.clientCIF = []
        self.clientSelected = []
        self.maxClientsPerRound = 5 
        self.updateFromClients=[]
        self.__serverTestData_X=[]
        self.__serverTestData_Y=[]
        self.__serverTestData = self.setServerTestDataTFF(emnist_test)
        self.predictionAcc=[]
        self.predictionLoss=[]
        
        
    def createBaseModel(self):
        #return tf.keras.models.Sequential([
        #    Dense(64, activation='relu',input_shape=(784,)),
        #    Dense(64, activation='relu'),
        #    Dense(10, activation='softmax'),])
        return tf.keras.models.Sequential([
            tf.keras.layers.Input(shape=(784,)),
            tf.keras.layers.Dense(10, kernel_initializer='zeros'),
            tf.keras.layers.Softmax()])
    
    def setServerTestData(self,Xtest,Ytest):  # NOT USED IN THIS CASE
        self.__serverTestData_X=Xtest
        self.__serverTestData_Y=Ytest
        
    def setServerTestDataTFF(self,serverTestData):  #Creates a test set for server
        BATCH_SIZE = 100
        SHUFFLE_BUFFER = 100
        PREFETCH_BUFFER= 10

        def preprocess(dataset):
            def batch_format_fn(element):
                """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
                return collections.OrderedDict(
                    x=tf.reshape(element['pixels'], [-1, 784]),
                    y=tf.reshape(element['label'], [-1, 1]))

            return dataset.repeat(10).shuffle(SHUFFLE_BUFFER).batch(
              BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

        serverData = tfds.as_numpy(preprocess(serverTestData.create_tf_dataset_for_client(serverTestData.client_ids[30])))

        return serverData
     
        
    def initialBroadcast(self):
        for i in self.clientIds:
            clientName = client_list["client_"+str(self.clientIds[i])]  #Create Client Name, Using Client_List Dictionary
            #clientName = "client_"+str(self.clientIds[i])  #Create Client Name
            #print(clientName)
            clientName.setInitialModel(self.serverModel) #setInitialModel-> Method of client Class
            #eval(clientName).setInitialModel(self.serverModel) #setInitialModel-> Method of client Class
            
    def getClientActiveStatus(self):  #if the client is available for update or not randomly set in client
        self.clientActive = []
        for i in self.clientIds:
            clientName = client_list["client_"+str(self.clientIds[i])]  #Create Client Name
            #print(clientName)
            if(clientName.sendActiveStatus() == 1): #getActiveStatus() -> method of client class
                self.clientActive.append(i)
        print("Active Clients in Round :--: ", self.clientActive)
    
    
    def getClientCIF(self): # Get Client Importance Factor for active clients
        self.clientCIF=[]
        for i in self.clientActive:
            clientName = client_list["client_"+str(i)]
            c_cif = clientName.sendCIF()   #sendCIF() - method of client
            self.clientCIF.append(c_cif)
        print("Clients with Acceptable CIF: ", self.clientCIF)
        
    def getClientSelected(self): # Select The Top (N = maxClientsPerRound) with the highest CIF value
        self.clientSelected = []
        # Use non-class, local variables to leave class variables clean
        clientCIF, clientActive = ( list(t) for t in zip(*sorted(zip(self.clientCIF, self.clientActive)))) # Sort Active Clients and their CIF values, by CIF
        self.clientSelected  = clientActive[-self.maxClientsPerRound:] # Select the N=maxClientsPerRound of clients with highest CIF
        self.clientCIF = clientCIF[-self.maxClientsPerRound:] # Select the N=maxClientsPerRound of highest CIF
            
    
    def getModelUpdateFromClients(self):
        print("--------------------------------------------\n","Round NO:",self.updateRoundNum)
        print("Active Clients: ",self.clientActive)
        print("CIF of Active Clients: ", self.clientCIF)
        print("Selected Clients: ", self.clientSelected)
        print("Training Clients with CIF > 10")
        self.updateRoundNum +=1
        self.updateFromClients=[]
        #self.dataPointsClients=[]
        
        for i in range (0,len(self.clientSelected)):
            if self.clientCIF[i] >= 10:  #Select Client only if CIF > some value 10 chosen for testing ONLY
                clientName = client_list["client_"+str(int(self.clientSelected[i]))]
                self.updateFromClients.append(clientName.sendClientUpdate()) #Get a list of updates from selected clients
            
        avgModelWeights = self.computeFedAvg(self.updateFromClients)
        #print("average model weights = \n",avgModelWeights)
        self.serverModel.set_weights(avgModelWeights)   
        
    def computeFedAvg(self,updates):  #Compute Federated Averaging from Available Clients
        totalDataPoints = 0
        scaleFactor=[]
        for i in range (0,len(updates)):
            totalDataPoints += updates[i][1]   #Sum the total Data Points on All Available Clients
            scaleFactor.append(updates[i][1])  #Store individual number datapoints for clients

        scaleFactor = np.array(scaleFactor)/totalDataPoints #Create the scale factor
       
        sumOfAvgWeights = []*len(updates[0][0]) 
        
        for j in range (0,len(updates[0][0])): #range of layers
            k=np.zeros_like(updates[0][0][j])
            #print("ShapeofK:",k.shape)
            for i in range(0,len(updates)): #range of clients
                k=k+updates[i][0][j]*scaleFactor[i]
                #print("ShapeofK:",k.shape)
            sumOfAvgWeights.append(k)
      
        return sumOfAvgWeights   
    
    def testServerModel(self):
        self.serverModel.compile(loss = 'sparse_categorical_crossentropy',
                                optimizer = tf.keras.optimizers.SGD (learning_rate=.01),
                                metrics=['accuracy'])
        testData = next(iter(self.__serverTestData))
        return self.serverModel.evaluate(testData['x'],
                                         testData['y'])
        
        
    def updateAllClients(self):
        for i in self.clientIds:
            clientName = client_list["client_"+str(self.clientIds[i])]  #Create Client Name
            clientName.setModelUpdateWeights(self.serverModel.get_weights()) 