# Federated Learning Project
This notebook demonstrates how to set up and compare Federated Learning (FL) with Centralized Learning (CL) using the CIFAR-100 dataset and the modified version of the LeNet-5 model taken from [Hsu et al., Federated Visual Classification with Real-World Data Distribution, ECCV 2020].

## 1. Setup
We start by importing necessary libraries and setting global constants for the experiments.

In [1]:
import sys
import torch
import torch.nn as nn

from models.model import LeNet5 #import the model
import numpy as np
sys.path.append('../data/cifar100/')
from cifar100_loader import CIFAR100DataLoader
from Server import Server
from utils.federated_utils import plot_metrics,test, plot_client_selection,save_data,load_data

# Constants

In [3]:
# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

NUM_CLIENTS = 100  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_STEPS = 4  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

BATCH_SIZE = 64  # Batch size for local training
LR = 1e-3  # Initial learning rate for local optimizers
MOMENTUM = 0.9  # Momentum for SGD optimizer
WEIGHT_DECAY = 0.0001  # Regularization term for local training
CHECKPOINT_DIR = './checkpoints/'
LOG_FREQUENCY = 10  # Frequency of logging training progress

cuda


## 2. Data Loading
We load the CIFAR-100 dataset and split it into training, validation, and test sets. This is done using the `data_loader.py` module.

In [3]:
#10% of the dataset kept for validation
data_loader = CIFAR100DataLoader(batch_size=BATCH_SIZE, validation_split=0.1, download=True, num_workers=4, pin_memory=True)
trainloader, validloader, testloader = data_loader.train_loader, data_loader.val_loader, data_loader.test_loader

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## 3. Federated Training
We simulate federated learning by splitting the dataset into shards and training with selected clients in each round.

### Initialize Model & Loss

In [2]:
global_model = LeNet5()
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100

In [14]:
import os
def load_data(model, file_name):
    """
    Load the model weights and metrics from a file.
    
    Args:
        model (nn.Module): The model to load the weights into.
        file_name (str): Name of the file to load the data from.
    
    Returns:
        tuple: A tuple containing the model, val_accuracies, val_losses, train_accuracies, and train_losses.
    """
    # Fixed base directory
    directory = './trained_models/'
    # Complete path for the file
    file_path = os.path.join(directory, file_name)
    
    # Load the saved data from the specified file
    save_dict = torch.load(file_path)
    
    # Load the model state
    model.load_state_dict(save_dict['model_state'])
    
    # Extract the metrics
    val_accuracies = save_dict['val_accuracies']
    val_losses = save_dict['val_losses']
    train_accuracies = save_dict['train_accuracies']
    train_losses = save_dict['train_losses']
    
    print(f"Data loaded successfully from {file_path}")
    
    return model, val_accuracies, val_losses, train_accuracies, train_losses


mm = LeNet5()

model, val_accuracies, val_losses, train_accuracies, train_losses = load_data(mm,'FederatedBaseline.pth')
val_accuracies = [val_accuracies[i]*100 for i in range(len(val_accuracies))]
plot_metrics(train_losses,train_accuracies,val_losses,val_accuracies,f"FederatedBaselineTuning_lr_{0.1}_wd_{0.001}.png")


  save_dict = torch.load(file_path)


Data loaded successfully from ./trained_models/FederatedBaseline.pth


# Hyperparameters tuning

In [None]:
""" 
Hyperparameter tuning for the learning rate and weight decay
J=4, rounds = 100
"""
# Generate 3 values for the learning rate (lr) between 1e-3 and 1e-1 in log-uniform
lr_values = np.logspace(-3, -1, num=3)

# Generate 4 values for the weight decay (lr) between 1e-4 and 1e-1 in log-uniform
wd_values = np.logspace(-4, -1, num=4)

print("Learning Rate Values (log-uniform):", lr_values)
print("Weight Decay Values (log-uniform):", wd_values)

rounds = 100 #fewer communication rounds for hyperparameter tuning
best_val_accuracy = 0
best_setting = None
for lr in lr_values:
    for wd in wd_values:
        print(f"Learning rate: {lr}, Weight decay: {wd}")
        global_model = LeNet5() 
        server = Server(global_model, DEVICE, CHECKPOINT_DIR)                                                                   
        global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=100, rounds=rounds, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=LOCAL_STEPS)
        plot_metrics(train_accuracies, train_losses,val_accuracies, val_losses, f"FederatedBaselineTuning_lr_{lr}_wd_{wd}.png")
        print(f"Validation accuracy: {val_accuracies[-1]} with lr: {lr} and wd: {wd}")
        max_val_accuracy = max(val_accuracies)
        if max_val_accuracy > best_val_accuracy:
            best_val_accuracy = max_val_accuracy
            best_setting = (lr,wd)
print(f"Best setting: {best_setting} with validation accuracy: {best_val_accuracy}")

Learning Rate Values (log-uniform): [0.001 0.01  0.1  ]
Weight Decay Values (log-uniform): [0.0001 0.001  0.01   0.1   ]
Learning rate: 0.001, Weight decay: 0.0001


No checkpoint found, starting from epoch 1.
------------------------------------- Round 10 ------------------------------------------------
Client 43 --> Final Loss (Step 4/4): 4.602455139160156
Client 76 --> Final Loss (Step 4/4): 4.612491607666016
Client 74 --> Final Loss (Step 4/4): 4.620962619781494
Client 11 --> Final Loss (Step 4/4): 4.60128116607666
Client 14 --> Final Loss (Step 4/4): 4.6037163734436035
Client 91 --> Final Loss (Step 4/4): 4.608407974243164
Client 85 --> Final Loss (Step 4/4): 4.60923957824707
Client 13 --> Final Loss (Step 4/4): 4.615657329559326
Client 33 --> Final Loss (Step 4/4): 4.610918045043945
Client 59 --> Final Loss (Step 4/4): 4.602585315704346


