<a href="https://colab.research.google.com/github/stefffffffffano/AML_FederatedLearning/blob/main/Federated_CIFAR_100.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
import torch
import torch.nn as nn

In [None]:
# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)
from google.colab import drive
drive.mount('/content/drive')
NUM_CLIENTS = 100  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_STEPS = 4  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

BATCH_SIZE = 64  # Batch size for local training
MOMENTUM = 0.9  # Momentum for SGD optimizer
CHECKPOINT_DIR = '/content/drive/MyDrive/colab_checkpoints/'
LOG_FREQUENCY = 10  # Frequency of logging training progress

In [3]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

class CIFAR100DataLoader:
    def __init__(self, batch_size=64, validation_split=0.1, download=True, num_workers=4, pin_memory=True):
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.download = download
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # Define transformations
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
        ])

        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
        ])

        # Load datasets
        self.train_loader, self.val_loader, self.test_loader = self._prepare_loaders()

    def _prepare_loaders(self):
        # Load the full training dataset
        full_trainset = datasets.CIFAR100(root='./data', train=True, download=self.download, transform=self.train_transform)

        # Split indices for training and validation
        indexes = list(range(len(full_trainset)))
        train_indexes, val_indexes = train_test_split(
            indexes,
            train_size=1 - self.validation_split,
            test_size=self.validation_split,
            random_state=42,
            stratify=full_trainset.targets,
            shuffle=True
        )

        # Create training and validation subsets
        train_dataset = Subset(full_trainset, train_indexes)
        train_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        full_trainset_val = datasets.CIFAR100(root='./data', train=True, download=self.download, transform=self.test_transform)
        val_dataset = Subset(full_trainset_val, val_indexes)
        val_loader = DataLoader(
            val_dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        # Load the test dataset
        testset = datasets.CIFAR100(root='./data', train=False, download=self.download, transform=self.test_transform)
        test_loader = DataLoader(
            testset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        return train_loader, val_loader, test_loader

    def __iter__(self):
        """Allows iteration over all loaders for unified access."""
        return iter([self.train_loader, self.val_loader, self.test_loader])

In [4]:
import torch
from statistics import mean
import torch.nn as nn

"""
Utility function used both in the centralized and federated learning
Computes the accuracy and the loss on the validation/test set depending on the dataloader passed
"""

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100
def evaluate(model, dataloader):
    with torch.no_grad():
        model.train(False) # Set Network to evaluation mode
        running_corrects = 0
        losses = []
        for data, targets in dataloader:
            data = data.to(DEVICE)        # Move the data to the GPU
            targets = targets.to(DEVICE)  # Move the targets to the GPU
            # Forward Pass
            outputs = model(data)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
            # Get predictions
            _, preds = torch.max(outputs.data, 1)
            # Update Corrects
            running_corrects += torch.sum(preds == targets.data).data.item()
            # Calculate Accuracy
            accuracy = running_corrects / float(len(dataloader.dataset))

    return accuracy*100, mean(losses)

In [5]:
import torch.nn as nn
import torch.nn.functional as F
"""
Model architecture for the CIFAR-100 dataset.
The model is based on the LeNet-5 architecture with some modifications.
Reference: Hsu et al., Federated Visual Classification with Real-World Data Distribution, ECCV 2020

CNN similar to LeNet5 which has two 5×5, 64-channel convolution layers, each precedes a 2×2
max-pooling layer, followed by two fully-connected layers with 384 and 192
channels respectively and finally a softmax linear classifier
"""


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc_layer = nn.Sequential(
            nn.Linear(64 * 5 * 5, 384),  # Updated to be consistent with data augmentation
            nn.ReLU(),
            nn.Linear(384, 192),
            nn.ReLU(),
            nn.Linear(192, 100)  # 100 classes for CIFAR-100
        )

    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)  # Flatten the output of the conv layers
        x = self.fc_layer(x)
        x = F.log_softmax(x, dim=1)
        return x


In [6]:
import os
import torch
import json

# Ensure the checkpoint directory exists, creating it if necessary
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, hyperparameters, subfolder="", checkpoint_data=None):
    """
    Saves the model checkpoint and removes the previous one if it exists.

    Arguments:
    model -- The model whose state is to be saved.
    optimizer -- The optimizer whose state is to be saved (can be None).
    epoch -- The current epoch of the training process.
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to save the checkpoint.
    checkpoint_data -- Data to save in a JSON file (e.g., training logs).
    """
    # Define the path for the subfolder where checkpoints will be stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    # Create the subfolder if it doesn't exist
    os.makedirs(subfolder_path, exist_ok=True)

    # Construct filenames for both the model checkpoint and the associated JSON file
    filename = f"model_epoch_{epoch}_params_{hyperparameters}.pth"
    filepath = os.path.join(subfolder_path, filename)
    filename_json = f"model_epoch_{epoch}_params_{hyperparameters}.json"
    filepath_json = os.path.join(subfolder_path, filename_json)

    # Define the filenames for the previous checkpoint files, to remove them if necessary
    previous_filepath = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.pth")
    previous_filepath_json = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.json")

    # Remove the previous checkpoint if it exists, but only for epochs greater than 1
    if epoch > 1 and os.path.exists(previous_filepath):
        os.remove(previous_filepath)
        os.remove(previous_filepath_json)

    # Prepare the checkpoint data dictionary
    checkpoint = {'model_state_dict': model.state_dict(), 'epoch': epoch}
    # If an optimizer is provided, save its state as well
    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    # Save the model and optimizer (if provided) state dictionary to the checkpoint file
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

    # If additional data (e.g., training logs) is provided, save it to a JSON file
    if checkpoint_data:
        with open(filepath_json, 'w') as json_file:
            json.dump(checkpoint_data, json_file, indent=4)

