In [None]:
"""Federated Learning Configuration for unbalanced data distribution.
Change this cell to configure the experiment parameters:
    - disco: Use discrepancy-aware weighting (FedDISCO)
    - sama: Use severity-aware masking (FedSAMA)
"""

# Number of clients
K = 45
# Fraction of clients
C = 0.2
# Number of local steps
J = 64
# DISCO
disco = True
# SAMA
sama = True

In [None]:
# Baseline
baseline = not (sama or disco)

In [None]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import random
from data.cifar100_loader import get_unbalanced_cifar100_datasets
from eval import evaluate
from train import train, train_steps
from model.prepare_model import get_dino_vits16_model, freeze_backbone, unfreeze_backbone, freeze_head, unfreeze_head
from model.unbalance import compute_kl_discrepancy, compute_discrepancy_aware_weights_sigmoid, compute_severity, get_class_distribution_vector
from model.model_editing import mask_calculator, freeze_and_clean_client_masks
from model.federated_averaging import train_on_client, average_metrics, average_models, get_trainable_keys
from tqdm import tqdm

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for data and checkpoints
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Create dataloaders for the clients
train_datasets, test_loader, client_class_map, client_metadata = get_unbalanced_cifar100_datasets(K) # Default division by size and bias

criterion = nn.CrossEntropyLoss()

In [None]:
# Compute KL divergence for each client
discrepancies = []
for client_dataset in train_datasets:
    kl_div = compute_kl_discrepancy(client_dataset)
    discrepancies.append(kl_div)

In [None]:
# Compute size of each client dataset
# and optionally compute severity and discrepancy-aware weights
client_sizes = [len(dataset) for dataset in train_datasets]
if sama:
    severities = compute_severity(client_sizes, discrepancies)
if disco:
    weights_disc = compute_discrepancy_aware_weights_sigmoid (client_sizes, discrepancies, tau=0.1)

In [None]:
# Plot client sizes
plt.figure(figsize=(10, 5))
plt.bar(range(K), client_sizes, color='blue', alpha=0.7)
plt.xlabel('Client ID')
plt.ylabel('Number of samples')
plt.title('Number of samples per client')
plt.xticks(range(K), rotation=90)
plt.tight_layout()
plt.show()

In [None]:
# Plot client discrepancies
plt.figure(figsize=(10, 5))
plt.bar(range(K), discrepancies, color='orange', alpha=0.7)
plt.xlabel('Client ID')
plt.ylabel('Discrepancy')
plt.title('Discrepancy per client')
plt.xticks(range(K), rotation=90)
plt.tight_layout()
plt.show()

In [None]:
if sama:
    # Plot client severity
    plt.figure(figsize=(10, 5))
    plt.bar(range(K), severities, color='red', alpha=0.7)
    plt.xlabel('Client ID')
    plt.ylabel('Client severity')
    plt.title('Client Severity')
    plt.xticks(range(K), rotation=90)
    plt.tight_layout()
    plt.show()

In [None]:
if disco:
    # Plot client weights
    plt.figure(figsize=(10, 5))
    plt.bar(range(K), weights_disc, color='green', alpha=0.7)
    plt.xlabel('Client ID')
    plt.ylabel('Client weight')
    plt.title('Client weights')
    plt.xticks(range(K), rotation=90)
    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_client_scatter_weights_sizes_discrepancies(client_metrics, discrepancies, client_sizes, metric=None):
    """
    Scatter plot of client size vs discrepancy, with weight/severity shown by color and point size.
    Y-axis is limited to [0, 1].
    """
    client_metrics = np.array(client_metrics)
    discrepancies = np.array(discrepancies)
    client_sizes = np.array(client_sizes)

    # Normalize weight for marker size
    point_sizes = 300 * (client_metrics / client_metrics.max())
    point_colors = client_metrics

    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(client_sizes, discrepancies, c=point_colors, s=point_sizes,
                          cmap='viridis', edgecolors='k', alpha=0.8)

    plt.colorbar(scatter, label=f'Client {metric}')
    plt.xlabel('Client Size')
    plt.ylabel('Discrepancy')
    plt.title(f'Client Size vs Discrepancy (Color & Size = {metric})')
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
if sama:
    plot_client_scatter_weights_sizes_discrepancies(severities,
                                                discrepancies,
                                                client_sizes,
                                                metric='Severity')
