In [None]:
"""Federated Learning Hyperparameters Configuration.
Change this cell to set the hyperparameters for your federated learning experiment."""

# Number of clients
K = 100
# Classes per client
N = 100
# Fraction of clients
C = 0.1
# Number of local steps
J = 4

In [None]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from data.cifar100_loader import get_federated_cifar100_dataloaders
from model.prepare_model import get_dino_vits16_model, freeze_backbone
from model.hyperparameter_tuning import run_grid_search_federated
from eval import evaluate
from train import train
from model.federated_averaging import get_trainable_keys, train_on_client, average_metrics, average_models

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for data and checkpoints
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Create dataloaders for the clients
federated_test = False
train_datasets, val_loader, test_loaders, client_class_map = get_federated_cifar100_dataloaders(K, N, 
                                                                                                  federatedTest=federated_test, val_split=0.1)

criterion = nn.CrossEntropyLoss()

In [None]:
# Define hyperparameter configurations
configs = [
    {'lr': 0.01},
    {'lr': 0.005},
    {'lr': 0.001}
]

# Hyperparameter tuning
best_result, results = run_grid_search_federated(train_datasets, val_loader, get_dino_vits16_model, 
                                              criterion, configs, num_clients=K, C=C, steps=J, device=device)

In [None]:
# Retrieve the best configuration and results
best_cfg = best_result['cfg']

# Load the model for federated learning
collaborative_model = get_dino_vits16_model(device)

freeze_backbone(collaborative_model)

In [None]:
train_datasets, _, test_loaders, client_class_map = get_federated_cifar100_dataloaders(K, N,
                                                                                        federatedTest=federated_test, val_split=0)

In [None]:
start_round = 0
num_rounds = 800 // J
best_test_acc = 0.0

hist_train_loss = []
hist_train_acc = []
hist_test_loss = []
hist_test_acc = []

random.seed(42)

In [None]:
# FedAvg loop
for round in range(start_round, start_round + num_rounds):
    print(f"\n--- Round {round + 1}/{start_round + num_rounds} ---")

    # Select clients
    selected_clients = random.sample(range(K), int(C * K))

    # Local training
    local_models, train_losses, train_accs = [], [], []
    for client_id in selected_clients:
        model_state, loss, acc = train_on_client(
            client_id,
            collaborative_model,
            train_datasets[client_id],
            J,
            criterion,
            best_cfg['lr'],
            device
        )
        local_models.append(model_state)
        train_losses.append(loss)
        train_accs.append(acc)

    # Weighting by dataset size
    client_sample_counts = [len(train_datasets[c]) for c in selected_clients]
    total_samples = sum(client_sample_counts)
    client_weights = [count / total_samples for count in client_sample_counts]

    # Federated averaging
    trainable_keys = get_trainable_keys(collaborative_model)
    averaged_state = average_models(local_models, client_weights, trainable_keys)
    new_state = collaborative_model.state_dict()
    for key in averaged_state:
        new_state[key] = averaged_state[key]
    collaborative_model.load_state_dict(new_state)

    # Log average training metrics
    avg_train_loss = average_metrics(train_losses, client_weights)
    avg_train_acc = average_metrics(train_accs, client_weights)
    print(f"Avg Train Loss: {avg_train_loss:.4f}, Avg Train Accuracy: {avg_train_acc:.4f}")
    hist_train_loss.append(avg_train_loss)
    hist_train_acc.append(avg_train_acc)

    if federated_test:
        # Evaluation on all clients
        test_losses, test_accs = [], []
        for client_id in range(K):
            loss, acc = evaluate(collaborative_model, test_loaders[client_id], criterion, device)
            test_losses.append(loss)
            test_accs.append(acc)

        # Weighted test metrics
        test_sample_counts = [len(test_loaders[c].dataset) for c in range(K)]
        total_test_samples = sum(test_sample_counts)
        test_weights = [count / total_test_samples for count in test_sample_counts]

        avg_test_loss = average_metrics(test_losses, test_weights)
        avg_test_acc = average_metrics(test_accs, test_weights)
    else:
        # Evaluation on the aggregated test set
        avg_test_loss, avg_test_acc = evaluate(collaborative_model, test_loaders, criterion, device)
        
    print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
    hist_test_loss.append(avg_test_loss)
    hist_test_acc.append(avg_test_acc)

    # Save the model each 10 rounds
    if (round + 1) % 10 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"federated_round_{round + 1}_{N}class.pth")
        torch.save(collaborative_model.state_dict(), checkpoint_path)

In [None]:
# Plot the training and test loss
plt.plot(hist_train_loss, label='Train Loss')
plt.plot(hist_test_loss, label='Test Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(hist_train_acc, label='Train Accuracy')
plt.plot(hist_test_acc, label='Test Accuracy')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.legend()
plt.show()