In [3]:
# Import libraries
import os  # For interacting with the file system
import shutil  # For managing files and directories in a cross-platform manner
import keras  # For building deep learning models
import numpy as np  # For numerical operations on arrays
import tensorflow
from glob import glob  # For finding file paths
from tqdm import tqdm  # For progress bars

# Data preprocessing
from keras.preprocessing.image import ImageDataGenerator  # For image data augmentation
# Model architecture
from keras import Sequential  # For building sequential models
from keras.models import load_model  # For loading pre-trained models
from keras.layers import Dense, GlobalAvgPool2D as GAP, Dropout  # For defining model layers

# Training callbacks
from keras.callbacks import ModelCheckpoint, EarlyStopping  # For training callbacks

# Pre-trained models
from tensorflow.keras.applications import ResNet152V2  # For using pre-trained models

# Set the path to the dataset
data_path = '../input/animals10/raw-img'



In [5]:
class_names = sorted(os.listdir(data_path))

num_classes = len(class_names)


class_sizes = []
for name in class_names:
    class_size = len(os.listdir(data_path + "/" + name))
    class_sizes.append(class_size)

print("Class Distribution:\n", class_sizes)

Class Distribution:
 [4863, 2623, 1446, 2112, 3098, 1668, 1866, 1820, 4821, 1862]


In [7]:
class_name_size = dict(zip(class_names, class_sizes))

In [44]:
sampled_data_path = './sampled-data'
if os.path.exists(sampled_data_path):
    shutil.rmtree(sampled_data_path)

num_clients = 16

client_data_path = []

for i in range(num_clients):
    client_data_path.append('./client-data-' + str(i))
    if os.path.exists(client_data_path[i]):
        shutil.rmtree(client_data_path[i])
    if not os.path.exists(client_data_path[i]):
        os.mkdir(client_data_path[i])

# Create the sampled data directory if it doesn't exist
if not os.path.exists(sampled_data_path):
    os.mkdir(sampled_data_path)

# Set the percentage of each class to sample
sample_percent = 0.1

# Define a dictionary that maps the original class names to their English names
class_names_dict = {
    'cane': 'dog',
    'cavallo': 'horse',
    'elefante': 'elephant',
    'farfalla': 'butterfly',
    'gallina': 'chicken',
    'gatto': 'cat',
    'mucca': 'cow',
    'pecora': 'sheep',
    'ragno': 'spider',
    'scoiattolo': 'squirrel'
}

# Loop through each class directory and copy 2000 images or less to the sampled data directory
for class_name in os.listdir(data_path):
    # Get the path to the original class directory
    class_path = os.path.join(data_path, class_name)
    # Get the English name of the class
    class_name_en = class_names_dict[class_name]
    # Get the path to the sampled class directory
    sampled_class_path = os.path.join(sampled_data_path, class_name_en)
    
    client_class_path = []
    for i in range(num_clients):
        new_client_class_path = os.path.join(client_data_path[i], class_name_en)
        client_class_path.append(new_client_class_path)
        if not os.path.exists(new_client_class_path):
            os.mkdir(new_client_class_path)
            
    # Create the sampled class directory and fill it up only if it doesn't exist ********
    if not os.path.exists(sampled_class_path):
        os.mkdir(sampled_class_path)
        # Get a list of all the image files in the class directory
        image_files = os.listdir(class_path)
        # Calculate the number of images to sample **************
        image_class_size = class_name_size[class_name]
        if image_class_size > 2000:
            num_images = 2000
        else:
            num_images = int(image_class_size)
        # Sample the images
        sampled_images = np.random.choice(image_files, size=num_images, replace=False)
        
        # Split the sampled images into all the clients
        i = 0
        j = 0
        num_images_per_client = num_images / num_clients
        # print(num_images_per_client)
        for image_name in sampled_images:
            if(i >= num_images_per_client):
                j=j+1
                i = 0
            src_path = os.path.join(class_path, image_name)
            dst_path = os.path.join(client_class_path[j], image_name)
            shutil.copyfile(src_path, dst_path)
            i=i+1
        
        # Copy the sampled images to the sampled class directory
        for image_name in sampled_images:
            src_path = os.path.join(class_path, image_name)
            dst_path = os.path.join(sampled_class_path, image_name)
            shutil.copyfile(src_path, dst_path)


In [53]:
data_generator = ImageDataGenerator(
    rescale=1./255, 
    horizontal_flip=True, 
    vertical_flip=True, 
    rotation_range=20, 
    validation_split=0.2)

train_data = data_generator.flow_from_directory(
    sampled_data_path, 
    target_size=(256,256), 
    class_mode='binary', 
    batch_size=32, 
    shuffle=True, 
    subset='training')