def load_checkpoint(model, optimizer, hyperparameters, subfolder=""):
    """
    Loads the latest checkpoint available based on the specified hyperparameters.

    Arguments:
    model -- The model whose state will be updated from the checkpoint.
    optimizer -- The optimizer whose state will be updated from the checkpoint (can be None).
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to look for checkpoints.

    Returns:
    The next epoch to resume from and the associated JSON data if available.
    """
    # Define the path to the subfolder where checkpoints are stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)

    # If the subfolder doesn't exist, print a message and start from epoch 1
    if not os.path.exists(subfolder_path):
        print("No checkpoint found, starting from epoch 1.")
        return 1, None  # Epoch starts from 1

    # Search for checkpoint files in the subfolder that match the hyperparameters
    files = [f for f in os.listdir(subfolder_path) if f"params_{hyperparameters}" in f and f.endswith('.pth')]

    # If checkpoint files are found, load the one with the highest epoch number
    if files:
        latest_file = max(files, key=lambda x: int(x.split('_')[2]))  # Find the latest epoch file
        filepath = os.path.join(subfolder_path, latest_file)
        checkpoint = torch.load(filepath, weights_only=True)

        # Load the model state from the checkpoint
        model.load_state_dict(checkpoint['model_state_dict'])
        # If an optimizer is provided, load its state as well
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Try to load the associated JSON file if available
        json_filepath = os.path.join(subfolder_path, latest_file.replace('.pth', '.json'))
        json_data = None
        if os.path.exists(json_filepath):
            # If the JSON file exists, load its contents
            with open(json_filepath, 'r') as json_file:
                json_data = json.load(json_file)
            print("Data loaded!")
        else:
            # If no JSON file exists, print a message
            print("No data found")

        # Print the epoch from which the model is resuming
        print(f"Checkpoint found: Resuming from epoch {checkpoint['epoch'] + 1}\n\n")
        return checkpoint['epoch'] + 1, json_data

    # If no checkpoint is found, print a message and start from epoch 1
    print("No checkpoint found, starting from epoch 1..\n\n")
    return 1, None  # Epoch starts from 1

