<a href="https://colab.research.google.com/github/sameepshrestha/federated_learning_compression/blob/main/federated__learning_rough_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True) # Set force_remount to True

Mounted at /content/drive


In [3]:
!tar -xf yourfile.tar

tar: yourfile.tar: Cannot open: No such file or directory
tar: Error is not recoverable: exiting now


In [4]:
# for data load
import os

# for reading and processing images
import imageio
from PIL import Image

# for visualizations

import matplotlib.pyplot as plt

import numpy as np # for using np arrays

# for bulding and running deep learning model
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.losses import binary_crossentropy
from sklearn.model_selection import train_test_split

In [5]:
def LoadData (input_dir, target_dir):
    """
    Looks for relevant filenames in the shared path
    Returns 2 lists for original and masked files respectively

    """
    input_img_paths =sorted([os.path.join(input_dir , fname )for fname in os.listdir(input_dir) if fname.endswith(".jpg")])
    target_img_paths = sorted([os.path.join(target_dir , fname) for fname in os.listdir(target_dir) if fname.endswith(".png") and not fname.startswith(".")])
    # Read the images folder like a list


    return input_img_paths, target_img_paths

In [6]:
def PreprocessData(img, mask, target_shape_img, target_shape_mask, path1, path2):
    """
    Processes the images and mask present in the shared list and path
    Returns a NumPy dataset with images as 3-D arrays of desired size
    Please note the masks in this dataset have only one channel
    """
    # Pull the relevant dimensions for image and mask
    m = len(img)                     # number of images
    i_h,i_w,i_c = target_shape_img   # pull height, width, and channels of image
    m_h,m_w,m_c = target_shape_mask  # pull height, width, and channels of mask

    # Define X and Y as number of images along with shape of one image
    X = np.zeros((m,i_h,i_w,i_c), dtype=np.float32)
    y = np.zeros((m,m_h,m_w,m_c), dtype=np.int32)

    # Resize images and masks
    for file in img:
        # convert image into an array of desired shape (3 channels)
        index = img.index(file)
        path = os.path.join(path1, file)
        single_img = Image.open(path).convert('RGB')
        single_img = single_img.resize((i_h,i_w))
        single_img = np.reshape(single_img,(i_h,i_w,i_c))
        single_img = single_img/256.
        X[index] = single_img

        # convert mask into an array of desired shape (1 channel)
        single_mask_ind = mask[index]
        path = os.path.join(path2, single_mask_ind)
        single_mask = Image.open(path)
        single_mask = single_mask.resize((m_h, m_w))
        single_mask = np.reshape(single_mask,(m_h,m_w,m_c))
        single_mask = single_mask - 1 # to ensure classes #s start from 0
        y[index] = single_mask
    return X, y

In [7]:
img_size = (128 , 128)
def path_to_img(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img , channels=3)
    img = tf.image.resize(img , img_size)
    img = tf.cast(img , tf.float32)
    return img
def path_to_target(path):
    img = tf.io.read_file(path )
    img = tf.image.decode_png(img ,channels=1 )
    img = tf.image.resize(img , img_size)
    img = tf.cast(img , tf.uint8) - 1
    return img
def map_fn(img_path , target_path):
    img = path_to_img(img_path)
    mask = path_to_target(target_path)
    return img , mask
num_valid_samples =1000

