In [None]:
# Importing all the necessary libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import os, cv2
import tensorflow as tf

from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Importing the data

data = pd.read_csv('/content/drive/MyDrive/ham_10000/HAM10000_28.28_L.csv')

Visualising Dataset

In [None]:
data

Unnamed: 0,pixel0000,pixel0001,pixel0002,pixel0003,pixel0004,pixel0005,pixel0006,pixel0007,pixel0008,pixel0009,...,pixel0775,pixel0776,pixel0777,pixel0778,pixel0779,pixel0780,pixel0781,pixel0782,pixel0783,label
0,169,171,170,177,181,182,181,185,194,192,...,184,186,185,180,157,140,140,159,165,2
1,19,57,105,140,149,148,144,155,170,170,...,172,175,160,144,114,89,47,18,18,2
2,155,163,161,167,167,172,155,152,165,175,...,163,178,157,166,167,148,141,136,115,2
3,25,71,116,139,136,153,148,161,172,162,...,125,135,138,137,111,71,32,16,16,2
4,129,162,181,196,205,208,205,213,225,224,...,210,197,172,190,195,193,181,147,88,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10010,172,171,173,175,164,187,207,210,208,206,...,210,217,221,209,185,187,192,192,192,0
10011,2,34,108,116,114,119,131,139,139,145,...,173,169,168,168,143,138,83,23,3,0
10012,122,154,162,170,179,197,200,195,202,199,...,221,215,205,187,209,198,187,164,156,0
10013,137,143,141,139,147,152,155,152,155,159,...,172,171,175,183,177,170,169,166,170,0


# Preprocessing dataset

In [None]:
data.isna().sum().sum()

0

In [None]:
y = data['label'].copy()
X = data.drop('label', axis=1).copy()

There are 7 possible labels.

From the dataset provider:

0: nv - Melanocytic nevi
1: mel - Melanoma
2: bkl - Benign keratosis-like lesions
3: bcc - Basal cell carcinoma
4: akiec - Actinic keratoses and intraepithelial carcinoma / Bowen's disease
5: vasc - Vascular lesions
6: df - Dermatofibroma

In [None]:
y.value_counts()

4    6705
6    1113
2    1099
1     514
0     327
5     142
3     115
Name: label, dtype: int64

In [None]:
label_mapping = {
    0: 'nv',
    1: 'mel',
    2: 'bkl',
    3: 'bcc',
    4: 'akiec',
    5: 'vasc',
    6: 'df'
}

Rescaling

In [None]:
X = X / 255

X

Unnamed: 0,pixel0000,pixel0001,pixel0002,pixel0003,pixel0004,pixel0005,pixel0006,pixel0007,pixel0008,pixel0009,...,pixel0774,pixel0775,pixel0776,pixel0777,pixel0778,pixel0779,pixel0780,pixel0781,pixel0782,pixel0783
0,0.662745,0.670588,0.666667,0.694118,0.709804,0.713725,0.709804,0.725490,0.760784,0.752941,...,0.690196,0.721569,0.729412,0.725490,0.705882,0.615686,0.549020,0.549020,0.623529,0.647059
1,0.074510,0.223529,0.411765,0.549020,0.584314,0.580392,0.564706,0.607843,0.666667,0.666667,...,0.709804,0.674510,0.686275,0.627451,0.564706,0.447059,0.349020,0.184314,0.070588,0.070588
2,0.607843,0.639216,0.631373,0.654902,0.654902,0.674510,0.607843,0.596078,0.647059,0.686275,...,0.635294,0.639216,0.698039,0.615686,0.650980,0.654902,0.580392,0.552941,0.533333,0.450980
3,0.098039,0.278431,0.454902,0.545098,0.533333,0.600000,0.580392,0.631373,0.674510,0.635294,...,0.556863,0.490196,0.529412,0.541176,0.537255,0.435294,0.278431,0.125490,0.062745,0.062745
4,0.505882,0.635294,0.709804,0.768627,0.803922,0.815686,0.803922,0.835294,0.882353,0.878431,...,0.827451,0.823529,0.772549,0.674510,0.745098,0.764706,0.756863,0.709804,0.576471,0.345098
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10010,0.674510,0.670588,0.678431,0.686275,0.643137,0.733333,0.811765,0.823529,0.815686,0.807843,...,0.823529,0.823529,0.850980,0.866667,0.819608,0.725490,0.733333,0.752941,0.752941,0.752941
10011,0.007843,0.133333,0.423529,0.454902,0.447059,0.466667,0.513725,0.545098,0.545098,0.568627,...,0.721569,0.678431,0.662745,0.658824,0.658824,0.560784,0.541176,0.325490,0.090196,0.011765
10012,0.478431,0.603922,0.635294,0.666667,0.701961,0.772549,0.784314,0.764706,0.792157,0.780392,...,0.874510,0.866667,0.843137,0.803922,0.733333,0.819608,0.776471,0.733333,0.643137,0.611765
10013,0.537255,0.560784,0.552941,0.545098,0.576471,0.596078,0.607843,0.596078,0.607843,0.623529,...,0.698039,0.674510,0.670588,0.686275,0.717647,0.694118,0.666667,0.662745,0.650980,0.666667