def delete_existing_checkpoints(subfolder=""):
    """
    Deletes all existing checkpoints in the specified subfolder.

    Arguments:
    subfolder -- Optional subfolder within the checkpoint directory to delete checkpoints from.
    """
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    if os.path.exists(subfolder_path):
        for file_name in os.listdir(subfolder_path):
            file_path = os.path.join(subfolder_path, file_name)
            if os.path.isfile(file_path):
                os.remove(file_path)
        print(f"All existing checkpoints in {subfolder_path} have been deleted.")
    else:
        print(f"No checkpoint folder found at {subfolder_path}.")

In [None]:
#10% of the dataset kept for validation
data_loader = CIFAR100DataLoader(batch_size=BATCH_SIZE, validation_split=0.1, download=True, num_workers=4, pin_memory=True)
trainloader, validloader, testloader = data_loader.train_loader, data_loader.val_loader, data_loader.test_loader

In [8]:
global_model = LeNet5()
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100

In [9]:
from torch.backends import cudnn
import time


class Client:
    def __init__(self, client_id, data_loader, model, device):
        """
        Initializes a federated learning client.
        :param client_id: Unique identifier for the client.
        :param data_loader: Data loader specific to the client.
        :param model: The model class to be used by the client.
        :param device: The device (CPU/GPU) to perform computations.
        """
        self.client_id = client_id
        self.data_loader = data_loader
        self.model = model.to(device)
        self.device = device

    def client_update(self, client_data, criterion, optimizer, local_steps=4, detailed_print=False):
        """
        Trains a given client's local model on its dataset for a fixed number of steps (`local_steps`).

        Args:
            model (nn.Module): The local model to be updated.
            client_id (int): Identifier for the client (used for logging/debugging purposes).
            client_data (DataLoader): The data loader for the client's dataset.
            criterion (Loss): The loss function used for training (e.g., CrossEntropyLoss).
            optimizer (Optimizer): The optimizer used for updating model parameters (e.g., SGD).
            local_steps (int): Number of local epochs to train on the client's dataset.
            detailed_print (bool): If True, logs the final loss after training.

        Returns:
            dict: The state dictionary of the updated model.
        """


        cudnn.benchmark  # Calling this optimizes runtime

        self.model.train()  # Set the model to training mode
        step_count = 0
        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        while step_count < local_steps:
            for data, targets in client_data:
                # Move data and targets to the specified device (e.g., GPU or CPU)
                data, targets = data.to(self.device), targets.to(self.device)


                start_time = time.time()  # for testing-----------------------------

                # Reset the gradients before backpropagation
                optimizer.zero_grad()

                # Forward pass: compute model predictions
                outputs = self.model(data)

                # Compute the loss
                loss = criterion(outputs, targets)

                # Backward pass: compute gradients and update weights
                loss.backward()
                optimizer.step()

                #---------- Accumulate metrics
                #  Accumulates the weighted loss for the number of samples in the batch to account for any variation in
                #  batch size due to, for example, the smaller final batch. A little too precise? :)
                total_loss += loss.item() * data.size(0)
                _, predicted = outputs.max(1)
                correct_predictions += predicted.eq(targets).sum().item()
                total_samples += data.size(0)

                step_count +=1
                if step_count >= local_steps:
                  break

        #---------- Compute averaged metrics
        avg_loss = total_loss / total_samples
        avg_accuracy = correct_predictions / total_samples * 100

        # Optionally, print the loss for the last epoch of training
        if detailed_print:
          print(f'Client {self.client_id} --> Final Loss (Step {step_count}/{local_steps}): {loss.item()}')


        # Return the updated model's state dictionary (weights)
        return self.model.state_dict(), avg_loss, avg_accuracy

In [10]:
import torch
import numpy as np
import os
import matplotlib.pyplot as plt

DIR = '/content/drive/MyDrive/colab_plots/'