In [8]:
def EncoderMiniBlock(inputs, n_filters=32, dropout_prob=0.3, max_pooling=True,name="name"):
    """
    This block uses multiple convolution layers, max pool, relu activation to create an architecture for learning.
    Dropout can be added for regularization to prevent overfitting.
    The block returns the activation values for next layer along with a skip connection which will be used in the decoder
    """
    conv = Conv2D(n_filters,
                  3,   # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='HeNormal', name = name+"1")(inputs)
    conv = Conv2D(n_filters,
                  3,   # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='HeNormal', name = name+"2")(conv)

    conv = BatchNormalization()(conv, training=False)
    if dropout_prob > 0:
        conv = tf.keras.layers.Dropout(dropout_prob)(conv)
    if max_pooling:
        next_layer = tf.keras.layers.MaxPooling2D(pool_size = (2,2))(conv)
    else:
        next_layer = conv
    skip_connection = conv

    return next_layer, skip_connection
def DecoderMiniBlock(prev_layer_input, skip_layer_input, n_filters=32, name="name"):
    """
    Decoder Block first uses transpose convolution to upscale the image to a bigger size and then,
    merges the result with skip layer results from encoder block
    Adding 2 convolutions with 'same' padding helps further increase the depth of the network for better predictions
    The function returns the decoded layer output
    """
    up = Conv2DTranspose(
                n_filters,
                (3,3),    # Kernel size
                strides=(2,2),
                padding='same',name=name+"transpose")(prev_layer_input)

    conv = Conv2D(n_filters,
                3,     # Kernel size
                activation='relu',
                padding='same',
                kernel_initializer='HeNormal', name = name +"1")(up)
    conv = Conv2D(n_filters,
                3,   # Kernel size
                activation='relu',
                padding='same',
                kernel_initializer='HeNormal', name = name +"2")(conv)
    return conv

def UNetCompiled(input_size=(128, 128, 3), n_filters=32, n_classes=3):

    inputs = Input(input_size)

    cblock1 = EncoderMiniBlock(inputs, n_filters,dropout_prob=0, max_pooling=True, name = "cblock1")
    cblock2 = EncoderMiniBlock(cblock1[0],n_filters*2,dropout_prob=0, max_pooling=True,name = "cblock2")
    cblock3 = EncoderMiniBlock(cblock2[0], n_filters*4,dropout_prob=0, max_pooling=True, name = "cblock3")
    cblock4 = EncoderMiniBlock(cblock3[0], n_filters*8,dropout_prob=0.3, max_pooling=True,name = "cblock4")
    cblock5 = EncoderMiniBlock(cblock4[0], n_filters*16, dropout_prob=0.3, max_pooling=False,name="cblock5")

    ublock6 = DecoderMiniBlock(cblock5[0], cblock4[1],  n_filters * 8, name="ublock6")
    ublock7 = DecoderMiniBlock(ublock6, cblock3[1],  n_filters * 4,name="ublock7")
    ublock8 = DecoderMiniBlock(ublock7, cblock2[1],  n_filters * 2,name="ublock8")
    ublock9 = DecoderMiniBlock(ublock8, cblock1[1],  n_filters,name="ublock9")
    conv9 = Conv2D(n_filters,
                3,
                activation='relu',
                padding='same',
                kernel_initializer='he_normal',name="conv9")(ublock9)

    conv10 = Conv2D(n_classes, 1, padding='same',name="conv10")(conv9)

    # Define the model
    model = tf.keras.Model(inputs=inputs, outputs=conv10)

    return model
model= UNetCompiled(input_size=(128, 128, 3), n_filters=32, n_classes=3)

In [9]:
layer_names = [
 'cblock11',
 'cblock12',
 'cblock21',
 'cblock22',
 'cblock31',
 'cblock32',
 'cblock41',
 'cblock42',
 'cblock51',
 'cblock52',
 'ublock6transpose',
 'ublock61',
 'ublock62',
 'ublock7transpose',
 'ublock71',
 'ublock72',
 'ublock8transpose',
 'ublock81',
 'ublock82',
 'ublock9transpose',
 'ublock91',
 'ublock92',
]

In [10]:
# Load images and masks
path1 = '/content/drive/MyDrive/segmentation/images'
path2 = '/content/drive/MyDrive/segmentation/annotations/trimaps'
img, mask = LoadData(path1, path2)  # LoadData should return lists of image and mask paths


In [11]:
def set_trainable_layers(model, layers_to_train):
    """
    Set only specific layers to be trainable.

    Args:
    model: The model whose layers are to be modified.
    layers_to_train: A list of layer names to be set as trainable.
    """
    for layer in model.layers:
        if layer.name in layers_to_train:
            layer.trainable = True
        else:
            layer.trainable = False


In [12]:
def get_layer_weights(model, layers_to_send):
    """
    Extracts the weights of specific layers from the model.

    Args:
    model: The model from which to extract weights.
    layers_to_send: A list of layer names whose weights need to be sent.

    Returns:
    A dictionary of layer names and their corresponding weights.
    """
    layer_weights = {}
    for layer in model.layers:
        if layer.name in layers_to_send:
            layer_weights[layer.name] = layer.get_weights()
    return layer_weights

def update_layer_weights(model, received_weights):
    """
    Updates specific layers of the model with received weights.

    Args:
    model: The model to update.
    received_weights: A dictionary of layer names and their new weights.
    """
    # print(avg_weights,"average weigths")
    for layer in model.layers:
        if layer.name in received_weights:
            # print(received_weights[layer.name].shape)
            layer.set_weights(received_weights[layer.name])


In [13]:
import numpy as np
import tensorflow as tf

num_clients = 4

# Shuffle the dataset indices
data_indices = np.arange(len(img))  # Create an array of indices corresponding to the dataset
np.random.shuffle(data_indices)     # Shuffle the indices

# Split the shuffled indices into num_clients parts
client_data_indices = np.array_split(data_indices, num_clients)

# Create datasets for each client
client_datasets = []
for indices in client_data_indices:
    client_images = np.array(img)[indices]
    client_masks = np.array(mask)[indices]

    # Create TensorFlow datasets for each client
    client_dataset = tf.data.Dataset.from_tensor_slices((client_images, client_masks))
    client_dataset = client_dataset.map(map_fn).batch(64).prefetch(1)
    client_datasets.append(client_dataset)


In [14]:
# def federated_averaging(client_weights, client_layers):
#     """
#     Performs federated averaging of the weights for specific layers across all clients.

#     Args:
#     client_weights: A list of dictionaries, each containing layer weights from a client.
#     client_layers: A list of lists, each containing layer names that were trained on each client.

#     Returns:
#     A dictionary of averaged weights for the selected layers.
#     """
#     new_weights = {layer_name: None for layer_name in all_layer_names} #empty weights dictionary
#     layer_counts = {layer_name: 0 for layer_name in all_layer_names}

#     for client_weight, layers_to_train in zip(client_weights, client_layers):
#         for layer_name in layers_to_train:#layers to train will contain layers of all the client which was trained
#             if layer_name in client_weight:
#                 if new_weights[layer_name] is None:
#                     print(len(client_weight[layer_name]))
#                     new_weights[layer_name] = client_weight[layer_name][0]
#                 else:
#                     new_weights[layer_name] += client_weight[layer_name][0]
#                 layer_counts[layer_name] += 1
#     for layer_name in new_weights:
#         if layer_counts[layer_name] > 0:  # Only average layers that were trained
#             print(new_weights[layer_name].shape)
#             new_weights[layer_name] /= layer_counts[layer_name]

#     # Filter out None values for layers that were not trained by any client
#     new_weights = {k: v for k, v in new_weights.items() if v is not None}
#     print("new_weights",new_weights.keys())
#     return new_weights


In [15]:
def federated_averaging(client_weights, client_layers):
    """
    Performs federated averaging of the weights for specific layers across all clients.

    Args:
    client_weights: A list of dictionaries, each containing layer weights from a client.
    client_layers: A list of lists, each containing layer names that were trained on each client.

    Returns:
    A dictionary of averaged weights for the selected layers.
    """
    new_weights = {layer_name: None for layer_name in all_layer_names}  # Initialize empty weights dictionary
    layer_counts = {layer_name: 0 for layer_name in all_layer_names}    # Initialize counts for each layer

    # Iterate over all clients and their corresponding layers to average weights
    for client_weight, layers_to_train in zip(client_weights, client_layers):
        for layer_name in layers_to_train:
            if layer_name in client_weight:
                # Initialize the layer weights if not already initialized
                if new_weights[layer_name] is None:
                    # Create a deep copy of the list of weights for the layer
                    new_weights[layer_name] = [np.copy(w) for w in client_weight[layer_name]]
                else:
                    # Sum the weights across clients for averaging
                    for i in range(len(client_weight[layer_name])):
                        new_weights[layer_name][i] += client_weight[layer_name][i]
                layer_counts[layer_name] += 1

    # Average the weights for each layer
    for layer_name in new_weights:
        if layer_counts[layer_name] > 0:  # Only average layers that were trained
            for i in range(len(new_weights[layer_name])):
                new_weights[layer_name][i] /= layer_counts[layer_name]

    # Filter out None values for layers that were not trained by any client
    new_weights = {k: v for k, v in new_weights.items() if v is not None}
    print("new_weights", new_weights.keys())
    return new_weights


In [17]:
# Initialize the global model
import time
import random
from tensorflow.keras import backend as K
import gc
global_model = UNetCompiled(input_size=(128, 128, 3), n_filters=32, n_classes=3)
global_model.compile(optimizer=tf.keras.optimizers.Adam(),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])
# List all layer names
all_layer_names = [layer.name for layer in global_model.layers]

num_rounds = 40

# Clients create a local model initially (in round 0)
client_models = []
for _ in range(len(client_datasets)-1):
    client_model = tf.keras.models.clone_model(global_model)
    client_model.set_weights(global_model.get_weights())  # Initialize with global weights
    client_model.compile(optimizer=tf.keras.optimizers.Adam(),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])
    client_models.append(client_model)