valid_data = data_generator.flow_from_directory(
    sampled_data_path, 
    target_size=(256,256), 
    class_mode='binary', 
    batch_size=32, 
    shuffle=True, 
    subset='validation')

Found 14931 images belonging to 10 classes.
Found 3731 images belonging to 10 classes.


In [13]:
# Specify the name of the model as "ResNet152V2".
name = "ResNet152V2"

# Load the pre-trained ResNet152V2 model, freeze its weights and exclude its final classification layer.
base_model = ResNet152V2(include_top=False, input_shape=(256,256,3), weights='imagenet')
base_model.trainable = False

# Create a sequential model with the ResNet152V2 base model, a global average pooling layer, two fully connected layers, and a final softmax classification layer.
resnet152V2 = Sequential([
    base_model,
    GAP(),
    Dense(256, activation='relu'),
    Dropout(0.2),
    Dense(num_classes, activation='softmax')
], name=name)

# Compile the model with sparse categorical cross-entropy as the loss function, Adam optimizer and accuracy as the evaluation metric.
resnet152V2.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

# Set up the EarlyStopping and ModelCheckpoint callbacks to monitor the training process and save the best model weights.
cbs = [
    EarlyStopping(patience=3, restore_best_weights=True),
    ModelCheckpoint(name + ".h5", save_best_only=True)
]

# Train the model using the training and validation datasets, using 50 epochs and the previously defined callbacks.
resnet152V2.fit(
    train_data, validation_data=valid_data,
    epochs=5, callbacks=cbs
)

Epoch 1/5

  saving_api.save_model(


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7a05d7fa9030>

In [57]:
class ClientDevice:
    def __init__(self, device_id, data_generator):
        self.device_id = device_id
        self.data_generator = data_generator
        

    def get_train_data(self):
        return self.data_generator.flow_from_directory(
                                                    client_data_path[self.device_id], 
                                                    target_size=(256,256), 
                                                    class_mode='binary', 
                                                    batch_size=32, 
                                                    shuffle=True, 
                                                    subset='training')
    
    def get_valid_data(self):
        return self.data_generator.flow_from_directory(
                                                        client_data_path[self.device_id], 
                                                        target_size=(256,256), 
                                                        class_mode='binary', 
                                                        batch_size=32, 
                                                        shuffle=True, 
                                                        subset='validation')

client_devices = []

for i in range(num_clients):
    client_device = ClientDevice(device_id=i, data_generator=data_generator)
    client_devices.append(client_device)


In [74]:
# Load the pre-trained ResNet152V2 model, freeze its weights and exclude its final classification layer.
global_model = ResNet152V2(include_top=False, input_shape=(256,256,3), weights='imagenet')
global_model.trainable = False

# Create a sequential model with the ResNet152V2 base model, a global average pooling layer, two fully connected layers, and a final softmax classification layer.
model = Sequential([
    global_model,
    GAP(),
    Dense(256, activation='relu'),
    Dropout(0.2),
    Dense(num_classes, activation='softmax')
], name=name)

# Compile the model with sparse categorical cross-entropy as the loss function, Adam optimizer and accuracy as the evaluation metric.
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

# Set up the EarlyStopping and ModelCheckpoint callbacks to monitor the training process and save the best model weights.
cbs = [
    EarlyStopping(patience=3, restore_best_weights=True),
    ModelCheckpoint(name + ".h5", save_best_only=True)
]

num_epochs = 5

for epoch in range(num_epochs):
    # Store the current global model to be continuously updated for a single epoch
    all_trained_model_weights = []
    
    for client_id in range(num_clients):
        local_name = name + "_loc_" + str(client_id)
        
        local_model = tensorflow.keras.models.clone_model(model)
        
        # Compile the local model
        local_model.compile(
            loss='sparse_categorical_crossentropy',
            optimizer='adam',
            metrics=['accuracy']
        )
        
        cbs = [
                    EarlyStopping(patience=3, restore_best_weights=True),
                    ModelCheckpoint(local_name + ".h5", save_best_only=True)
                ]
        # Train the local model
        local_model.fit(
            client_devices[client_id].get_train_data(),
            validation_data=client_devices[client_id].get_valid_data(),
            epochs=1, 
            callbacks=cbs
        )
        
        all_trained_model_weights.append(local_model.get_weights())
    
    # average out all the models and reset the global model before epoch ends
    model_weights = model.get_weights()
    
    for i in range(len(model_weights)):
        sum = 0.0
        for local_model in all_trained_model_weights:
            sum += local_model[i]
        model_weights[i] = sum // num_clients
    
    print("Epoch!")
    model.set_weights(model_weights)

    model.evaluate(valid_data)

Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 classes.
Found 937 images belonging to 10 classes.
Found 232 images belonging to 10 c