def plot_client_selection(client_selection_count, file_name):
    """
    Bar plot to visualize the frequency of client selections in a federated learning simulation.

    Args:
        client_selection_count (list): list containing the number of times each client was selected.
        file_name (str): name of the file to save the plot.
    """
    # Fixed base directory
    directory = DIR +  'plots_federated/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    num_clients = len(client_selection_count)
    plt.figure(figsize=(10, 6))
    plt.bar(range(num_clients), client_selection_count, alpha=0.7, edgecolor='black')
    plt.xlabel("Client ID", fontsize=14)
    plt.ylabel("Selection Count", fontsize=14)
    plt.title("Client Selection Frequency", fontsize=16)
    plt.xticks(range(num_clients), fontsize=10, rotation=90 if num_clients > 20 else 0)
    plt.tight_layout()
    plt.savefig(file_path, format="png", dpi=300)
    plt.close()

def test(global_model, test_loader):
    """
    Evaluate the global model on the test dataset.

    Args:
        global_model (nn.Module): The global model to be evaluated.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        float: The accuracy of the model on the test dataset.
    """
    test_accuracy, _ = evaluate(global_model, test_loader)
    return test_accuracy

def plot_metrics(train_accuracies, train_losses, val_accuracies,val_losses, file_name):
    """
    Plot the training and validation metrics for a federated learning simulation.

    Args:
        train_accuracies (list): List of training accuracies.
        train_losses (list): List of training losses.
        val_accuracies (list): List of validation accuracies.
        val_losses (list): List of validation losses.
        file_name (str): Name of the file to save the plot.
    """
    # Fixed base directory
    directory = DIR + '/plots_federated/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Create a list of epochs for the x-axis
    epochs = list(range(1, len(train_losses) + 1))

    # Plot the training and validation losses
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, label='Train Loss', color='blue')
    plt.plot(epochs, val_losses, label='Validation Loss', color='red')
    plt.xlabel('Rounds', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title('Training and Validation Loss', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(file_path.replace('.png', '_loss.png'), format='png', dpi=300)
    plt.close()

    # Plot the training and validation accuracies
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_accuracies, label='Train Accuracy', color='blue')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy', color='red')
    plt.xlabel('Rounds', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.title('Training and Validation Accuracy', fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(file_path.replace('.png', '_accuracy.png'), format='png', dpi=300)
    plt.close()


def save_data(global_model, val_accuracies, val_losses, train_accuracies, train_losses,client_count, file_name):
    """
    Save the global model, val_accuracies, val_losses, train_accuracies,train_losses and client_count to a file.

    Args:
        global_model (nn.Module): The global model to be saved.
        val_accuracies (list): List of validation accuracies.
        val_losses (list): List of validation losses.
        train_accuracies (list): List of training accuracies.
        train_losses (list): List of training losses.
        file_name (str): Name of the file to save the data.
    """
    # Fixed base directory
    directory = DIR + '/trained_models/'
    # Ensure the base directory exists
    os.makedirs(directory, exist_ok=True)

    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Save all data (model state and metrics) into a dictionary
    save_dict = {
        'model_state': global_model.state_dict(),
        'val_accuracies': val_accuracies,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'train_losses': train_losses,
        'client_count': client_count
    }

    # Save the dictionary to the specified file
    torch.save(save_dict, file_path)
    print(f"Data saved successfully to {file_path}")

def load_data(model, file_name):
    """
    Load the model weights and metrics from a file.

    Args:
        model (nn.Module): The model to load the weights into.
        file_name (str): Name of the file to load the data from.

    Returns:
        tuple: A tuple containing the model, val_accuracies, val_losses, train_accuracies train_losses and client_count.
    """
    # Fixed base directory
    directory = DIR+ 'trained_models/'
    # Complete path for the file
    file_path = os.path.join(directory, file_name)

    # Load the saved data from the specified file
    save_dict = torch.load(file_path)

    # Load the model state
    model.load_state_dict(save_dict['model_state'])

    # Extract the metrics
    val_accuracies = save_dict['val_accuracies']
    val_losses = save_dict['val_losses']
    train_accuracies = save_dict['train_accuracies']
    train_losses = save_dict['train_losses']
    client_count = save_dict['client_count']

    print(f"Data loaded successfully from {file_path}")

    return model, val_accuracies, val_losses, train_accuracies, train_losses,client_count

In [11]:
import torch
import torch.optim as optim
from copy import deepcopy
import numpy as np
from torch.utils.data import Subset
import os
from torch.utils.data import DataLoader, Subset
import logging

log = logging.getLogger(__name__)

class Server:
    def __init__(self, global_model, device, CHECKPOINT_DIR):
        self.global_model = global_model
        self.device = device
        self.CHECKPOINT_DIR = CHECKPOINT_DIR
        # Ensure the checkpoint directory exists, creating it if necessary
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    def fedavg_aggregate(self, client_states, client_sizes, client_avg_losses, client_avg_accuracies):
        """
        Aggregates model updates and client metrics from selected clients using the Federated Averaging (FedAvg) algorithm.
        The updates and metrics are weighted by the size of each client's dataset.

        Args:
            global_model (nn.Module): The global model whose structure is used for aggregation.
            client_states (list[dict]): A list of state dictionaries (model weights) from participating clients.
            client_sizes (list[int]): A list of dataset sizes for the participating clients.
            client_avg_losses (list[float]): A list of average losses for the participating clients.
            client_avg_accuracies (list[float]): A list of average accuracies for the participating clients.

        Returns:
            tuple: The aggregated state dictionary with updated model parameters, global average loss, and global average accuracy.
        """
        # Copy the global model's state dictionary for aggregation
        new_state = deepcopy(self.global_model.state_dict())

        # Calculate the total number of samples across all participating clients
        total_samples = sum(client_sizes)

        # Initialize all parameters in the new state to zero
        for key in new_state:
            new_state[key] = torch.zeros_like(new_state[key])

        # Initialize metrics
        total_loss = 0.0
        total_accuracy = 0.0

        # Perform a weighted average of client updates and metrics
        for state, size, avg_loss, avg_accuracy in zip(client_states, client_sizes, client_avg_losses, client_avg_accuracies):
            for key in new_state:
                # Add the weighted contribution of each client's parameters
                new_state[key] += (state[key] * size / total_samples)
            total_loss += avg_loss * size
            total_accuracy += avg_accuracy * size

        # Calculate global metrics
        global_avg_loss = total_loss / total_samples
        global_avg_accuracy = total_accuracy / total_samples

        # Return the aggregated state dictionary with updated weights and global metrics
        return new_state, global_avg_loss, global_avg_accuracy


    # Federated Learning Training Loop
    def train_federated(self, criterion, trainloader, validloader, num_clients, num_classes, rounds, lr, momentum, batchsize, wd, C=0.1, local_steps=4, log_freq=10, detailed_print=False,gamma=None):
        val_accuracies = []
        val_losses = []
        train_accuracies = []
        train_losses = []
        best_model_state = None  # The model with the best accuracy
        client_selection_count = [0] * num_clients #Count how many times a client has been selected
        best_val_acc = 0.0

        shards = self.sharding(trainloader.dataset, num_clients, num_classes) #each shard represent the training data for one client
        client_sizes = [len(shard) for shard in shards]

        self.global_model.to(self.device) #as alwayse, we move the global model to the specified device (CPU or GPU)

        #loading checkpoint if it exists
        checkpoint_start_step, data_to_load = load_checkpoint(model=self.global_model,optimizer=None,hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/")
        if data_to_load is not None:
          val_accuracies = data_to_load['val_accuracies']
          val_losses = data_to_load['val_losses']
          train_accuracies = data_to_load['train_accuracies']
          train_losses = data_to_load['train_losses']
          client_selection_count = data_to_load['client_selection_count']
        probabilities = None
        if gamma is not None:
            probabilities = self.skewed_probabilities(num_clients, gamma)

        # ********************* HOW IT WORKS ***************************************
        # The training runs for rounds iterations (GLOBAL_ROUNDS=2000)
        # Each round simulates one communication step in federated learning, including:
        # 1) client selection
        # 2) local training (of each client)
        # 3) central aggregation
        for round_num in range(checkpoint_start_step, rounds):
            if (round_num+1) % log_freq == 0 and detailed_print:
              print(f"------------------------------------- Round {round_num+1} ------------------------------------------------" )

            # 1) client selection: In each round, a fraction C (e.g., 10%) of clients is randomly selected to participate.
            #     This reduces computation costs and mimics real-world scenarios where not all devices are active.
            selected_clients = self.client_selection(num_clients, C,probabilities)
            client_states = []
            client_avg_losses = []
            client_avg_accuracies = []
            for client_id in selected_clients:
                client_selection_count[client_id] += 1

            # 2) local training: for each client updates the model using the client's data for local_steps epochs
            for client_id in selected_clients:
                local_model = deepcopy(self.global_model) #it creates a local copy of the global model
                optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #same of the centralized version
                client_loader = DataLoader(shards[client_id], batch_size=batchsize, shuffle=True)

                print_log =  (round_num+1) % log_freq == 0 and detailed_print
                client = Client(client_id, client_loader, local_model, self.device)
                client_local_state, client_avg_loss, client_avg_accuracy  = client.client_update(client_loader, criterion, optimizer, local_steps, print_log)

                client_states.append(client_local_state)
                client_avg_losses.append(client_avg_loss)
                client_avg_accuracies.append(client_avg_accuracy)


            # 3) central aggregation: aggregates participating client updates using fedavg_aggregate
            #    and replaces the current parameters of global_model with the returned ones.
            aggregated_state, train_loss, train_accuracy = self.fedavg_aggregate(client_states, [client_sizes[i] for i in selected_clients], client_avg_losses, client_avg_accuracies)
        
            self.global_model.load_state_dict(aggregated_state)

            train_accuracies.append(train_accuracy)
            train_losses.append(train_loss)
            #Validation at the server
            val_accuracy, val_loss = evaluate(self.global_model, validloader)
            val_accuracies.append(val_accuracy)
            val_losses.append(val_loss)
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                best_model_state = deepcopy(self.global_model.state_dict())

            if (round_num+1) % log_freq == 0:
                if detailed_print:
                    print(f"--> best validation accuracy: {best_val_acc:.2f}\n--> training accuracy: {train_accuracy:.2f}")
                    print(f"--> validation loss: {val_loss:.4f}\n--> training loss: {train_loss:.4f}")

                # checkpointing
                checkpoint_data = {
                    'val_accuracies': val_accuracies,
                    'val_losses': val_losses,
                    'train_accuracies': train_accuracies,
                    'train_losses': train_losses,
                    'client_selection_count': client_selection_count
                }
                save_checkpoint(self.global_model,optimizer=None, epoch=round_num, hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/", checkpoint_data=checkpoint_data)
                if detailed_print:
                    print(f"------------------------------ Round {round_num+1} terminated: model updated -----------------------------\n\n" )

        self.global_model.load_state_dict(best_model_state)

        return self.global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count

    def skewed_probabilities(self, number_of_clients, gamma=0.5):
            # Generate skewed probabilities using a Dirichlet distribution
            probabilities = np.random.dirichlet(np.ones(number_of_clients) * gamma)
            return probabilities

    def client_selection(self,number_of_clients, clients_fraction, probabilities=None):
        """
        Selects a subset of clients based on uniform or skewed distribution.
        
        Args:
        number_of_clients (int): Total number of clients.
        clients_fraction (float): Fraction of clients to be selected.
        uniform (bool): If True, selects clients uniformly. If False, selects clients based on a skewed distribution.
        gamma (float): Hyperparameter for the Dirichlet distribution controlling the skewness (only used if uniform=False).
        
        Returns:
        list: List of selected client indices.
        """
        num_clients_to_select = int(number_of_clients * clients_fraction)
        
        if probabilities is None:
            # Uniformly select clients without replacement
            selected_clients = np.random.choice(number_of_clients, num_clients_to_select, replace=False)
        else:
            selected_clients = np.random.choice(number_of_clients, num_clients_to_select, replace=False, p=probabilities)
        
        return selected_clients



    def sharding(self, dataset, number_of_clients, number_of_classes=100):
        """
        Function that performs the sharding of the dataset given as input.
        dataset: dataset to be split (should be a PyTorch dataset or similar);
        number_of_clients: the number of partitions we want to obtain (e.g., 100 for 100 clients);
        number_of_classes: (int) the number of classes inside each partition, or 100 for IID (default to 100).
        """

        # Validate the number of classes input
        if not (1 <= number_of_classes <= 100):
            raise ValueError("number_of_classes should be an integer between 1 and 100")

        # Shuffle dataset indices for randomness
        indices = np.random.permutation(len(dataset))

        if number_of_classes == 100:  # IID Case
            # Equally distribute indices among clients: we can just randomly assign an equal number of records to each client
            
            # Compute basic partition sizes
            basic_partition_size = len(dataset) // number_of_clients
            remainder = len(dataset) % number_of_clients

            shards = []  # This will hold the final dataset shards
            start_idx = 0

            for i in range(number_of_clients):
                end_idx = start_idx + basic_partition_size + (1 if i < remainder else 0)
                shards.append(Subset(dataset, indices[start_idx:end_idx]))
                start_idx = end_idx
            return shards

        else:  # non-IID Case
            # Get labels for sorting
            images = np.array([dataset[i][0] for i in range(len(dataset))])  # Assuming each sample is a tuple (data, label)
            labels = np.array([dataset[i][1] for i in range(len(dataset))]) 
            TOTAL_NUM_CLASSES = len(set(labels))

            shard_size = len(dataset) // (number_of_clients * number_of_classes)  # Shard size for each class per client
            print("dataset len: ", len(dataset), ", shard size: ", shard_size, ", number of shards: ",(number_of_clients * number_of_classes))
            if shard_size == 0:
                raise ValueError("Shard size is too small; increase dataset size or reduce number of clients/classes.")


            # Divide the dataset into shards, each containing samples from one class
            shards = {}
            for i in range(TOTAL_NUM_CLASSES):  
                # Filter samples for the current class
                class_samples = [(images[j], labels[j]) for j in range(len(labels)) if labels[j] == i]
                shards_of_class_i = []
                # While there are enough samples to form a shard
                while len(class_samples) >= shard_size:
                    # Take a shard of shard_size samples
                    shards_of_class_i.append(class_samples[:shard_size])
                    # Remove the shard_size samples from class_samples
                    class_samples = class_samples[shard_size:]
                # Add the last shard (which might be smaller than shard_size)
                if class_samples:
                    shards_of_class_i.append(class_samples)
                # Store the class shards
                shards[i] = shards_of_class_i  # Store shards by class
      
            client_shards = []  # List to store the dataset for each client
                        
            for client_id in range(number_of_clients):
                
                client_labels = [label % TOTAL_NUM_CLASSES for label in range(client_id * number_of_classes, client_id * number_of_classes + number_of_classes)]

                #print(client_labels)

                # Collect the shards for the selected classes
                client_shard_indices = []
                for label in client_labels:
                    shard = shards[label].pop(0)  # Pop the first shard from the class's shard list
                    client_shard_indices.append(shard)

                # Flatten and combine the shard indices into one list
                client_indices = [sample[0] for shard in client_shard_indices for sample in shard]

                #print(f"Client {client_id} has {len(client_indices)} samples divided in {len(client_shard_indices)} shards (classes).")

                # Create a Subset for the client
                client_dataset = Subset(dataset, client_indices)
                client_shards.append(client_dataset)

            return client_shards  # Return the list of dataset subsets (shards) for each client

In [None]:
#Hyperparameters tuning function
def hyperparameters_tuning(num_classes, local_steps, rounds):
    print(f"Hyperparameter tuning for num_classes={num_classes}, local_steps={local_steps}")
    lr_values = np.logspace(-3, -1, num=3)
    wd_values = np.logspace(-4, -1, num=4)
    best_val_accuracy = 0
    best_setting = None
    for lr in lr_values:
        for wd in wd_values:
            print(f"Learning rate: {lr}, Weight decay: {wd}")
            global_model = LeNet5()
            server = Server(global_model, DEVICE, CHECKPOINT_DIR)
            global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=num_classes, rounds=rounds, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=local_steps,log_freq=100, detailed_print=False,gamma=None)
            plot_metrics(train_accuracies, train_losses,val_accuracies, val_losses, f"FederatedTuning_Nc_{num_classes}_J_{local_steps}_lr_{lr}_wd_{wd}.png")
            print(f"Validation accuracy: {val_accuracies[-1]} with lr: {lr} and wd: {wd}")
            max_val_accuracy = max(val_accuracies)
            if max_val_accuracy > best_val_accuracy:
                best_val_accuracy = max_val_accuracy
                best_setting = (lr,wd)
    print(f"Best setting: {best_setting} with validation accuracy: {best_val_accuracy}")
    return best_setting

In [None]:
# Constants
LOCAL_STEPS_VALUES = [4, 8, 16]  # Values for J (number of local steps)
NUM_CLASSES_VALUES = [1, 5, 10, 50]  # Number of classes per client for Non-IID
NUM_RUNDS = {4: 2000, 8: 1000, 16:500}
IID_CLASSES = 100  # Full IID distribution

# Function to perform the training and testing for a given configuration
def run_experiment(num_classes, local_steps, plot_suffix):
    print(f"Running experiment: num_classes={num_classes}, local_steps={local_steps}")
    global_model = LeNet5()
    server = Server(global_model, DEVICE, CHECKPOINT_DIR)

    tuning_rounds = int(NUM_RUNDS[local_steps]/20)
    best_lr, best_wd = hyperparameters_tuning(num_classes = num_classes, local_steps=local_steps, rounds=tuning_rounds)

    global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(
        criterion, trainloader, validloader, 
        num_clients=NUM_CLIENTS, num_classes=num_classes, 
        rounds=NUM_RUNDS[local_steps], lr=best_lr, momentum=MOMENTUM, 
        batchsize=BATCH_SIZE, wd=best_wd, C=FRACTION_CLIENTS, 
        local_steps=local_steps, log_freq=100, 
        detailed_print=False, gamma=None  # No skewed sampling for this experiment
    )

    # Testing and plotting
    test_accuracy = test(global_model, testloader)
    plot_metrics(train_accuracies, train_losses, val_accuracies, val_losses, f"Federated_{plot_suffix}_LR_{best_lr}_WD_{best_wd}.png")
    print(f"Test accuracy for num_classes={num_classes}, local_steps={local_steps}: {test_accuracy}")

    # Save data for future analysis
    save_data(global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count, f"Federated_{plot_suffix}_LR_{best_lr}_WD_{best_wd}.pth")

# Main experiment loop
#for num_classes in NUM_CLASSES_VALUES + [IID_CLASSES]:  # Include IID setting
 #   for local_steps in LOCAL_STEPS_VALUES:
  #      plot_suffix = f"num_classes_{num_classes}_local_steps_{local_steps}"
    #    run_experiment(num_classes, local_steps, plot_suffix)

#to run iterations manually
num_classes = 100
local_steps = 8
plot_suffix = f"num_classes_{num_classes}_local_steps_{local_steps}"
run_experiment(num_classes, local_steps, plot_suffix)