# Initialize the layers to send for the first round
layers_to_send = []  # No layers to send initially
start_time = time.time()
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}/{num_rounds}")

    # Randomly select layers to train for this round
    layers_to_train = all_layer_names

    # Include specific layers (e.g., output layers) that should always be trained
    layers_to_train.append('conv9')
    layers_to_train.append('conv10')

    # Server sends weights of the selected layers from the previous round to clients
    layer_weights_to_send = get_layer_weights(global_model, layers_to_send) if layers_to_send else None

    client_weights = []
    client_layers = []

    # Each client trains its local model with the specified layers
    for client_id, client_dataset in enumerate(client_datasets[0:(len(client_datasets)-1)]):
        client_model = client_models[client_id]  # Reuse the local model created in round 0

        if layer_weights_to_send:
            # Update only the relevant layers for this round

            update_layer_weights(client_model, layer_weights_to_send)


        # Set the specified layers to be trainable
        set_trainable_layers(client_model, layers_to_train)

        # Compile and train the local model on the client's dataset
        client_model.fit(client_dataset, epochs=1)
        print(f"Client {client_id} finished training with layers: {layers_to_train}")

        # Collect only the updated weights for the trained layers
        updated_weights = get_layer_weights(client_model, layers_to_train)
        client_weights.append(updated_weights)
        client_layers.append(layers_to_train)

    # Perform federated averaging with only the trained layers
    avg_weights = federated_averaging(client_weights, client_layers)
    # Update the global model with the averaged weights
    update_layer_weights(global_model, avg_weights)

    # Prepare the layers for the next round
    layers_to_send = layers_to_train

    print(f"Completed Round {round_num + 1}. Global model updated with layers: {layers_to_train}at time")
    print("time is ",(time.time()-start_time))
    del client_weights, client_layers
    K.clear_session()

    # Force garbage collection to prevent memory from accumulating
    gc.collect()

