In [149]:
import sys
sys.path.append('../data/cifar100/')
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
import random
from torch.utils.data import Subset
from statistics import mean
#from cifar100_loader import load_cifar100
#from models.model import LeNet5 #import the model

### Constants for FL training

In [150]:
# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

NUM_CLIENTS = 100  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_STEPS = 4  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

BATCH_SIZE = 64  # Batch size for local training
LR = 1e-2  # Initial learning rate for local optimizers
MOMENTUM = 0.9  # Momentum for SGD optimizer
WEIGHT_DECAY = 1e-4  # Regularization term for local training

LOG_FREQUENCY = 10  # Frequency of logging training progress

cuda


# Loaders

In [151]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

class CIFAR100DataLoader:
    def __init__(self, batch_size=32, validation_split=0.1, download=True, num_workers=4, pin_memory=True):
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.download = download
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # Define transformations
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(24, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
        ])

        self.test_transform = transforms.Compose([
            transforms.CenterCrop(24),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
        ])

        # Load datasets
        self.train_loader, self.val_loader, self.test_loader = self._prepare_loaders()

    def _prepare_loaders(self):
        # Load the full training dataset
        full_trainset = datasets.CIFAR100(root='./data', train=True, download=self.download, transform=self.train_transform)

        # Split indices for training and validation
        indexes = list(range(len(full_trainset)))
        train_indexes, val_indexes = train_test_split(
            indexes,
            train_size=1 - self.validation_split,
            test_size=self.validation_split,
            random_state=42,
            stratify=full_trainset.targets,
            shuffle=True
        )

        # Create training and validation subsets
        train_dataset = Subset(full_trainset, train_indexes)
        train_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        full_trainset_val = datasets.CIFAR100(root='./data', train=True, download=self.download, transform=self.test_transform)
        val_dataset = Subset(full_trainset_val, val_indexes)
        val_loader = DataLoader(
            val_dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        # Load the test dataset
        testset = datasets.CIFAR100(root='./data', train=False, download=self.download, transform=self.test_transform)
        test_loader = DataLoader(
            testset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers, pin_memory=self.pin_memory
        )

        return train_loader, val_loader, test_loader

    def __iter__(self):
        """Allows iteration over all loaders for unified access."""
        return iter([self.train_loader, self.val_loader, self.test_loader])

### Load the dataset

In [152]:
#10% of the dataset kept for validation
data_loader = CIFAR100DataLoader(batch_size=32, validation_split=0.1, download=True, num_workers=2, pin_memory=True)
trainloader, validloader, testloader = data_loader.train_loader, data_loader.val_loader, data_loader.test_loader

print("Dimension of the training dataset:", len(trainloader.dataset))
print("Dimension of the validation dataset:", len(validloader.dataset))
print("Dimension of the test dataset:", len(testloader.dataset))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Dimension of the training dataset: 45000
Dimension of the validation dataset: 5000
Dimension of the test dataset: 10000


## Checkpointing

In [153]:
import os
import torch
import json

# Directory where checkpoints are stored
CHECKPOINT_DIR = '../checkpoints/'

# Ensure the checkpoint directory exists, creating it if necessary
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(model, optimizer, epoch, hyperparameters, subfolder="", checkpoint_data=None):
    """
    Saves the model checkpoint and removes the previous one if it exists.

    Arguments:
    model -- The model whose state is to be saved.
    optimizer -- The optimizer whose state is to be saved (can be None).
    epoch -- The current epoch of the training process.
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to save the checkpoint.
    checkpoint_data -- Data to save in a JSON file (e.g., training logs).
    """
    # Define the path for the subfolder where checkpoints will be stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    # Create the subfolder if it doesn't exist
    os.makedirs(subfolder_path, exist_ok=True)

    # Construct filenames for both the model checkpoint and the associated JSON file
    filename = f"model_epoch_{epoch}_params_{hyperparameters}.pth"
    filepath = os.path.join(subfolder_path, filename)
    filename_json = f"model_epoch_{epoch}_params_{hyperparameters}.json"
    filepath_json = os.path.join(subfolder_path, filename_json)

    # Define the filenames for the previous checkpoint files, to remove them if necessary
    previous_filepath = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.pth")
    previous_filepath_json = os.path.join(subfolder_path, f"model_epoch_{epoch - 1}_params_{hyperparameters}.json")

    # Remove the previous checkpoint if it exists, but only for epochs greater than 1
    if epoch > 1 and os.path.exists(previous_filepath):
        os.remove(previous_filepath)
        os.remove(previous_filepath_json)

    # Prepare the checkpoint data dictionary
    checkpoint = {'model_state_dict': model.state_dict(), 'epoch': epoch}
    # If an optimizer is provided, save its state as well
    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    # Save the model and optimizer (if provided) state dictionary to the checkpoint file
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")

    # If additional data (e.g., training logs) is provided, save it to a JSON file
    if checkpoint_data:
        with open(filepath_json, 'w') as json_file:
            json.dump(checkpoint_data, json_file, indent=4)

def load_checkpoint(model, optimizer, hyperparameters, subfolder=""):
    """
    Loads the latest checkpoint available based on the specified hyperparameters.

    Arguments:
    model -- The model whose state will be updated from the checkpoint.
    optimizer -- The optimizer whose state will be updated from the checkpoint (can be None).
    hyperparameters -- A string representing the model's hyperparameters for file naming.
    subfolder -- Optional subfolder within the checkpoint directory to look for checkpoints.

    Returns:
    The next epoch to resume from and the associated JSON data if available.
    """
    # Define the path to the subfolder where checkpoints are stored
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)

    # If the subfolder doesn't exist, print a message and start from epoch 1
    if not os.path.exists(subfolder_path):
        print("No checkpoint found, starting from epoch 1.")
        return 1, None  # Epoch starts from 1

    # Search for checkpoint files in the subfolder that match the hyperparameters
    files = [f for f in os.listdir(subfolder_path) if f"params_{hyperparameters}" in f and f.endswith('.pth')]

    # If checkpoint files are found, load the one with the highest epoch number
    if files:
        latest_file = max(files, key=lambda x: int(x.split('_')[2]))  # Find the latest epoch file
        filepath = os.path.join(subfolder_path, latest_file)
        checkpoint = torch.load(filepath, weights_only=True)

        # Load the model state from the checkpoint
        model.load_state_dict(checkpoint['model_state_dict'])
        # If an optimizer is provided, load its state as well
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Try to load the associated JSON file if available
        json_filepath = os.path.join(subfolder_path, latest_file.replace('.pth', '.json'))
        json_data = None
        if os.path.exists(json_filepath):
            # If the JSON file exists, load its contents
            with open(json_filepath, 'r') as json_file:
                json_data = json.load(json_file)
            print("Data loaded!")
        else:
            # If no JSON file exists, print a message
            print("No data found")

        # Print the epoch from which the model is resuming
        print(f"Checkpoint found: Resuming from epoch {checkpoint['epoch'] + 1}\n\n")
        return checkpoint['epoch'] + 1, json_data

    # If no checkpoint is found, print a message and start from epoch 1
    print("No checkpoint found, starting from epoch 1..\n\n")
    return 1, None  # Epoch starts from 1


