### Import Numpy, Tensorflow (TFF), Asyncio Packages ###

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt
import collections
import nest_asyncio
#%reload_ext tensorboard
nest_asyncio.apply()
#from tensorflow.keras.models import Sequential
#from tensorflow.keras.layers import Dense
#from tensorflow.keras.utils import to_categorical
#from tensorflow.keras.datasets import mnist

TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


### Import PyTorch Packages ###

In [2]:
# Neural network model in pytorch defined here
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
#        self.conv1 = nn.Conv2d(1, 20, 5, 1)
#        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = torch.from_numpy(np.array(x))
#        x = F.relu(self.conv1(x))
#        x = F.max_pool2d(x, 2, 2)
#        x = F.relu(self.conv2(x))
#        x = F.max_pool2d(x, 2, 2)
#        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
#        m = nn.LogSoftmax(dim = 1)
        return x

### Import EMNIST Tensorflow Federated Dataset ###

In [3]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)

### Create Server Class ###

In [4]:
class server:
    
    def __init__(self, numOfClients):
        self.numOfClients = numOfClients #Num of Client
        self.clientIds = list(range(0,self.numOfClients)) #Client List
        self.updateRoundNum = 0  # Update Round Number
        self.serverModel = self.createBaseModel() #Create a Keras Model for MNIST
        self.clientActive = [] #If client is available or not based on time - day/night
        self.clientCIF = []
        self.updateFromClients=[]
        self.clientSelected = []
        self.maxClientsPerRound = 10 #This must be LESS than numOfClients, or sorting will not work properly!
        self.__serverTestData_X=[]
        self.__serverTestData_Y=[]
        self.__serverTestData= self.setServerTestDataTFF(emnist_test)
        self.predictionAcc=[]
        self.predictionLoss=[]
        
        
    def createBaseModel(self):
        return Net()
    
    def setServerTestData(self,Xtest,Ytest):  # NOT USED IN THIS CASE
        self.__serverTestData_X=Xtest
        self.__serverTestData_Y=Ytest
        
    def setServerTestDataTFF(self,serverTestData):  #Creates a test set for server
        
        BATCH_SIZE = 20
        SHUFFLE_BUFFER = 100
        PREFETCH_BUFFER= 10

        def preprocess(dataset):
            def batch_format_fn(element):
        #        Flatten a batch `pixels` and return the features as an `OrderedDict`.
                return collections.OrderedDict(
                    x=tf.reshape(element['pixels'], [-1, 784]),
                    y=tf.reshape(element['label'], [-1, 1]))

            return dataset.repeat(10).shuffle(SHUFFLE_BUFFER).batch(
              BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

        serverData = tfds.as_numpy(preprocess(serverTestData.create_tf_dataset_for_client(serverTestData.client_ids[30])))
        

        return serverData

    
    def initialBroadcast(self):
        for i in self.clientIds:
            clientName = client_list["client_"+str(self.clientIds[i])]  #Create Client Name, Using Client_List Dictionary
            #clientName = "client_"+str(self.clientIds[i])  #Create Client Name
            #print(clientName)
            clientName.setInitialModel(self.serverModel) #setInitialModel-> Method of client Class
            #eval(clientName).setInitialModel(self.serverModel) #setInitialModel-> Method of client Class
            
    def getClientActiveStatus(self):  #if the client is available for update or not randomly set in client
        self.clientActive = []
        for i in self.clientIds:
            clientName = client_list["client_"+str(self.clientIds[i])]  #Create Client Name
            #print(clientName)
            if(clientName.sendActiveStatus() == 1): #getActiveStatus() -> method of client class
                self.clientActive.append(i)
               
    def getClientCIF(self): # Get Client Importance Factor for active clients
        self.clientCIF=[]
        for i in self.clientActive:
            clientName = client_list["client_"+str(i)]
            c_cif = clientName.sendCIF()   #sendCIF() - method of client
            self.clientCIF.append(c_cif)
        #print("Clients with Acceptable CIF: ", self.clientCIF)
        
    def getClientSelected(self): # Select The Top (N = maxClientsPerRound) with the highest CIF value
        self.clientSelected = []
        # Use non-class, local variables to leave class variables clean
        
        # Filter out clients with too low of a CIF
        print("CIF of Clients:" + str(self.clientCIF))
        print("Active Clients:" + str(self.clientActive))
        print("")
        
        CIF_Threshold = 0
        clientCIF_filt = []
        clientActive_filt = []
        
        for clientCIF, clientActive in zip(self.clientCIF, self.clientActive):
            if clientCIF > CIF_Threshold: #Client CIF must be greater than
                clientCIF_filt.append(clientCIF)
                clientActive_filt.append(clientActive)

        print("Filtered CIF of Clients:" + str(clientCIF_filt))
        print("Filtered Active Clients:" + str(clientActive_filt))
        print("")

        # Sort Filtered clients        
        clientCIF, clientActive = ( list(t) for t in zip(*sorted(zip(clientCIF_filt, clientActive_filt)))) # Sort Active Clients and their CIF values, by CIF
        print("Sorted CIF of Cleints: " + str(clientCIF[-self.maxClientsPerRound:]))
        print("Sorted Active Clients: " + str(clientActive[-self.maxClientsPerRound:]))

        self.clientSelected  = clientActive[-self.maxClientsPerRound:] # Select the N=maxClientsPerRound of clients with highest CIF
        self.clientCIF = clientCIF[-self.maxClientsPerRound:] # Select the N=maxClientsPerRound of highest CIF
    
    def getModelUpdateFromClients(self):
        print("--------------------------------------------\n","Round NO:",self.updateRoundNum)
        print("Active Clients: ",self.clientActive)
        print("CIF of Active Clients: ", self.clientCIF)
        print("Selected Clients: ", self.clientSelected)
        print("Training Clients with CIF > 10")
        self.updateRoundNum +=1
        self.updateFromClients=[]
        #self.dataPointsClients=[]
        for i in range (0,len(self.clientSelected)):
            if self.clientCIF[i] >= 10:  #Select Client only if CIF > some value 10 chosen for testing ONLY
                clientName = client_list["client_"+str(int(self.clientSelected[i]))]
                self.updateFromClients.append(clientName.sendClientUpdate()) #Get a list of updates from selected clients
            
        avgModelWeights = self.computeFedAvg(self.updateFromClients)
        #print("average model weights = \n",avgModelWeights)
        #self.serverModel.set_weights(avgModelWeights)   
        
        
        new_state_dict = {}
        model_counter = 0
        # handle the conv layers part which is not changing
        for param_idx, (key_name, param) in enumerate(self.serverModel.state_dict().items()):
        #print("&"*30)
        #print("Key: {}, Weight Shape: {}, Matched weight shape: {}".format(key_name, param.size(), weights[param_idx].shape))
        #print("&"*30)
            if "conv" in key_name or "features" in key_name:
                if "weight" in key_name:
                    temp_dict = {key_name: avgModelWeights[param_idx].reshape(param.size())}
                elif "bias" in key_name:
                    temp_dict = {key_name: avgModelWeights[param_idx]}
            elif "fc" in key_name or "classifier" in key_name:
                if "weight" in key_name:
                    temp_dict = {key_name: avgModelWeights[param_idx]}
                elif "bias" in key_name:
                    temp_dict = {key_name: avgModelWeights[param_idx]}

            new_state_dict.update(temp_dict)
            
            
        self.serverModel.load_state_dict(new_state_dict)
    
    def computeFedAvg(self,updates):#Compute Federated Averaging from Available Clients
        with torch.no_grad():
            
            totalDataPoints = 0
            scaleFactor=[]
            for i in range (0,len(updates)):
                totalDataPoints += updates[i][1]   #Sum the total Data Points on All Available Clients
                scaleFactor.append(updates[i][1])  #Store individual number datapoints for client
        
        
            scaleFactor = np.array(scaleFactor)/totalDataPoints #Create the scale factor
        
            sumOfAvgWeights = []*len(updates[0][0]) 
        
            for j in range (0,len(updates[0][0])): #range of layers
                k=torch.zeros_like(updates[0][0][j])
                for i in range(0,len(updates)): #range of clients
                    k=k+updates[i][0][j]*scaleFactor[i]
                #print("ShapeofK:",k.shape)
                sumOfAvgWeights.append(k)
        
        
            return sumOfAvgWeights   
    
    def testServerModel(self):
        testData = next(iter(self.__serverTestData))
        output = self.serverModel(testData['x'])
        testData['y'] = torch.LongTensor(testData['y'])
        target = torch.max(testData['y'], 1)[0]
        loss = F.cross_entropy(output, target)
        pred = output.data.max(1, keepdim=True)[1]
        correct = pred.eq(testData['y'].view_as(pred)).sum().item()
        accuracy = 100*correct/len(testData['y'])
        print('\nLength of the test data: ',len(testData['y']))
        print("Loss: ",loss)
        print("Accuracy: ",accuracy)
        return loss, accuracy
      
    def updateAllClients(self):
        
        updatedWeights = []
        for param in self.serverModel.parameters():
            updatedWeights.append(param)
       
        for i in self.clientIds:
            clientName = client_list["client_"+str(self.clientIds[i])]
        
            clientName.setModelUpdateWeights(updatedWeights) ## This might have to be changed to access client_list dictionary


### Create Client Class ###

In [14]:
class client:
    
    def __init__(self,ID):
        self.id = ID
        self.__clientModel = Net()
        self.__clientModelWeights=[]
        self.activeStatus = 0
        self.__clientDataX=[] #Private to Client
        self.__clientDataY=[]
        self.clientCIF=0
        self.meanDistanceToReceiver = int(np.random.normal(15, 2.5, 1)) #Select random int from normal dist 
        self.roundSINR = 0
        self.propFair = 0
        self.epochs = 10
        self.lr = 0.01
        self.__clientData = self.__clientDataFromTFF(emnist_train)
    
    
    #Methods Private to client
    
    def __clientDataFromTFF(self,tffDataset):
        NUM_EPOCHS = self.epochs
        BATCH_SIZE = 100
        SHUFFLE_BUFFER = 100
        PREFETCH_BUFFER= 10

        def preprocess(dataset):
            def batch_format_fn(element):
                """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
                return collections.OrderedDict(
                    x=tf.reshape(element['pixels'], [-1, 784]),
                    y=tf.reshape(element['label'], [-1, 1]))

            return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
              BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
        
        return tfds.as_numpy(preprocess(tffDataset.create_tf_dataset_for_client(tffDataset.client_ids[self.id])))
        
          
    def __getTrainData(self):
        return self.__clientDataX, self.__clientDataY  # Just for Testing will be expanded to select meaningful data
    
    #Public Methods
    def sendActiveStatus(self):
        self.__setActiveStatus()
        return self.activeStatus
    
    def sendCIF(self):
        self.__setCIF()
        return self.clientCIF
    
        #Methods Private to client
    def __setActiveStatus(self): 
        #self.activeStatus = int(np.random.randint(0,2,1))
        self.activeStatus = 1
    
    def __setRoundSINR(self):
        self.roundSINR = np.round((np.random.normal(self.meanDistancetoReceive, 1, 1)[0]), 2)
                               
    #def __setPropFair(self):
    #    self.propFair = self.roundSINR / self.meanSINR
        
    def __setdataQuality(self):\
        
        ## Calculate for entropy - Start ##
        
#         for values in classification_probabilities:
#           for value in values:
#             score += value * np.log(value)
#         score = score * -1

        ## Calculate for entropy - End ##
        
        self.dataQuality = 20 #len(list(self.__clientData)) # Set 20 so that CIF value will always be passing
    
    def __setCIF(self):
        #self.clientCIF = int(np.random.randint(0,50,1))  # Random Number between 0 and 49 for check purpose only
        self.__setRoundSINR()
        self.__setPropFair() #This is currently unused
        self.__setdataQuality()
        self.clientCIF = self.dataQuality
        # self.clientCIF = (-1/self.roundSINR) + self.dataQuality #Formula for computing client CIF
    
    def setInitialModel(self,model):
        self.__clientModel=model
         
    def setClientData(self,X,Y):
        self.__clientDataX = X
        self.__clientDataY = Y
   

    def sendClientUpdate(self):
        print("->",self.id,end=" ")
        optimizer = optim.SGD(self.__clientModel.parameters(), lr = self.lr)
        #print(self.__clientDataX.shape)
        #print(self.__clientDataY.shape)
        if len(self.__clientDataX) > 0:
            numOfDataPoints = len(self.__clientDataX)
            for epoch in range(self.epochs):
                optimizer.zero_grad()
                output = self.__clientModel(self.__clientDataX)
                loss = F.cross_entropy(output, torch.max(self.__clientDataY, 1)[0])
                loss.backward()
                optimizer.step()
        
        else:
            currentData = next(iter(self.__clientData))
        
            numOfDataPoints = len(currentData['x'])
            for epoch in range(self.epochs):
                optimizer.zero_grad()
                output = self.__clientModel(currentData['x'])
                currentData['y'] = torch.LongTensor(currentData['y'])
                loss = F.cross_entropy(output, torch.max(currentData['y'], 1)[0])
                loss.backward()
                optimizer.step()

            
        updatedWeights = []
        for param in self.__clientModel.parameters():
            updatedWeights.append(param)
       
      
        return (updatedWeights,numOfDataPoints)
    
    def setModelUpdateWeights(self,modelUpdateWeights):
        new_state_dict = {}
        model_counter = 0
        # handle the conv layers part which is not changing
        for param_idx, (key_name, param) in enumerate(self.__clientModel.state_dict().items()):
        #print("&"*30)
        #print("Key: {}, Weight Shape: {}, Matched weight shape: {}".format(key_name, param.size(), weights[param_idx].shape))
        #print("&"*30)
            if "conv" in key_name or "features" in key_name:
                if "weight" in key_name:
                    temp_dict = {key_name: modelUpdateWeights[param_idx].reshape(param.size())}
                elif "bias" in key_name:
                    temp_dict = {key_name: modelUpdateWeights[param_idx]}
            elif "fc" in key_name or "classifier" in key_name:
                if "weight" in key_name:
                    temp_dict = {key_name: modelUpdateWeights[param_idx]}
                elif "bias" in key_name:
                    temp_dict = {key_name: modelUpdateWeights[param_idx]}

            new_state_dict.update(temp_dict)
            
            
        self.__clientModel.load_state_dict(new_state_dict)


        
    def plotClientData(self):
        print("Y_Data",self.__clientDataY)
        id0=[]
        id1=[]
        id2=[]
        id3=[]
        id4=[]
        id5=[]
        id6=[]
        id7=[]
        id8=[]
        id9=[]
        for i in range (0,len(self.__clientData['x'])):
            for j in range (0,10):
                if j==self.__clientData['y']:
                    eval('id'+str(j)).append(i)
        xAxis=['0','1','2','3','4','5','6','7','8','9']
        yAxis=[len(id0),len(id1),len(id2),len(id3),len(id4),len(id5),len(id6),len(id7),len(id8),len(id9)]
        plt.bar(xAxis,yAxis)
        title="Data for Client"+str(self.id)
        plt.title(title)
        plt.show()

## Initializing Clients

In [15]:
client_list = {}

# Set the number of clients
num_clients = 11
for i in range(num_clients):
    name = "client_" + str(i)
    client_list[name] = client(i)

### Initializing Sever

In [16]:
# Initialize server with number of clients
serverFA = server(num_clients)

### Inital Broadcast ###

In [17]:
serverFA.initialBroadcast()

In [18]:
# serverFA.getClientActiveStatus()
# print(serverFA.clientActive)

In [19]:
# serverFA.getClientCIF()
# print(serverFA.clientCIF)

In [20]:
# serverFA.getClientSelected()
# #print(serverFA.clientSelected)

In [21]:
numberOfTraingRounds = 10
trainingMetrics=[]*numberOfTraingRounds
for i in range(0,numberOfTraingRounds):
    serverFA.getClientActiveStatus()
    serverFA.getClientCIF()
    serverFA.getClientSelected()
    serverFA.getModelUpdateFromClients()
    trainingMetrics.append(serverFA.testServerModel())
    serverFA.updateAllClients()


CIF of Clients:[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
Active Clients:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

Filtered CIF of Clients:[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
Filtered Active Clients:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

Sorted CIF of Cleints: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
Sorted Active Clients: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
--------------------------------------------
 Round NO: 0
Active Clients:  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
CIF of Active Clients:  [20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
Selected Clients:  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Training Clients with CIF > 10
-> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> 9 -> 10 
Length of the test data:  20
Loss:  tensor(2.2485, grad_fn=<NllLossBackward>)
Accuracy:  10.0
CIF of Clients:[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
Active Clients:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

Filtered CIF of Clients:[20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
Filtered Active Clients:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

StopIteration: 

In [None]:
plt.plot(trainingMetrics)

### Below are Rough cells ... donot run

### --> SNR - Get Value the Channel Importance - Transmit Power, Additional Gaussian Noice, Channel Rayliegh Fading (hk) 

### -->Data Augmentation at Every Round
: Client 0 - 6000 Data, For first round - 3000 - specifying the epochs, batch_size =32 (Default) model.fit()
3400 - Calculate the Entropy - CIF along with the SNR- Importance of the update Training - additional 400 data
4000

### --> Code Benchmarking use EMNIST dataset and check the code w.r.t TFF solution <- 20TH NOV

### --> Check on the leaf dataset - LSTM , CNN, 

### --> Second order optimizer - implement using adahessian < change the model pytorch

### --> comparisons with Just SNR, SNR with Data importance, with and without adahessian, 

### --> check for convergence vs number of users

### --> How will the code perform if the model changed to CNN