print("time is ",(time.time()-start_time))
# Evaluate on client 4's dataset
loss, accuracy = global_model.evaluate(client_datasets[-1])
print(f"Final evaluation on Client 4's dataset - Loss: {loss}, Accuracy: {accuracy}")

Round 1/40
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 7s/step - accuracy: 0.5196 - loss: 0.9976
Client 0 finished training with layers: ['input_layer_2', 'cblock11', 'cblock12', 'batch_normalization_10', 'max_pooling2d_8', 'cblock21', 'cblock22', 'batch_normalization_11', 'max_pooling2d_9', 'cblock31', 'cblock32', 'batch_normalization_12', 'max_pooling2d_10', 'cblock41', 'cblock42', 'batch_normalization_13', 'dropout_4', 'max_pooling2d_11', 'cblock51', 'cblock52', 'batch_normalization_14', 'dropout_5', 'ublock6transpose', 'ublock61', 'ublock62', 'ublock7transpose', 'ublock71', 'ublock72', 'ublock8transpose', 'ublock81', 'ublock82', 'ublock9transpose', 'ublock91', 'ublock92', 'conv9', 'conv10', 'conv9', 'conv10']
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 3s/step - accuracy: 0.5036 - loss: 0.9932
Client 1 finished training with layers: ['input_layer_2', 'cblock11', 'cblock12', 'batch_normalization_10', 'max_pooling2d_8', 'cblock21', 'cblock2

KeyboardInterrupt: 

In [19]:
all_layer_names

['input_layer',
 'cblock11',
 'cblock12',
 'batch_normalization',
 'max_pooling2d',
 'cblock21',
 'cblock22',
 'batch_normalization_1',
 'max_pooling2d_1',
 'cblock31',
 'cblock32',
 'batch_normalization_2',
 'max_pooling2d_2',
 'cblock41',
 'cblock42',
 'batch_normalization_3',
 'dropout',
 'max_pooling2d_3',
 'cblock51',
 'cblock52',
 'batch_normalization_4',
 'dropout_1',
 'ublock6transpose',
 'ublock61',
 'ublock62',
 'ublock7transpose',
 'ublock71',
 'ublock72',
 'ublock8transpose',
 'ublock81',
 'ublock82',
 'ublock9transpose',
 'ublock91',
 'ublock92',
 'conv9',
 'conv10']

In [18]:
# Initialize the global model
import time
import random
from tensorflow.keras import backend as K
import gc
global_model = UNetCompiled(input_size=(128, 128, 3), n_filters=32, n_classes=3)
global_model.compile(optimizer=tf.keras.optimizers.Adam(),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])
# List all layer names
all_layer_names = [layer.name for layer in global_model.layers]