In [None]:
X = np.array(X)

Binarizing the labels to 1-hot-encode the labels

In [None]:
#binarize the labels
lb = LabelBinarizer()
y = lb.fit_transform(y)

In [None]:
y

array([[0, 0, 1, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 1]])

In [None]:
print(X.shape)
print(y.shape)

(10015, 784)
(10015, 7)


Splitting into train and test

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(X,y, test_size=0.2, random_state=1)

In [None]:
def create_clients(X_train, Y_train, num_clients=10, initial='clients'):
    #create a list of client names
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

    #randomize the data
    data = list(zip(X_train, Y_train))
    random.shuffle(data)

    #shard data and place at each client
    size = len(data)//num_clients
    shards = [data[i:i + size] for i in range(0, size*num_clients, size)]

    #number of clients must equal number of shards
    assert(len(shards) == len(client_names))

    return {client_names[i] : shards[i] for i in range(len(client_names))} 

In [None]:
#create clients
clients = create_clients(X_train, Y_train, num_clients=10, initial='client')

In [None]:
def batch_data(data_shard, bs=32):
    #seperate shard into data and labels lists
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(bs)

In [None]:
#process and batch the training data for each client
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)
    
#process and batch the test set  
test_batched = tf.data.Dataset.from_tensor_slices((X_test, Y_test)).batch(len(Y_test))

In [None]:
class SimpleMLP:
    @staticmethod
    def build(shape, classes):
        model = Sequential()
        model.add(Dense(200, input_shape=(784,)))
        model.add(Activation("relu"))
        model.add(Dense(200))
        model.add(Activation("relu"))
        model.add(Dense(7))
        model.add(Activation("softmax"))
        return model

In [None]:
learning_rate = 0.01 
comms_round = 100
loss='categorical_crossentropy'
metrics = ['accuracy']
optimizer = SGD(learning_rate=learning_rate, 
                decay=learning_rate/ comms_round, 
                momentum=0.9
               )                      

In [None]:
def weight_scalling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    #get the bs
    bs = list(clients_trn_data[client_name])[0][0].shape[0]
    #first calculate the total training data points across clinets
    global_count = sum([tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names])*bs
    # get the total number of data points held by a client
    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy()*bs
    return local_count/global_count