def delete_existing_checkpoints(subfolder=""):
    """
    Deletes all existing checkpoints in the specified subfolder.

    Arguments:
    subfolder -- Optional subfolder within the checkpoint directory to delete checkpoints from.
    """
    subfolder_path = os.path.join(CHECKPOINT_DIR, subfolder)
    if os.path.exists(subfolder_path):
        for file_name in os.listdir(subfolder_path):
            file_path = os.path.join(subfolder_path, file_name)
            if os.path.isfile(file_path):
                os.remove(file_path)
        print(f"All existing checkpoints in {subfolder_path} have been deleted.")
    else:
        print(f"No checkpoint folder found at {subfolder_path}.")


# Models

In [154]:
import torch.nn as nn
import torch.nn.functional as F
"""
Model architecture for the CIFAR-100 dataset.
The model is based on the LeNet-5 architecture with some modifications.
Reference: Hsu et al., Federated Visual Classification with Real-World Data Distribution, ECCV 2020

CNN similar to LeNet5 which has two 5×5, 64-channel convolution layers, each precedes a 2×2
max-pooling layer, followed by two fully-connected layers with 384 and 192
channels respectively and finally a softmax linear classifier
"""


import torch.nn as nn
import torch.nn.functional as F
"""
Model architecture for the CIFAR-100 dataset.
The model is based on the LeNet-5 architecture with some modifications.
Reference: Hsu et al., Federated Visual Classification with Real-World Data Distribution, ECCV 2020

CNN similar to LeNet5 which has two 5×5, 64-channel convolution layers, each precedes a 2×2
max-pooling layer, followed by two fully-connected layers with 384 and 192
channels respectively and finally a softmax linear classifier
"""


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc_layer = nn.Sequential(
            nn.Linear(64 * 3 * 3, 384),  # Updated to be consistent with data augmentation
            nn.ReLU(),
            nn.Linear(384, 192),
            nn.ReLU(),
            nn.Linear(192, 100)  # 100 classes for CIFAR-100
        )

    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)  # Flatten the output of the conv layers
        x = self.fc_layer(x)
        x = F.log_softmax(x, dim=1)
        return x