num_rounds = 40

# Clients create a local model initially (in round 0)
client_models = []
for _ in range(len(client_datasets)-1):
    client_model = tf.keras.models.clone_model(global_model)
    client_model.set_weights(global_model.get_weights())  # Initialize with global weights
    client_model.compile(optimizer=tf.keras.optimizers.Adam(),
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])
    client_models.append(client_model)

# Initialize the layers to send for the first round
layers_to_send = []  # No layers to send initially
start_time = time.time()
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}/{num_rounds}")

    # Randomly select layers to train for this round
    layers_to_train = random.sample(layer_names, 4)  # Example: select 2 random layers
    print("Layers to train:", layers_to_train)

    # Include specific layers (e.g., output layers) that should always be trained
    layers_to_train.append('conv9')
    layers_to_train.append('conv10')

    # Server sends weights of the selected layers from the previous round to clients
    layer_weights_to_send = get_layer_weights(global_model, layers_to_send) if layers_to_send else None

    client_weights = []
    client_layers = []

    # Each client trains its local model with the specified layers
    for client_id, client_dataset in enumerate(client_datasets[0:(len(client_datasets)-1)]):
        client_model = client_models[client_id]  # Reuse the local model created in round 0

        if layer_weights_to_send:
            # Update only the relevant layers for this round

            update_layer_weights(client_model, layer_weights_to_send)


        # Set the specified layers to be trainable
        set_trainable_layers(client_model, layers_to_train)

        # Compile and train the local model on the client's dataset
        client_model.fit(client_dataset, epochs=1)
        print(f"Client {client_id} finished training with layers: {layers_to_train}")

        # Collect only the updated weights for the trained layers
        updated_weights = get_layer_weights(client_model, layers_to_train)
        client_weights.append(updated_weights)
        client_layers.append(layers_to_train)

    # Perform federated averaging with only the trained layers
    avg_weights = federated_averaging(client_weights, client_layers)
    # Update the global model with the averaged weights
    update_layer_weights(global_model, avg_weights)

    # Prepare the layers for the next round
    layers_to_send = layers_to_train

    print(f"Completed Round {round_num + 1}. Global model updated with layers: {layers_to_train}at time")
    print("time is ",(time.time()-start_time))
    del client_weights, client_layers
    K.clear_session()

    # Force garbage collection to prevent memory from accumulating
    gc.collect()