def scale_model_weights(weight, scalar):
    '''function for scaling a models weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final

In [None]:
def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    #get the average grad accross all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
        
    return avg_grad


def test_model(X_test, Y_test,  model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    #logits = model.predict(X_test, batch_size=100)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis=1), tf.argmax(Y_test, axis=1))
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss))
    return acc, loss

In [None]:
def fl_model(clients_batched, test_batched, comms_round, loss, optimizer, metrics):
    #initialize global model
    smlp_global = SimpleMLP()
    global_model = smlp_global.build(784,7)

    #commence global training loop
    for comm_round in range(comms_round):
            
        # get the global model's weights - will serve as the initial weights for all local models
        global_weights = global_model.get_weights()
    
        #initial list to collect local model weights after scalling
        scaled_local_weight_list = list()

        #randomize client data - using keys
        client_names= list(clients_batched.keys())
        random.shuffle(client_names)
    
        #loop through each client and create new local model
        for client in client_names:
            smlp_local = SimpleMLP()
            local_model = smlp_local.build(784, 7)
            local_model.compile(loss=loss, 
                          optimizer=optimizer, 
                          metrics=metrics)
        
            #set local model weight to the weight of the global model
            local_model.set_weights(global_weights)
        
            #fit local model with client's data
            local_model.fit(clients_batched[client], epochs=1, verbose=0)
        
            #scale the model weights and add to list
            scaling_factor = weight_scalling_factor(clients_batched, client)
            scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
            scaled_local_weight_list.append(scaled_weights)
        
            #clear session to free memory after each communication round
            K.clear_session()
        
        #to get the average over all the local model, we simply take the sum of the scaled weights
        average_weights = sum_scaled_weights(scaled_local_weight_list)
    
        #update global model 
        global_model.set_weights(average_weights)

        #test global model and print out metrics after each communications round
        for(X_test, Y_test) in test_batched:
            global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)
            
            # Make predictions on the test set
            y_pred = global_model.predict(X_test)
            y_pred = np.argmax(y_pred, axis=1)
            
            # Get the true classes
            true_classes = np.argmax(Y_test, axis = 1)
      
             # Print the classification report
            report = classification_report(true_classes, y_pred)
            from sklearn.metrics import confusion_matrix
            confusion_matrix = confusion_matrix(true_classes, y_pred)
            if comm_round == 99:
              print(confusion_matrix)
              plt.show()
              print(report)

In [None]:
fl_model(clients_batched, test_batched, comms_round, loss, optimizer, metrics)

comm_round: 0 | global_acc: 66.251% | global_loss: 1.6503382921218872


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 1 | global_acc: 66.251% | global_loss: 1.6510295867919922


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 2 | global_acc: 66.251% | global_loss: 1.6228222846984863


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 3 | global_acc: 66.151% | global_loss: 1.6392771005630493


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 4 | global_acc: 66.151% | global_loss: 1.6379889249801636


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 5 | global_acc: 66.151% | global_loss: 1.628819465637207


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 6 | global_acc: 66.151% | global_loss: 1.6079626083374023


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 7 | global_acc: 66.151% | global_loss: 1.604748010635376


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 8 | global_acc: 66.151% | global_loss: 1.620723009109497


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 9 | global_acc: 66.151% | global_loss: 1.6512597799301147


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 10 | global_acc: 66.151% | global_loss: 1.5875557661056519


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 11 | global_acc: 66.151% | global_loss: 1.6174336671829224


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 12 | global_acc: 66.201% | global_loss: 1.6386327743530273


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 13 | global_acc: 66.201% | global_loss: 1.6042187213897705


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 14 | global_acc: 66.301% | global_loss: 1.6107243299484253


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 15 | global_acc: 66.301% | global_loss: 1.5971627235412598


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 16 | global_acc: 66.251% | global_loss: 1.611230731010437


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 17 | global_acc: 66.251% | global_loss: 1.6039305925369263


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 18 | global_acc: 66.201% | global_loss: 1.6027973890304565


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 19 | global_acc: 66.251% | global_loss: 1.62092924118042


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 20 | global_acc: 66.251% | global_loss: 1.5995558500289917


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 21 | global_acc: 66.201% | global_loss: 1.595375418663025


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 22 | global_acc: 66.251% | global_loss: 1.6118409633636475


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 23 | global_acc: 66.301% | global_loss: 1.5955439805984497


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 24 | global_acc: 66.450% | global_loss: 1.6183533668518066


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 25 | global_acc: 66.400% | global_loss: 1.6100306510925293


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 26 | global_acc: 66.750% | global_loss: 1.6276602745056152


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 27 | global_acc: 66.550% | global_loss: 1.6182169914245605


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 28 | global_acc: 66.550% | global_loss: 1.59121835231781


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 29 | global_acc: 66.400% | global_loss: 1.6097217798233032


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 30 | global_acc: 66.650% | global_loss: 1.6123669147491455


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 31 | global_acc: 66.800% | global_loss: 1.666329264640808


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 32 | global_acc: 66.500% | global_loss: 1.6274524927139282


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 33 | global_acc: 66.400% | global_loss: 1.5707502365112305


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 34 | global_acc: 66.400% | global_loss: 1.571716547012329


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 35 | global_acc: 66.650% | global_loss: 1.6077755689620972


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 36 | global_acc: 66.400% | global_loss: 1.5711044073104858


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 37 | global_acc: 66.500% | global_loss: 1.571199655532837


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 38 | global_acc: 66.700% | global_loss: 1.597794532775879


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 39 | global_acc: 66.700% | global_loss: 1.587345838546753


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 40 | global_acc: 66.450% | global_loss: 1.5677292346954346


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 41 | global_acc: 66.650% | global_loss: 1.6060558557510376


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 42 | global_acc: 67.049% | global_loss: 1.6309257745742798


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 43 | global_acc: 66.700% | global_loss: 1.605690836906433


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 44 | global_acc: 66.650% | global_loss: 1.6000574827194214


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 45 | global_acc: 66.550% | global_loss: 1.5938482284545898


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 46 | global_acc: 66.800% | global_loss: 1.5912338495254517


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 47 | global_acc: 66.600% | global_loss: 1.5730359554290771


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 48 | global_acc: 66.750% | global_loss: 1.601218581199646


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 49 | global_acc: 66.750% | global_loss: 1.5926952362060547


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 50 | global_acc: 66.650% | global_loss: 1.5787019729614258


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 51 | global_acc: 67.000% | global_loss: 1.6115362644195557


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 52 | global_acc: 66.900% | global_loss: 1.5985995531082153


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 53 | global_acc: 66.600% | global_loss: 1.5748372077941895


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 54 | global_acc: 66.650% | global_loss: 1.5642030239105225


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 55 | global_acc: 66.750% | global_loss: 1.5879238843917847


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 56 | global_acc: 67.199% | global_loss: 1.617919921875


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 57 | global_acc: 67.099% | global_loss: 1.5992043018341064


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 58 | global_acc: 66.700% | global_loss: 1.5806244611740112


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 59 | global_acc: 67.049% | global_loss: 1.599320650100708


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 60 | global_acc: 66.750% | global_loss: 1.591705083847046


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 61 | global_acc: 67.049% | global_loss: 1.6224967241287231


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 62 | global_acc: 67.000% | global_loss: 1.5956628322601318


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 63 | global_acc: 67.099% | global_loss: 1.6023776531219482


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 64 | global_acc: 67.149% | global_loss: 1.6079528331756592


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 65 | global_acc: 66.950% | global_loss: 1.5985876321792603


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 66 | global_acc: 66.800% | global_loss: 1.585093379020691


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 67 | global_acc: 66.700% | global_loss: 1.5717345476150513


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 68 | global_acc: 67.000% | global_loss: 1.6092394590377808


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 69 | global_acc: 66.750% | global_loss: 1.5800577402114868


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 70 | global_acc: 67.149% | global_loss: 1.583211898803711


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 71 | global_acc: 67.049% | global_loss: 1.5988764762878418


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 72 | global_acc: 67.199% | global_loss: 1.5828468799591064


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 73 | global_acc: 66.650% | global_loss: 1.5664774179458618


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 74 | global_acc: 66.700% | global_loss: 1.5707708597183228


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 75 | global_acc: 66.900% | global_loss: 1.5797122716903687


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 76 | global_acc: 67.449% | global_loss: 1.6089500188827515


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 77 | global_acc: 66.900% | global_loss: 1.5734789371490479


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 78 | global_acc: 67.249% | global_loss: 1.5952469110488892


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 79 | global_acc: 66.850% | global_loss: 1.575049638748169


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 80 | global_acc: 66.850% | global_loss: 1.5665160417556763


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 81 | global_acc: 67.449% | global_loss: 1.6055033206939697


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 82 | global_acc: 67.049% | global_loss: 1.5853245258331299


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 83 | global_acc: 66.800% | global_loss: 1.552903413772583


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 84 | global_acc: 66.800% | global_loss: 1.593362808227539


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 85 | global_acc: 67.399% | global_loss: 1.6219671964645386


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 86 | global_acc: 67.249% | global_loss: 1.6193095445632935


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 87 | global_acc: 67.149% | global_loss: 1.5927932262420654


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 88 | global_acc: 67.399% | global_loss: 1.5941725969314575


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 89 | global_acc: 67.149% | global_loss: 1.6017471551895142


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 90 | global_acc: 66.900% | global_loss: 1.5708636045455933


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 91 | global_acc: 67.249% | global_loss: 1.5834718942642212


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 92 | global_acc: 67.000% | global_loss: 1.5768953561782837


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 93 | global_acc: 66.950% | global_loss: 1.5748543739318848


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 94 | global_acc: 67.249% | global_loss: 1.578571081161499


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 95 | global_acc: 66.900% | global_loss: 1.569765567779541


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 96 | global_acc: 67.149% | global_loss: 1.5851374864578247


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 97 | global_acc: 67.149% | global_loss: 1.596619725227356


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 98 | global_acc: 67.549% | global_loss: 1.6161102056503296


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 99 | global_acc: 67.299% | global_loss: 1.5771993398666382
[[   0    1   10    0   50    0    0]
 [   0    0    1    0   94    0    1]
 [   0    0   17    0  208    0    3]
 [   0    0    8    0   29    0    0]
 [   0    0    8    0 1312    0    7]
 [   0    0    1    0   29    0    2]
 [   0    0    7    0  196    0   19]]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.00      0.00      0.00        96
           2       0.33      0.07      0.12       228
           3       0.00      0.00      0.00        37
           4       0.68      0.99      0.81      1327
           5       0.00      0.00      0.00        32
           6       0.59      0.09      0.15       222

    accuracy                           0.67      2003
   macro avg       0.23      0.16      0.15      2003
weighted avg       0.56      0.67      0.57      2003



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
def test_model(X_test, Y_test,  model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    #logits = model.predict(X_test, batch_size=100)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis=1), tf.argmax(Y_test, axis=1))
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss))
    return acc, loss

In [None]:
SGD_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(len(Y_train)).batch(320)
smlp_SGD = SimpleMLP()
SGD_model = smlp_SGD.build(784, 7) 

SGD_model.compile(loss=loss, 
              optimizer=optimizer, 
              metrics=metrics)

# fit the SGD training data to model
_ = SGD_model.fit(SGD_dataset, epochs=100, verbose=0)

#test the SGD global model and print out metrics
for(X_test, Y_test) in test_batched:
        SGD_acc, SGD_loss = test_model(X_test, Y_test, SGD_model, 1)
# Make predictions on the test set
y_pred = SGD_model.predict(X_test)
y_pred = np.argmax(y_pred, axis=1)
            
# Get the true classes
true_classes = np.argmax(Y_test, axis = 1)
      
# Print the classification report and the confusion matrix
report = classification_report(true_classes, y_pred)
from sklearn.metrics import confusion_matrix
confusion_matrix = confusion_matrix(true_classes, y_pred)
print(confusion_matrix)
plt.show()
print(report)

comm_round: 1 | global_acc: 67.099% | global_loss: 1.556148648262024
[[   0    2    7    0   52    0    0]
 [   0    1    1    0   93    0    1]
 [   0    1   13    0  209    0    5]
 [   0    1    6    0   30    0    0]
 [   0    0    7    0 1312    0    8]
 [   0    0    1    0   29    0    2]
 [   0    0    6    0  198    0   18]]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.20      0.01      0.02        96
           2       0.32      0.06      0.10       228
           3       0.00      0.00      0.00        37
           4       0.68      0.99      0.81      1327
           5       0.00      0.00      0.00        32
           6       0.53      0.08      0.14       222

    accuracy                           0.67      2003
   macro avg       0.25      0.16      0.15      2003
weighted avg       0.56      0.67      0.56      2003



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Label-flipping

In [None]:
## Pseudocode for flipping percent p of labels.

# Store all possible labels in list all_labels.
# k =  (total samples of client * p)/100     

# For one particular client dataset- do following.
# Get k random indexes of dataset and store in a list all_indexes (list contains k random numbers from 0 - number of samples for that client).
# For each value in the all_indexes list obtained above - do following
# Replace label with some random label from list all_labels

In [None]:
client_1,label_1=zip(*clients['client_1'])

In [None]:
# before replacing labels for label_1
for i, label in enumerate(label_1):
    print(i, label)

0 [0 0 0 0 1 0 0]
1 [0 0 0 0 1 0 0]
2 [0 0 0 0 1 0 0]
3 [0 0 1 0 0 0 0]
4 [0 0 0 0 1 0 0]
5 [0 0 0 0 0 0 1]
6 [0 0 0 0 1 0 0]
7 [0 0 0 0 0 1 0]
8 [0 0 1 0 0 0 0]
9 [0 0 0 0 1 0 0]
10 [0 0 0 0 1 0 0]
11 [0 0 0 0 0 1 0]
12 [0 0 0 0 0 0 1]
13 [0 0 0 0 1 0 0]
14 [0 0 0 0 1 0 0]
15 [0 0 0 0 1 0 0]
16 [0 1 0 0 0 0 0]
17 [0 0 1 0 0 0 0]
18 [0 0 0 0 0 0 1]
19 [0 0 0 0 1 0 0]
20 [0 0 0 0 0 0 1]
21 [1 0 0 0 0 0 0]
22 [0 0 0 0 1 0 0]
23 [0 0 0 0 1 0 0]
24 [0 0 0 0 1 0 0]
25 [0 1 0 0 0 0 0]
26 [0 0 0 0 1 0 0]
27 [0 0 0 0 1 0 0]
28 [0 0 0 0 0 0 1]
29 [0 0 0 0 0 0 1]
30 [0 0 0 0 1 0 0]
31 [0 0 0 0 1 0 0]
32 [0 0 0 0 1 0 0]
33 [0 0 0 0 1 0 0]
34 [0 0 0 0 1 0 0]
35 [0 1 0 0 0 0 0]
36 [0 0 0 0 1 0 0]
37 [0 0 1 0 0 0 0]
38 [0 0 0 0 1 0 0]
39 [0 0 0 0 1 0 0]
40 [0 0 1 0 0 0 0]
41 [0 0 0 0 0 0 1]
42 [0 0 0 0 1 0 0]
43 [0 0 0 0 1 0 0]
44 [0 0 0 0 0 0 1]
45 [0 0 0 0 1 0 0]
46 [0 0 0 0 1 0 0]
47 [0 0 0 0 1 0 0]
48 [0 0 0 0 1 0 0]
49 [0 0 0 0 0 0 1]
50 [0 0 0 0 1 0 0]
51 [0 1 0 0 0 0 0]
52 [0 0 0 0 1 0 0]
53 

In [None]:
# getting all labels under all_labels variable
all_labels = ()
clients_1 = ['client_1','client_2','client_3','client_4','client_5','client_6','client_7','client_8','client_9','client_10']
for client in range(len(clients_1)):
    name = clients_1[client]
    image,label = zip(*clients[name])
    all_labels=all_labels+(label)
print(list(all_labels))


[array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 1, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 1, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 1, 0, 0, 0, 0, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 1, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0,

In [None]:
len(all_labels)

8010

In [None]:
# finding the k
p=20
k= len(clients['client_1'])*p//100
print(k)

160


In [None]:
# unzipped the client_1 into images and labels
len_of_one_client_dataset = len(client_1)
print(len_of_one_client_dataset)

801


In [None]:
all_indexes = np.random.randint(0, 801, k)

print(all_indexes)

[513 360 798 369 485 726 443 554 677 591 252 653 617  64 778 320 512 207
 781 744 391  86 447 128 286 345  63 594 408 526 609 308 505  85 295 427
 120 263 348 121 561 581  60 750 351 344 257 232 482 473 624 681 610 629
 396 137 578 189 722 213 388 586 509 313 389 315 557 528 407  73 106 424
 641 690 633  30 536 730 452 790 523 647 613 709 785 202 377 762 531 519
 476 789 707 281 362 191 384 379 447 270 511  15 292 105 512  49 282 362
 725 304 334 773 469 731 715 228 301 548  82 359 708 519 331 521  48 447
 693 212 394 604 149 622 777 590 333 500 752 288 641 137 409 380 777 199
 108 452 163 647 727 591  18 759 106 129 373 426  31 326 206 134]


In [None]:
label_1 = list(label_1)

# Replace the label at each index in all_indexes with a random label from all_labels
for index in all_indexes:
    label_1[index] = random.choice(all_labels)

In [None]:
print(label_1)

[array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 1, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 1, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 1, 0, 0, 0, 0, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 1, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0]), array([0, 1, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 1, 0, 0]), array([0,

In [None]:
# cross checking from all_labels
all_labels[513]

array([0, 0, 1, 0, 0, 0, 0])

In [None]:
type(label_1)

list

In [None]:
# assigning the replaced labels to particular client
# Replace the labels in the clients dictionary
clients["client_1"] = list(zip(client_1, label_1))

In [None]:
# cross checking if it is assigned or not
clients['client_1'][513]

(array([0.00392157, 0.02352941, 0.04705882, 0.04313725, 0.05490196,
        0.20392157, 0.29803922, 0.20784314, 0.09411765, 0.0745098 ,
        0.07843137, 0.08235294, 0.08235294, 0.06666667, 0.05098039,
        0.05098039, 0.06666667, 0.09019608, 0.11764706, 0.08627451,
        0.06666667, 0.08235294, 0.10588235, 0.14901961, 0.19215686,
        0.19215686, 0.15686275, 0.09411765, 0.01176471, 0.03137255,
        0.04313725, 0.02745098, 0.2       , 0.40784314, 0.45098039,
        0.4627451 , 0.38431373, 0.24313725, 0.1372549 , 0.0745098 ,
        0.07843137, 0.09019608, 0.0745098 , 0.0627451 , 0.06666667,
        0.07058824, 0.09803922, 0.14509804, 0.10588235, 0.0745098 ,
        0.09803922, 0.12156863, 0.16470588, 0.19215686, 0.17254902,
        0.12156863, 0.01960784, 0.04313725, 0.02745098, 0.11372549,
        0.43137255, 0.5254902 , 0.51764706, 0.55294118, 0.54117647,
        0.52941176, 0.4745098 , 0.3372549 , 0.19607843, 0.09019608,
        0.08627451, 0.08235294, 0.0745098 , 0.08

In [None]:
# cross checking after replacing for label_1
for i in all_indexes:
    print(label_1[i])

[0 0 0 0 1 0 0]
[1 0 0 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 1 0 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 0 0 1]
[1 0 0 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 1 0 0 0 0 0]
[1 0 0 0 0 0 0]
[0 1 0 0 0 0 0]
[0 0 0 1 0 0 0]
[0 0 0 0 0 0 1]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 1 0 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 0 0 1]
[0 0 0 0 0 0 1]
[0 0 0 0 1 0 0]
[0 0 0 0 0 0 1]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 0 0 1]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 0 0 1]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 0 0 1 0 0]
[0 0 1 0 0 0 0]
[0 0 0 0

In [None]:
fl_model(clients_batched, test_batched, comms_round, loss, optimizer, metrics)

comm_round: 0 | global_acc: 66.251% | global_loss: 1.6204502582550049


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 1 | global_acc: 66.251% | global_loss: 1.6243340969085693


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 2 | global_acc: 66.251% | global_loss: 1.6312410831451416


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 3 | global_acc: 66.251% | global_loss: 1.6393557786941528


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 4 | global_acc: 66.251% | global_loss: 1.6426329612731934


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 5 | global_acc: 66.251% | global_loss: 1.6247674226760864


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 6 | global_acc: 66.201% | global_loss: 1.6408997774124146


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 7 | global_acc: 66.201% | global_loss: 1.6250821352005005


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 8 | global_acc: 66.101% | global_loss: 1.6423671245574951


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 9 | global_acc: 66.101% | global_loss: 1.6220383644104004


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 10 | global_acc: 66.051% | global_loss: 1.6299852132797241


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 11 | global_acc: 66.001% | global_loss: 1.629160761833191


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 12 | global_acc: 66.101% | global_loss: 1.6097888946533203


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 13 | global_acc: 66.051% | global_loss: 1.6211830377578735


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 14 | global_acc: 65.951% | global_loss: 1.6197527647018433


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 15 | global_acc: 65.951% | global_loss: 1.622419834136963


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 16 | global_acc: 66.001% | global_loss: 1.611541986465454


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 17 | global_acc: 66.101% | global_loss: 1.623810052871704


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 18 | global_acc: 66.051% | global_loss: 1.6061105728149414


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 19 | global_acc: 66.101% | global_loss: 1.5901769399642944


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 20 | global_acc: 66.101% | global_loss: 1.6131834983825684


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 21 | global_acc: 66.051% | global_loss: 1.6172378063201904


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 22 | global_acc: 66.201% | global_loss: 1.6147520542144775


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 23 | global_acc: 66.001% | global_loss: 1.633989691734314


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 24 | global_acc: 66.101% | global_loss: 1.653146505355835


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 25 | global_acc: 66.001% | global_loss: 1.62153160572052


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 26 | global_acc: 66.101% | global_loss: 1.6177527904510498


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 27 | global_acc: 66.201% | global_loss: 1.6099218130111694


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 28 | global_acc: 66.201% | global_loss: 1.6001791954040527


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 29 | global_acc: 66.400% | global_loss: 1.6419533491134644


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 30 | global_acc: 66.251% | global_loss: 1.6244690418243408


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 31 | global_acc: 66.051% | global_loss: 1.6220827102661133


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 32 | global_acc: 66.251% | global_loss: 1.5982511043548584


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 33 | global_acc: 66.301% | global_loss: 1.6038857698440552


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 34 | global_acc: 66.301% | global_loss: 1.6003109216690063


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 35 | global_acc: 66.251% | global_loss: 1.6039049625396729


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 36 | global_acc: 66.600% | global_loss: 1.6312811374664307


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 37 | global_acc: 66.301% | global_loss: 1.5881730318069458


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 38 | global_acc: 66.350% | global_loss: 1.5926533937454224


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 39 | global_acc: 66.301% | global_loss: 1.5897518396377563


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 40 | global_acc: 66.301% | global_loss: 1.582818865776062


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 41 | global_acc: 66.550% | global_loss: 1.6110446453094482


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 42 | global_acc: 66.550% | global_loss: 1.6132272481918335


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 43 | global_acc: 66.400% | global_loss: 1.5800896883010864


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 44 | global_acc: 66.550% | global_loss: 1.608547329902649


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 45 | global_acc: 66.600% | global_loss: 1.5856695175170898


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 46 | global_acc: 66.600% | global_loss: 1.5950132608413696


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 47 | global_acc: 66.600% | global_loss: 1.5889278650283813


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 48 | global_acc: 66.550% | global_loss: 1.6045074462890625


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 49 | global_acc: 66.550% | global_loss: 1.5862857103347778


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 50 | global_acc: 66.750% | global_loss: 1.603836178779602


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 51 | global_acc: 66.650% | global_loss: 1.5923446416854858


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 52 | global_acc: 66.550% | global_loss: 1.5888924598693848


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 53 | global_acc: 66.700% | global_loss: 1.5996359586715698


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 54 | global_acc: 66.850% | global_loss: 1.6028872728347778


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 55 | global_acc: 66.700% | global_loss: 1.599525809288025


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 56 | global_acc: 66.700% | global_loss: 1.5867457389831543


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 57 | global_acc: 66.800% | global_loss: 1.5950415134429932


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 58 | global_acc: 66.800% | global_loss: 1.5989904403686523


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 59 | global_acc: 66.800% | global_loss: 1.587819218635559


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 60 | global_acc: 66.700% | global_loss: 1.5923256874084473


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 61 | global_acc: 67.199% | global_loss: 1.6155035495758057


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 62 | global_acc: 67.049% | global_loss: 1.6060137748718262


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 63 | global_acc: 66.850% | global_loss: 1.5854461193084717


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 64 | global_acc: 66.700% | global_loss: 1.5933459997177124


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 65 | global_acc: 67.000% | global_loss: 1.5997204780578613


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 66 | global_acc: 67.049% | global_loss: 1.5996078252792358


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 67 | global_acc: 66.650% | global_loss: 1.5756722688674927


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 68 | global_acc: 67.199% | global_loss: 1.6012992858886719


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 69 | global_acc: 67.299% | global_loss: 1.6085642576217651


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 70 | global_acc: 67.049% | global_loss: 1.5920720100402832


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 71 | global_acc: 66.600% | global_loss: 1.5682395696640015


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 72 | global_acc: 67.049% | global_loss: 1.5974937677383423


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 73 | global_acc: 66.950% | global_loss: 1.5838050842285156


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 74 | global_acc: 67.049% | global_loss: 1.5931553840637207


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 75 | global_acc: 67.199% | global_loss: 1.6118539571762085


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 76 | global_acc: 67.049% | global_loss: 1.5992534160614014


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 77 | global_acc: 67.149% | global_loss: 1.6037073135375977


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 78 | global_acc: 66.900% | global_loss: 1.5873627662658691


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 79 | global_acc: 67.099% | global_loss: 1.6055010557174683


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 80 | global_acc: 67.199% | global_loss: 1.5965348482131958


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 81 | global_acc: 67.299% | global_loss: 1.6092232465744019


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 82 | global_acc: 66.850% | global_loss: 1.5796611309051514


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 83 | global_acc: 66.850% | global_loss: 1.5824381113052368


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 84 | global_acc: 66.800% | global_loss: 1.5800069570541382


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 85 | global_acc: 66.800% | global_loss: 1.5796501636505127


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 86 | global_acc: 67.000% | global_loss: 1.5807909965515137


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 87 | global_acc: 66.850% | global_loss: 1.5885957479476929


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 88 | global_acc: 67.199% | global_loss: 1.594296932220459


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 89 | global_acc: 67.149% | global_loss: 1.5947134494781494


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 90 | global_acc: 67.149% | global_loss: 1.5918693542480469


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 91 | global_acc: 67.199% | global_loss: 1.605787992477417


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 92 | global_acc: 67.149% | global_loss: 1.5915873050689697


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 93 | global_acc: 66.850% | global_loss: 1.5794724225997925


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 94 | global_acc: 67.049% | global_loss: 1.5993667840957642


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 95 | global_acc: 67.199% | global_loss: 1.5908466577529907


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 96 | global_acc: 67.000% | global_loss: 1.5764269828796387


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 97 | global_acc: 67.000% | global_loss: 1.585667610168457


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 98 | global_acc: 67.149% | global_loss: 1.605159044265747


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


comm_round: 99 | global_acc: 67.199% | global_loss: 1.5937548875808716
[[   0    2    8    0   51    0    0]
 [   0    0    2    0   93    0    1]
 [   0    1   13    0  209    0    5]
 [   0    0    6    0   31    0    0]
 [   0    1    5    0 1315    0    6]
 [   0    0    1    0   30    0    1]
 [   0    0    7    0  197    0   18]]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.00      0.00      0.00        96
           2       0.31      0.06      0.10       228
           3       0.00      0.00      0.00        37
           4       0.68      0.99      0.81      1327
           5       0.00      0.00      0.00        32
           6       0.58      0.08      0.14       222

    accuracy                           0.67      2003
   macro avg       0.22      0.16      0.15      2003
weighted avg       0.55      0.67      0.56      2003



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
SGD_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).shuffle(len(Y_train)).batch(320)
smlp_SGD = SimpleMLP()
SGD_model = smlp_SGD.build(784, 7) 

SGD_model.compile(loss=loss, 
              optimizer=optimizer, 
              metrics=metrics)

# fit the SGD training data to model
_ = SGD_model.fit(SGD_dataset, epochs=100, verbose=0)

#test the SGD global model and print out metrics
for(X_test, Y_test) in test_batched:
        SGD_acc, SGD_loss = test_model(X_test, Y_test, SGD_model, 1)

        # Make predictions on the test set
        y_pred = SGD_model.predict(X_test)
        y_pred = np.argmax(y_pred, axis=1)
            
        # Get the true classes
        true_classes = np.argmax(Y_test, axis = 1)
      
        # Print the classification report and the confusion matrix
        report = classification_report(true_classes, y_pred)
        from sklearn.metrics import confusion_matrix
        confusion_matrix = confusion_matrix(true_classes, y_pred)
        print(confusion_matrix)
        plt.show()
        print(report)

comm_round: 1 | global_acc: 66.950% | global_loss: 1.5530188083648682
[[   0    0    5    0   56    0    0]
 [   0    0    1    0   94    0    1]
 [   0    0    5    0  217    0    6]
 [   0    0    4    0   33    0    0]
 [   0    0    4    0 1318    0    5]
 [   0    0    1    0   30    0    1]
 [   0    0    2    0  202    0   18]]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.00      0.00      0.00        96
           2       0.23      0.02      0.04       228
           3       0.00      0.00      0.00        37
           4       0.68      0.99      0.80      1327
           5       0.00      0.00      0.00        32
           6       0.58      0.08      0.14       222

    accuracy                           0.67      2003
   macro avg       0.21      0.16      0.14      2003
weighted avg       0.54      0.67      0.55      2003



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