# Training

### Data Sharding for IID (Indipendent and Identically Distributed) FL Simulation

In [155]:
import numpy as np

def sharding(dataset, number_of_clients, number_of_classes=100):
    """
    Function that performs the sharding of the dataset given as input.
    dataset: dataset to be split;
    number_of_clients: the number of partitions we want to obtain;
    number_of_classes: (int) the number of classes inside each partition, or 100 for IID;
    """

    # Validation of input parameters
    if not (1 <= number_of_classes <= 100):
        raise ValueError("number_of_classes should be an integer between 1 and 100")

    # Shuffle dataset indices for randomness
    indices = np.random.permutation(len(dataset))

    # Compute basic partition sizes
    basic_partition_size = len(dataset) // number_of_clients
    remainder = len(dataset) % number_of_clients

    shards = []
    start_idx = 0

    if number_of_classes == 100:  # IID Case
        # Equally distribute indices among clients: we can just randomly assign to each client an equal amount of records
        for i in range(number_of_clients):
            end_idx = start_idx + basic_partition_size + (1 if i < remainder else 0)
            shards.append(Subset(dataset, indices[start_idx:end_idx]))
            start_idx = end_idx
    else:  # non-IID Case
        # Count of each class in the dataset
        from collections import Counter
        target_counts = Counter(target for _, target in dataset)

        # Calculate per client class allocation
        class_per_client = np.random.choice(list(target_counts.keys()), size=number_of_classes, replace=False)
        class_idx = {class_: np.where([target == class_ for _, target in dataset])[0] for class_ in class_per_client}

        # Assign class indices evenly to clients
        for i in range(number_of_clients):
            client_indices = np.array([], dtype=int)
            for class_ in class_per_client:
                n_samples = len(class_idx[class_]) // number_of_clients + (1 if i < remainder else 0)
                client_indices = np.concatenate((client_indices, class_idx[class_][:n_samples]))
                class_idx[class_] = np.delete(class_idx[class_], np.arange(n_samples))

            shards.append(Subset(dataset, indices=client_indices))

    return shards

### Client Local Update

# Client class

In [156]:
from torch.backends import cudnn
import time


class Client:
    def __init__(self, client_id, data_loader, model, device):
        """
        Initializes a federated learning client.
        :param client_id: Unique identifier for the client.
        :param data_loader: Data loader specific to the client.
        :param model: The model class to be used by the client.
        :param device: The device (CPU/GPU) to perform computations.
        """
        self.client_id = client_id
        self.data_loader = data_loader
        self.model = model.to(device)
        self.device = device

    def client_update(self, client_data, criterion, optimizer, local_steps=4, detailed_print=False):
        """
        Trains a given client's local model on its dataset for a fixed number of steps (`local_steps`).

        Args:
            model (nn.Module): The local model to be updated.
            client_id (int): Identifier for the client (used for logging/debugging purposes).
            client_data (DataLoader): The data loader for the client's dataset.
            criterion (Loss): The loss function used for training (e.g., CrossEntropyLoss).
            optimizer (Optimizer): The optimizer used for updating model parameters (e.g., SGD).
            local_steps (int): Number of local epochs to train on the client's dataset.
            detailed_print (bool): If True, logs the final loss after training.

        Returns:
            dict: The state dictionary of the updated model.
        """


        cudnn.benchmark  # Calling this optimizes runtime

        self.model.train()  # Set the model to training mode
        step_count = 0
        while step_count < local_steps:
            for data, targets in client_data:
                # Move data and targets to the specified device (e.g., GPU or CPU)
                data, targets = data.to(DEVICE), targets.to(DEVICE)


                start_time = time.time()  # for testing-----------------------------

                # Reset the gradients before backpropagation
                optimizer.zero_grad()

                # Forward pass: compute model predictions
                outputs = self.model(data)

                # Compute the loss
                loss = criterion(outputs, targets)

                # Backward pass: compute gradients and update weights
                loss.backward()
                optimizer.step()

                # for testing ------------------------------------------------------
                if detailed_print:
                  end_time = time.time()  # Record the end time
                  elapsed_time = end_time - start_time  # Calculate the elapsed time
                  print(f'Single step time taken: {elapsed_time:.4f} seconds')

                step_count +=1
                if step_count >= local_steps:
                  break

        # Optionally, print the loss for the last epoch of training
        if detailed_print:
          print(f'Client {self.client_id} --> Final Loss (Step {step_count}/{local_steps}): {loss.item()}')


        # Return the updated model's state dictionary (weights)
        return self.model.state_dict()

