## 1. Federated Averaging (FedAvg)

This is the most common aggregation strategy used in federated learning, and it's implemented in the original paper by Google. FedAvg takes the average of the model weights (or gradients) from all clients.

#### Formula:
For each model parameter $\omega$, the new global model parameter is updated by the weighted average of the client models' parameters:

In [2]:
def fed_avg(models):
    """
    Aggregates models using Federated Averaging (FedAvg).
    Args:
        models: List of model state_dicts from clients.
    Returns:
        avg_model: The averaged global model.
    """
    avg_model = models[0].copy()
    for key in avg_model.keys():
        avg_model[key] = sum([model[key] for model in models]) / len(models)
    return avg_model


## 2. Weighted Federated Averaging (Weighted FedAvg)

In this variant of FedAvg, the aggregation considers the size of each client's dataset. Clients with more data contribute more to the global model update.

#### Formula:

In [3]:
def weighted_fed_avg(models, client_sizes):
    """
    Aggregates models using Weighted Federated Averaging.
    Args:
        models: List of model state_dicts from clients.
        client_sizes: List of dataset sizes for each client.
    Returns:
        avg_model: The averaged global model.
    """
    total_size = sum(client_sizes)
    avg_model = models[0].copy()

    for key in avg_model.keys():
        avg_model[key] = sum([model[key] * (client_sizes[i] / total_size) for i, model in enumerate(models)])
    return avg_model


## 3. Median Aggregation (FedMedian)
Instead of averaging model weights, FedMedian uses the median of the client models' parameters. This strategy is more robust to outliers and may help in cases where some clients have noisy or malicious data.

#### Formula:

In [4]:
import numpy as np

def fed_median(models):
    """
    Aggregates models using Federated Median (FedMedian).
    Args:
        models: List of model state_dicts from clients.
    Returns:
        median_model: The median global model.
    """
    median_model = models[0].copy()

    for key in median_model.keys():
        # Stack all model parameters along a new axis and compute the median along that axis
        stacked_weights = np.stack([model[key].cpu().numpy() for model in models])
        median_model[key] = torch.tensor(np.median(stacked_weights, axis=0))

    return median_model


## 4. Trimmed Mean Aggregation (Trimmed Mean)
The Trimmed Mean strategy trims the highest and lowest values for each parameter and then takes the mean of the remaining values. This is another robust approach that mitigates the effect of outliers or malicious clients.

#### Formula:
For each parameter, remove the highest and lowest $𝑞$% of the values and then average the remaining values.

In [5]:
def trimmed_mean(models, trim_percent=0.1):
    """
    Aggregates models using Trimmed Mean.
    Args:
        models: List of model state_dicts from clients.
        trim_percent: Percentage of extreme values to trim from each side (default 10%).
    Returns:
        trimmed_mean_model: The trimmed mean global model.
    """
    trimmed_mean_model = models[0].copy()
    trim_count = int(trim_percent * len(models))

    for key in trimmed_mean_model.keys():
        # Stack all model parameters along a new axis
        stacked_weights = np.stack([model[key].cpu().numpy() for model in models])

        # Sort the weights and remove the top and bottom trim_percent of the values
        sorted_weights = np.sort(stacked_weights, axis=0)
        trimmed_weights = sorted_weights[trim_count:-trim_count]

        # Compute the mean of the trimmed weights
        trimmed_mean_model[key] = torch.tensor(np.mean(trimmed_weights, axis=0))

    return trimmed_mean_model


## 5. Norm-based Clipping Aggregation
This aggregation strategy clips the model weights based on their norms. This is useful when some clients have unusually large weight updates, which can destabilize training.

#### Formula:
For each model parameter $\omega_k$, clip it if its norm exceeds a threshold 𝜏:

In [6]:
def norm_clipping(models, clip_threshold=1.0):
    """
    Aggregates models using Norm-based Clipping.
    Args:
        models: List of model state_dicts from clients.
        clip_threshold: Clipping threshold for the norm.
    Returns:
        clipped_model: The clipped global model.
    """
    clipped_model = models[0].copy()

    for key in clipped_model.keys():
        stacked_weights = np.stack([model[key].cpu().numpy() for model in models])
        norm_weights = np.linalg.norm(stacked_weights, axis=0)

        # Clip the weights based on their norm
        clipped_weights = np.minimum(1, clip_threshold / norm_weights) * stacked_weights
        clipped_model[key] = torch.tensor(np.mean(clipped_weights, axis=0))

    return clipped_model


## 6. Krum Aggregation
Krum is an aggregation strategy designed to be robust to Byzantine (malicious) clients. It selects a single model from the set of client models that is the "most central" by calculating distances between model updates and excluding outliers.

#### Formula:
For each model, compute the distance to all other models and select the one with the smallest sum of distances to its closest $K$ neighbors.

In [7]:
def krum(models, num_neighbors=2):
    """
    Aggregates models using Krum, a robust aggregation technique.
    Args:
        models: List of model state_dicts from clients.
        num_neighbors: Number of closest neighbors to consider (default is 2).
    Returns:
        krum_model: The selected Krum model.
    """
    distances = []

    # Calculate distances between models
    for i, model_i in enumerate(models):
        dists = []
        for j, model_j in enumerate(models):
            if i != j:
                dist = sum([(model_i[key] - model_j[key]).norm().item() for key in model_i])
                dists.append(dist)
        dists.sort()
        distances.append((i, sum(dists[:num_neighbors])))

    # Select the model with the smallest sum of distances to its closest neighbors
    selected_model_index = sorted(distances, key=lambda x: x[1])[0][0]
    return models[selected_model_index]