if disco:
    plot_client_scatter_weights_sizes_discrepancies(weights_disc,
                                                discrepancies,
                                                client_sizes,
                                                metric='Weight')

In [None]:
# Get the collaborative model
collaborative_model = get_dino_vits16_model(device)

# Freeze the backbone
freeze_backbone(collaborative_model)

In [None]:
# Define warmup parameters
warmup_rounds = 5
warmup_steps = 64

In [None]:
start_round = 0
num_rounds = warmup_rounds
best_test_acc = 0.0

warmup_train_loss = []
warmup_train_acc = []
warmup_test_loss = []
warmup_test_acc = []

random.seed(42)

In [None]:
# Perform warmup training
print("--- Starting Federated Averaging Warmup ---")
# FedAvg loop
for round in range(start_round, start_round + num_rounds):
    print(f"\n--- Round {round + 1}/{start_round + num_rounds} ---")

    # Select clients
    selected_clients = random.sample(range(K), int(C * K))

    # Local training
    local_models, train_losses, train_accs = [], [], []
    for client_id in selected_clients:
        model_state, loss, acc = train_on_client(
            client_id,
            collaborative_model,
            train_datasets[client_id],
            warmup_steps,
            criterion,
            lr = 0.01,
            device = device
        )
        local_models.append(model_state)
        train_losses.append(loss)
        train_accs.append(acc)

    # Weighting by dataset size
    client_sample_counts = [len(train_datasets[c]) for c in selected_clients]
    total_samples = sum(client_sample_counts)
    client_weights_metrics = [count / total_samples for count in client_sample_counts]

    if not disco:
        # Weighting uniformly
        client_weights = [1.0 / len(selected_clients)] * len(selected_clients)
    else:
        # Use computed client weights
        tot_weights_disc = sum(weights_disc[c] for c in selected_clients)
        client_weights = [weights_disc[c] / tot_weights_disc for c in selected_clients]

    # Federated averaging
    trainable_keys = get_trainable_keys(collaborative_model)
    averaged_state = average_models(local_models, client_weights, trainable_keys)
    new_state = collaborative_model.state_dict()
    for key in averaged_state:
        new_state[key] = averaged_state[key]
    collaborative_model.load_state_dict(new_state)

    # Log average training metrics
    avg_train_loss = average_metrics(train_losses, client_weights_metrics)
    avg_train_acc = average_metrics(train_accs, client_weights_metrics)
    print(f"Avg Train Loss: {avg_train_loss:.4f}, Avg Train Accuracy: {avg_train_acc:.4f}")
    warmup_train_loss.append(avg_train_loss)
    warmup_train_acc.append(avg_train_acc)

    # Evaluate on test set
    avg_test_loss, avg_test_acc = evaluate(collaborative_model, test_loader, criterion, device)

    print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
    warmup_test_loss.append(avg_test_loss)
    warmup_test_acc.append(avg_test_acc)

torch.save(collaborative_model.state_dict(), os.path.join(CHECKPOINT_DIR, f'pre_trained_federated_unbalance_model_{J}steps.pth'))

In [None]:
# Plot the training and test loss
plt.plot(warmup_train_loss, label='Train Loss')
plt.plot(warmup_test_loss, label='Test Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(warmup_train_acc, label='Train Accuracy')
plt.plot(warmup_test_acc, label='Test Accuracy')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
# Load the pre-trained model
collaborative_model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'pre_trained_federated_unbalance_model_{J}steps.pth')))
# Unfreeze the backbone
unfreeze_backbone(collaborative_model)
# Freeze the head
freeze_head(collaborative_model)

# Define the mask density for each client
if not sama:
    density = [0.3] * K
else:
    min_density = 0.5
    max_density = 1.0
    density = [1 - (severity_i * (max_density - min_density) + min_density) for severity_i in severities]

In [None]:
# Plot densities
plt.figure(figsize=(10, 5))
plt.bar(range(K), density, color='purple', alpha=0.7)
plt.xlabel('Client ID')
plt.ylabel('Density')
plt.title('Density per client')
plt.xticks(range(K), rotation=90)
plt.tight_layout()
plt.show()

In [None]:
import math

# Compute the mask for each client
client_masks = {}
for client_id in tqdm(range(K)):
    # Create a list of samples per class
    distribution_vector = get_class_distribution_vector(train_datasets[client_id])
    samples_per_class = [math.ceil(distribution_vector[class_id] * 100) for class_id in range(K)]
    # For the unbalanced case, we use the computed density
    client_masks[client_id] = mask_calculator(collaborative_model, train_datasets[client_id], device, rounds=4, density=density[client_id],
                                            samples_per_class=samples_per_class, verbose=False)