### Central Server Aggregation with FedAvg

# Server class

In [157]:
class Server:
    def __init__(self, global_model):
        self.global_model = global_model

    def fedavg_aggregate(self, client_states, client_sizes):
        # Aggregation logic
        new_state = deepcopy(self.global_model.state_dict())
        total_samples = sum(client_sizes)
        for key in new_state:
            new_state[key] = torch.zeros_like(new_state[key])
        for state, size in zip(client_states, client_sizes):
            for key in new_state:
                new_state[key] += state[key] * size / total_samples
        return new_state



    # Federated Learning Training Loop
    def train_federated(self, criterion, trainloader, validloader, num_clients, num_classes, rounds, lr, momentum, batchsize, wd, C=0.1, local_steps=4, log_freq=10, detailed_print=False):
        val_accuracies = []
        val_losses = []
        train_accuracies = []
        train_losses = []
        best_model_state = None  # The model with the best accuracy
        client_selection_count = [0] * num_clients #Count how many times a client has been selected
        best_val_acc = 0.0

        shards = sharding(trainloader.dataset, num_clients, num_classes) #each shard represent the training data for one client
        client_sizes = [len(shard) for shard in shards]

        self.global_model.to(DEVICE) #as alwayse, we move the global model to the specified device (CPU or GPU)

        #loading checkpoint if it exists
        checkpoint_start_step, data_to_load = load_checkpoint(model=global_model,optimizer=None,hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/")
        if data_to_load is not None:
          val_accuracies = data_to_load['val_accuracies']
          val_losses = data_to_load['val_losses']
          train_accuracies = data_to_load['train_accuracies']
          train_losses = data_to_load['train_losses']
          client_selection_count = data_to_load['client_selection_count']


        # ********************* HOW IT WORKS ***************************************
        # The training runs for rounds iterations (GLOBAL_ROUNDS=2000)
        # Each round simulates one communication step in federated learning, including:
        # 1) client selection
        # 2) local training (of each client)
        # 3) central aggregation
        for round_num in range(checkpoint_start_step, rounds):
            if round_num % log_freq == 0:
              print(f"------------------------------------- Round {round_num} ------------------------------------------------" )

            #start_time = time.time()  # for testing-----------------------------

            # 1) client selection: In each round, a fraction C (e.g., 10%) of clients is randomly selected to participate.
            #     This reduces computation costs and mimics real-world scenarios where not all devices are active.
            selected_clients = random.sample(range(num_clients), int(C * num_clients))
            client_states = []
            for client_id in selected_clients:
                client_selection_count[client_id] += 1

            # 2) local training: for each client updates the model using the client's data for local_steps epochs
            for client_id in selected_clients:
                local_model = deepcopy(self.global_model) #it creates a local copy of the global model
                optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #same of the centralized version
                client_loader = DataLoader(shards[client_id], batch_size=batchsize, shuffle=True)

                client = Client(client_id, client_loader, local_model, DEVICE)

                local_state = client.client_update(client_loader, criterion, optimizer, local_steps, round_num % log_freq == 0 and detailed_print)
                client_states.append(local_state)


            # 3) central aggregation: aggregates participating client updates using fedavg_aggregate
            #    and replaces the current parameters of global_model with the returned ones.
            aggregated_state = self.fedavg_aggregate(client_states, [client_sizes[i] for i in selected_clients])
            self.global_model.load_state_dict(aggregated_state)

            #Validation at the server
            #if round_num % log_freq:
            val_accuracy, val_loss = evaluate(self.global_model, validloader,criterion)
            val_accuracies.append(val_accuracy)
            val_losses.append(val_loss)
            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy
                best_model_state = deepcopy(self.global_model.state_dict())

            if round_num % log_freq == 0:
                train_accuracy, train_loss = evaluate(self.global_model, trainloader,criterion)
                train_accuracies.append(train_accuracy)
                train_losses.append(train_loss)

                print(f"--> best validation accuracy: {best_val_acc}\n--> training accuracy: {train_accuracy}")
                print(f"--> validation loss: {val_loss}\n--> training loss: {train_loss}")

                # checkpointing
                checkpoint_data = {
                    'val_accuracies': val_accuracies,
                    'val_losses': val_losses,
                    'train_accuracies': train_accuracies,
                    'train_losses': train_losses,
                    'client_selection_count': client_selection_count
                }
                save_checkpoint(model=self.global_model, optimizer=None, epoch=round_num, hyperparameters=f"LR{lr}_WD{wd}", subfolder="Federated/", checkpoint_data=checkpoint_data)

                print(f"------------------------------ Round {round_num} terminated: model updated -----------------------------\n\n" )


            # for testing ------------------------------------------------------
            #end_time = time.time()  # Record the end time
            #elapsed_time = end_time - start_time  # Calculate the elapsed time
            #print(f'Single round time taken: {elapsed_time:.4f} seconds\n\n')


        global_model.load_state_dict(best_model_state)

        return global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count


In [158]:
def evaluate(model, dataloader, criterion):

    with torch.no_grad():
        model.train(False) # Set Network to evaluation mode
        running_corrects = 0
        losses = []

        for data, targets in dataloader:
            data = data.to(DEVICE)        # Move the data to the GPU
            targets = targets.to(DEVICE)  # Move the targets to the GPU

            # Forward Pass
            outputs = model(data)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
            # Get predictions
            _, preds = torch.max(outputs.data, 1)
            # Update Corrects
            running_corrects += torch.sum(preds == targets.data).data.item()
            # Calculate Accuracy
            accuracy = running_corrects / float(len(dataloader.dataset))

    return accuracy, mean(losses)

### Federated Learning Training Loop

### Initialize Model & Loss

In [159]:
global_model = LeNet5()
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100


### Run the training

In [None]:
#just for now
lr = LR
wd = WEIGHT_DECAY
#delete_existing_checkpoints("Federated/")
# Run Federated Learning
# Instantiate the server
server = Server(global_model)
#run federeted learning
global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(
    criterion=criterion,
    trainloader=trainloader,
    validloader=validloader,
    num_clients=NUM_CLIENTS,
    num_classes=100,
    rounds=GLOBAL_ROUNDS,
    lr=lr,
    momentum=MOMENTUM,
    batchsize=BATCH_SIZE,
    wd=wd,
    C=FRACTION_CLIENTS,
    local_steps=LOCAL_STEPS,
    log_freq=LOG_FREQUENCY,
    detailed_print=False
)

No checkpoint found, starting from epoch 1.
------------------------------------- Round 10 ------------------------------------------------
--> best validation accuracy: 0.0108
--> training accuracy: 0.009888888888888888
--> validation loss: 4.605470189623013
--> training loss: 4.605592923547329
Checkpoint saved: ../checkpoints/Federated/model_epoch_10_params_LR0.01_WD0.0001.pth
------------------------------ Round 10 terminated: model updated -----------------------------


------------------------------------- Round 20 ------------------------------------------------
--> best validation accuracy: 0.011
--> training accuracy: 0.01071111111111111
--> validation loss: 4.604295551397239
--> training loss: 4.604527847599119
Checkpoint saved: ../checkpoints/Federated/model_epoch_20_params_LR0.01_WD0.0001.pth
------------------------------ Round 20 terminated: model updated -----------------------------


------------------------------------- Round 30 ---------------------------------------

# Validation

### Run the test



In [None]:
accuracy = evaluate(global_model, testloader, criterion)[0]
print('\nTest Accuracy: {}'.format(accuracy))