# Federated Learning Project
This notebook demonstrates how to set up and compare Federated Learning (FL) with Centralized Learning (CL) using the CIFAR-100 dataset and the LeNet-5 model.

## 1. Setup
We start by importing necessary libraries and setting global constants for the experiments.

In [16]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
import random
from torch.utils.data import Subset, DataLoader
from statistics import mean

from models.model import LeNet5 #import the model

sys.path.append('../data/cifar100/')
from cifar100_loader import load_cifar100

from federated_utils import sharding, client_update, fedavg_aggregate

# Constants for FL training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_CLIENTS = 100  # Total number of clients in the federation
FRACTION_CLIENTS = 0.1  # Fraction of clients selected per round (C)
LOCAL_EPOCHS = 4  # Number of local steps (J)
GLOBAL_ROUNDS = 2000  # Total number of communication rounds

BATCH_SIZE = 64  # Batch size for local training
LR = 0.005  # Initial learning rate for local optimizers: best one from the centralized one
MOMENTUM = 0.9  # Momentum for SGD optimizer
WEIGHT_DECAY = 5e-5  # Regularization term for local training

LOG_FREQUENCY = 10  # Frequency of logging training progress

## 2. Data Loading
We load the CIFAR-100 dataset and split it into training, validation, and test sets. This is done using the `data_loader.py` module.

In [17]:
#load the dataset
trainloader, validloader, testloader = load_cifar100(batch_size=BATCH_SIZE, validation_split=0.25)

print("Data loaded successfully!\n")
print("Dimension of the training dataset:", len(trainloader.dataset))
print("Dimension of the validation dataset:", len(validloader.dataset))
print("Dimension of the test dataset:", len(testloader.dataset))

Files already downloaded and verified
Files already downloaded and verified
Data loaded successfully!

Dimension of the training dataset: 37500
Dimension of the validation dataset: 12500
Dimension of the test dataset: 10000


## 3. Federated Training
We simulate federated learning by splitting the dataset into shards and training with selected clients in each round.

### Initialize Model & Loss

In [18]:
global_model = LeNet5()
criterion = nn.NLLLoss()# our loss function for classification tasks on CIFAR-100

### Federated Learning Training Loop

In [19]:
# Federated Learning Training Loop
def federated_training(global_model, criterion, dataset, valid_dataset, num_clients, rounds, C=0.1, local_steps=4,detailed_print=False):
    shards = sharding(dataset, num_clients) #each shard represent the training data for one client
    client_sizes = [len(shard) for shard in shards]

    global_model.to(DEVICE) #as alwayse, we move the global model to the specified device (CPU or GPU)

    # ********************* HOW IT WORKS ***************************************
    # The training runs for rounds iterations (GLOBAL_ROUNDS=2000)
    # Each round simulates one communication step in federated learning, including:
    # 1) client selection
    # 2) local training (of each client)
    # 3) central aggregation
    for round_num in range(rounds):
        if round_num % LOG_FREQUENCY == 0 and detailed_print:
          print(f"------------------------------------- Round {round_num} ------------------------------------------------" )
        # 1) client selection: In each round, a fraction C (e.g., 10%) of clients is randomly selected to participate.
        #     This reduces computation costs and mimics real-world scenarios where not all devices are active.
        selected_clients = random.sample(range(num_clients), int(C * num_clients))
        client_states = []

        # 2) local training: for each client updates the model using the client's data for local_steps epochs
        for client_id in selected_clients:
            local_model = deepcopy(global_model) #it creates a local copy of the global model
            optimizer = optim.SGD(local_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY) #same of the centralized version
            client_loader = DataLoader(shards[client_id], batch_size=BATCH_SIZE, shuffle=True)

            local_state = client_update(local_model, client_id, client_loader, criterion, optimizer, local_steps, round_num % LOG_FREQUENCY == 0 and detailed_print)
            client_states.append(local_state)

        # 3) central aggregation: aggregates participating client updates using fedavg_aggregate
        #    and replaces the current parameters of global_model with the returned ones.
        global_model.load_state_dict(fedavg_aggregate(global_model, client_states, [client_sizes[i] for i in selected_clients]))

        # Validation at the server (optional, add metrics here)
        if round_num % LOG_FREQUENCY == 0 and detailed_print:
              print(f"------------------------------ Round {round_num} terminated: model updated -----------------------------\n\n" )
        #if round_num % LOG_FREQUENCY:
          #val_loss, val_accuracy = evaluate_model(global_model, valid_dataset, criterion)
          #print(f"Round {round_num}: Validation Loss = {val_loss:.4f}, Accuracy = {val_accuracy:.2f}%")


    return global_model #the updated global model is returned.

### Run the training

In [20]:
# Run Federated Learning
refined_model = federated_training(global_model, criterion, trainloader.dataset, validloader.dataset, num_clients=NUM_CLIENTS, rounds=GLOBAL_ROUNDS, C=FRACTION_CLIENTS, local_steps=LOCAL_EPOCHS, detailed_print=True)

------------------------------------- Round 0 ------------------------------------------------
Client 26 --> Final Loss (Epoch 4): 4.602545261383057
Client 13 --> Final Loss (Epoch 4): 4.594453811645508
Client 31 --> Final Loss (Epoch 4): 4.611208915710449
Client 39 --> Final Loss (Epoch 4): 4.608333110809326
Client 4 --> Final Loss (Epoch 4): 4.599559307098389
Client 58 --> Final Loss (Epoch 4): 4.595484733581543
Client 49 --> Final Loss (Epoch 4): 4.6042890548706055
Client 98 --> Final Loss (Epoch 4): 4.601893901824951
Client 92 --> Final Loss (Epoch 4): 4.609266757965088
Client 15 --> Final Loss (Epoch 4): 4.601667881011963
------------------------------ Round 0 terminated: model updated -----------------------------


------------------------------------- Round 10 ------------------------------------------------
Client 43 --> Final Loss (Epoch 4): 4.610133171081543
Client 20 --> Final Loss (Epoch 4): 4.604770660400391
Client 99 --> Final Loss (Epoch 4): 4.592456340789795
Client 68 

KeyboardInterrupt: 

## 4. Validation

In [None]:
def evaluate(model, dataloader, criterion):
    with torch.no_grad():
        model.train(False) # Set Network to evaluation mode
        running_corrects = 0
        losses = []
        for data, targets in dataloader:
            data = data.to(DEVICE)        # Move the data to the GPU
            targets = targets.to(DEVICE)  # Move the targets to the GPU
            # Forward Pass
            outputs = model(data)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
            # Get predictions
            _, preds = torch.max(outputs.data, 1)
            # Update Corrects
            running_corrects += torch.sum(preds == targets.data).data.item()
            # Calculate Accuracy
            accuracy = running_corrects / float(len(dataloader.dataset))

    return accuracy, mean(losses)

### Run the test

In [None]:
accuracy = evaluate(refined_model, testloader, criterion)[0]
print('\nTest Accuracy: {}'.format(accuracy))