In [None]:
# Freeze and clean the client masks
client_masks, frozen_state = freeze_and_clean_client_masks(collaborative_model, client_masks, threshold=0.01, K=K)

# Save the client masks to a single file
torch.save(client_masks, os.path.join(CHECKPOINT_DIR, f'client_masks_unbalance_{J}steps_{sama}_sama.pth'))
# Save the frozen state of the model
torch.save(frozen_state, os.path.join(CHECKPOINT_DIR, f'frozen_state_unbalance_{J}steps_{sama}_sama.pth'))

In [None]:
# Define the number of rounds and other parameters for the main training loop
start_round = 0
num_rounds = 20
best_test_acc = 0.0

hist_train_loss = []
hist_train_acc = []
hist_test_loss = []
hist_test_acc = []

random.seed(42)

In [None]:
# FedAvg loop
for round in range(start_round, start_round + num_rounds):
    print(f"\n--- Round {round + 1}/{start_round + num_rounds} ---")

    # Select clients
    selected_clients = random.sample(range(K), int(C * K))

    # Local training with client masks
    local_models, train_losses, train_accs = [], [], []
    for client_id in selected_clients:
        model_state, loss, acc = train_on_client(
            client_id,
            collaborative_model,
            train_datasets[client_id],
            J,
            criterion,
            lr = 0.01,
            device = device,
            mask = client_masks[client_id]
        )
        local_models.append(model_state)
        train_losses.append(loss)
        train_accs.append(acc)

    # Weighting by dataset size
    client_sample_counts = [len(train_datasets[c]) for c in selected_clients]
    total_samples = sum(client_sample_counts)
    client_weights_metrics = [count / total_samples for count in client_sample_counts]

    if not disco:
        # Weighting uniformly
        client_weights = [1.0 / len(selected_clients)] * len(selected_clients)
    else:
        # Use computed client weights
        tot_weights_disc = sum(weights_disc[c] for c in selected_clients)
        client_weights = [weights_disc[c] / tot_weights_disc for c in selected_clients]

    # Federated averaging
    trainable_keys = get_trainable_keys(collaborative_model)
    averaged_state = average_models(local_models, client_weights, trainable_keys)
    new_state = collaborative_model.state_dict()
    for key in averaged_state:
        new_state[key] = averaged_state[key]
    collaborative_model.load_state_dict(new_state)

    # Log average training metrics
    avg_train_loss = average_metrics(train_losses, client_weights_metrics)
    avg_train_acc = average_metrics(train_accs, client_weights_metrics)
    print(f"Avg Train Loss: {avg_train_loss:.4f}, Avg Train Accuracy: {avg_train_acc:.4f}")
    hist_train_loss.append(avg_train_loss)
    hist_train_acc.append(avg_train_acc)

    # Evaluate on test set every 2 rounds
    if (round + 1) % 2 == 0:
        avg_test_loss, avg_test_acc = evaluate(collaborative_model, test_loader, criterion, device)

        print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
        hist_test_loss.append(avg_test_loss)
        hist_test_acc.append(avg_test_acc)

    # Save model each 5 rounds
    if (round + 1) % 5 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"federated_unbalance_round_{round + 1}_{J}steps_{sama}_sama_{disco}_disco.pth")
        torch.save(collaborative_model.state_dict(), checkpoint_path)

In [None]:
# X-axis for train loss: computed every round
rounds_train = range(len(hist_train_loss))

# X-axis for test loss: computed every 2 rounds
rounds_test = range(1, (len(hist_test_loss) + 1) * 2 - 1, 2)

# Plotting
plt.plot(rounds_train, hist_train_loss, label='Train Loss')
plt.plot(rounds_test, hist_test_loss, label='Test Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.legend()

plt.show()

In [None]:
# X-axis for train loss: computed every round
rounds_train = range(len(hist_train_acc))

# X-axis for test loss: computed every 2 rounds
rounds_test = range(1, (len(hist_test_acc) + 1) * 2 - 1, 2)

# Plotting
plt.plot(rounds_train, hist_train_acc, label='Train Accuracy')
plt.plot(rounds_test, hist_test_acc, label='Test Accuracy')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.legend()

plt.show()