print("time is ",(time.time()-start_time))
# Evaluate on client 4's dataset
loss, accuracy = global_model.evaluate(client_datasets[-1])
print(f"Final evaluation on Client 4's dataset - Loss: {loss}, Accuracy: {accuracy}")

Round 1/40
Layers to train: ['ublock71', 'ublock62', 'cblock32', 'cblock21']
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 566ms/step - accuracy: 0.4366 - loss: 5.2644
Client 0 finished training with layers: ['ublock71', 'ublock62', 'cblock32', 'cblock21', 'conv9', 'conv10']
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 518ms/step - accuracy: 0.4388 - loss: 5.2227
Client 1 finished training with layers: ['ublock71', 'ublock62', 'cblock32', 'cblock21', 'conv9', 'conv10']
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 551ms/step - accuracy: 0.4409 - loss: 5.1734
Client 2 finished training with layers: ['ublock71', 'ublock62', 'cblock32', 'cblock21', 'conv9', 'conv10']
new_weights dict_keys(['cblock21', 'cblock32', 'ublock62', 'ublock71', 'conv9', 'conv10'])
Completed Round 1. Global model updated with layers: ['ublock71', 'ublock62', 'cblock32', 'cblock21', 'conv9', 'conv10']at time
time is  88.21108031272888
Round 2/40
Layers to tr

In [None]:
import numpy as np
import tensorflow as tf
import random

# Number of clients
num_clients = 5

# Shuffle the dataset and divide it into num_clients parts
client_data_indices = np.array_split(np.arange(len(img)), num_clients)

# Create datasets for each client
client_datasets = []
for indices in client_data_indices:
    client_images = np.array(img)[indices]
    client_masks = np.array(mask)[indices]

    # Create TensorFlow datasets for each client
    client_dataset = tf.data.Dataset.from_tensor_slices((client_images, client_masks))
    client_dataset = client_dataset.map(map_fn).batch(64).prefetch(1)
    client_datasets.append(client_dataset)


In [None]:
import random
global_model = UNetCompiled(input_size=(128, 128, 3), n_filters=32, n_classes=3)
all_layer_names = [layer.name for layer in global_model.layers]

def train_client_model(client_dataset, client_id, layers_to_train):
    local_model = tf.keras.models.clone_model(global_model)
    local_model.set_weights(global_model.get_weights())  # Copy global weights to the local model
    set_trainable_layers(local_model, layers_to_train)
    local_model.compile(optimizer=tf.keras.optimizers.Adam(),
                        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                        metrics=['accuracy'])
    local_model.fit(client_dataset, epochs=1)
    print(f"Client {client_id} finished training.")
    return local_model.get_weights(), layers_to_train


In [None]:
def federated_averaging(client_weights, client_layers):
    new_weights = [np.zeros_like(w) for w in client_weights[0]]
    for client_weight, layers_to_train in zip(client_weights, client_layers):
        for i, layer_name in enumerate(all_layer_names):
            if layer_name in layers_to_train:
                new_weights[i] += client_weight[i]
    new_weights = [w / len(client_weights) if all_layer_names[i] in set(sum(client_layers, [])) else w
                   for i, w in enumerate(new_weights)]
    return new_weights


In [None]:
num_rounds = 10

for round_num in range(num_rounds):
    print(f"Round {round_num + 1}/{num_rounds}")

    layers_to_train = random.sample(all_layer_names, 4)  # Example: select 4 random layers
    layers_to_train.extend(["conv9", "conv10"])  # Always train specific layers
    client_weights = []
    client_layers = []
    for client_id, client_dataset in enumerate(client_datasets):
        client_weight, trained_layers = train_client_model(client_dataset, client_id, layers_to_train)
        client_weights.append(client_weight)
        client_layers.append(trained_layers)

    avg_weights = federated_averaging(client_weights, client_layers)
    global_model.set_weights(avg_weights)

    print(f"Completed Round {round_num + 1}. Global model updated with layers: {layers_to_train}")