KeyboardInterrupt: 

# Training and testing

In [None]:
""" 
Training and testing with J=4, 2000 communication rounds
"""

lr = 0.1
wd = 0.001
global_model = LeNet5() 
server = Server(global_model, DEVICE, CHECKPOINT_DIR)                                                                   
global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=100, rounds=GLOBAL_ROUNDS, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=LOCAL_STEPS)
test_accuracy = test(global_model, testloader)
print(f"Test accuracy: {test_accuracy}")
#If needed for future plots or analysis, no need to train again
save_data(global_model, val_accuracies, val_losses, train_accuracies, train_losses, "FederatedBaseline.pth")

# Tests with J = 8 and J = 16
Communication rounds accordingly reduced to 1000 and 500 respectively

# Hyperparameters tuning

In [None]:
""" 
Hyperparameter tuning with J=8, 50 communication rounds
"""
# Generate 3 values for the learning rate (lr) between 1e-3 and 1e-1 in log-uniform
lr_values = np.logspace(-3, -1, num=3)

# Generate 4 values for the weight decay (lr) between 1e-4 and 1e-1 in log-uniform
wd_values = np.logspace(-4, -1, num=4)

print("Learning Rate Values (log-uniform):", lr_values)
print("Weight Decay Values (log-uniform):", wd_values)

rounds = 50 #fewer communication rounds for hyperparameter tuning
best_val_accuracy = 0
best_setting = None
for lr in lr_values:
    for wd in wd_values:
        print(f"Learning rate: {lr}, Weight decay: {wd}")
        global_model = LeNet5() 
        server = Server(global_model, DEVICE, CHECKPOINT_DIR)                                                                   
        global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=100, rounds=rounds, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=8)
        plot_metrics(train_accuracies, train_losses,val_accuracies, val_losses, f"FederatedTuningJequalto8_lr_{lr}_wd_{wd}.png")
        print(f"Validation accuracy: {val_accuracies[-1]} with lr: {lr} and wd: {wd}")
        max_val_accuracy = max(val_accuracies)
        if max_val_accuracy > best_val_accuracy:
            best_val_accuracy = max_val_accuracy
            best_setting = (lr,wd)
print(f"Best setting: {best_setting} with validation accuracy: {best_val_accuracy}")

In [None]:
""" 
Hyperparameter tuning with J=16, 25 communication rounds
"""
# Generate 3 values for the learning rate (lr) between 1e-3 and 1e-1 in log-uniform
lr_values = np.logspace(-3, -1, num=3)

# Generate 4 values for the weight decay (lr) between 1e-4 and 1e-1 in log-uniform
wd_values = np.logspace(-4, -1, num=4)

print("Learning Rate Values (log-uniform):", lr_values)
print("Weight Decay Values (log-uniform):", wd_values)

rounds = 25 #fewer communication rounds for hyperparameter tuning
best_val_accuracy = 0
best_setting = None
for lr in lr_values:
    for wd in wd_values:
        print(f"Learning rate: {lr}, Weight decay: {wd}")
        global_model = LeNet5() 
        server = Server(global_model, DEVICE, CHECKPOINT_DIR)                                                                   
        global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=100, rounds=rounds, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=16)
        plot_metrics(train_accuracies, train_losses,val_accuracies, val_losses, f"FederatedTuningJequalto16_lr_{lr}_wd_{wd}.png")
        print(f"Validation accuracy: {val_accuracies[-1]} with lr: {lr} and wd: {wd}")
        max_val_accuracy = max(val_accuracies)
        if max_val_accuracy > best_val_accuracy:
            best_val_accuracy = max_val_accuracy
            best_setting = (lr,wd)
print(f"Best setting: {best_setting} with validation accuracy: {best_val_accuracy}")

# Training and testing

In [None]:
""" 
8 local steps, 1000 rounds
"""
#lr and wd to be defined based on the hyperparameter tuning
#lr = 0.1
#wd = 0.001
global_model = LeNet5() 
server = Server(global_model, DEVICE, CHECKPOINT_DIR)                                                                   
global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=100, rounds=1000, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=8)
test_accuracy = test(global_model, testloader)
print(f"Test accuracy: {test_accuracy}")
#If needed for future plots or analysis, no need to train again
save_data(global_model, val_accuracies, val_losses, train_accuracies, train_losses, "FederatedJequalto8.pth")

In [None]:
"""
16 local steps, 500 rounds
"""
#lr and wd to be defined based on the hyperparameter tuning
#lr = 0.1
#wd = 0.001
global_model = LeNet5() 
server = Server(global_model, DEVICE, CHECKPOINT_DIR)                                                                   
global_model, val_accuracies, val_losses, train_accuracies, train_losses, client_selection_count = server.train_federated(criterion, trainloader, validloader, num_clients=NUM_CLIENTS, num_classes=100, rounds=500, lr=lr, momentum=MOMENTUM, batchsize=BATCH_SIZE, wd=wd, C=FRACTION_CLIENTS, local_steps=16)
test_accuracy = test(global_model, testloader)
print(f"Test accuracy: {test_accuracy}")
#If needed for future plots or analysis, no need to train again
save_data(global_model, val_accuracies, val_losses, train_accuracies, train_losses, "FederatedJequalto16.pth")