In [None]:
!pip install git+https://github.com/passerim/peer-reviewed-flower.git

In [None]:
import os
import random
from collections import OrderedDict
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Iterable

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from flwr.common import (
    EvaluateIns,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    Parameters,
    Scalar,
    Weights,
    parameters_to_weights,
    weights_to_parameters
)
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from prflwr.peer_review import PeerReviewClient
from prflwr.peer_review import PrConfig
from prflwr.peer_review import PeerReviewServer
from prflwr.peer_review.strategy import PeerReviewStrategy
from prflwr.peer_review.strategy import (
    AggregateTrainException,
    AggregateReviewException,
    ConfigureReviewException,
)
from prflwr.simulation import start_simulation
from prflwr.utils import import_dataset_utils
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets

import_dataset_utils()

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [None]:
# Setting random seed for reproducibility
SEED = 123
set_seed(SEED)

DATASET = "CIFAR10"  # possible values: "CIFAR10" or "CIFAR100"
NUM_EPOCHS = 50
NUM_CLIENTS = 50
LOCAL_EPOCHS = 2
BATCH_SIZE = 32
FRACTION_REV = 1/4
FRACTION_FIT = 1/4
FRACTION_EVAL = 0
NUM_ROUNDS = int(NUM_EPOCHS / (LOCAL_EPOCHS * FRACTION_FIT) + 0.5)
LEARNING_RATE = 0.1
LR_GAMMA = 0.1
CANDIDATE_REVIEW_SCORE = "c_score"
ANTITHETIC_REVIEW_SCORE = "a_score"

# Device to use for training and evaluation
DEVICE = torch.device("cpu")
print(f"Training on {DEVICE}")

In [None]:
def load_datasets(
    num_clients: int, 
    dataset: str = "CIFAR10",
    src: str = ".",
    iid: bool = True,
    concentration: float = 1,
    use_augmentation: bool = False
) -> Tuple[List[DataLoader], List[DataLoader], DataLoader]:
    if dataset not in ["CIFAR10", "CIFAR100"]:
        raise ValueError(
            "Unknown dataset! Admissible values are: 'CIFAR10' or 'CIFAR100'."
        )

    # Download and transform CIFAR dataset (train and test)
    augmentation = [
        transforms.Pad(4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
    ] if use_augmentation else []
    transform = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    trainset = getattr(datasets, dataset)(
        os.path.join(src, "./data"),
        train=True,
        download=True,
        transform=transforms.Compose([*augmentation, *transform]),
    )
    testset = getattr(datasets, dataset)(
        os.path.join(src, "./data"),
        train=False,
        download=True,
        transform=transforms.Compose([*transform]),
    )

    # Split training set into `num_clients` partitions to simulate different local datasets
    if not iid:
        targets = np.array(trainset.targets)
        idxs = np.array(range(len(targets)))
        dataset = [idxs, targets]
        train_partitions, _ = fl.dataset.utils.common.create_lda_partitions(
            dataset,
            num_partitions=num_clients,
            concentration=concentration,
            accept_imbalanced=False,
        )
        subsets = list(map(lambda p: Subset(trainset, p[0]), train_partitions))
    else:
        partition_size = len(trainset) // num_clients
        lengths = [partition_size] * num_clients
        subsets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in subsets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloader


In [None]:
class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(512, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 512)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= total
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net: nn.Module, testloader: DataLoader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= total
    accuracy = correct / total
    return loss, accuracy


In [None]:
# Load data
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS, DATASET)
NUM_CLASSES = len(np.unique(testloader.dataset.targets))

# Create an instance of the model
net = Net(NUM_CLASSES).to(DEVICE)
with torch.no_grad():
    assert net(torch.randn((3, 32, 32), device=DEVICE)).shape == torch.Size([1, NUM_CLASSES])

# Print some stats about the model and the data
print("Model parameters:", sum(p.numel() for p in net.parameters() if p.requires_grad))
print("Client's trainset size:", len(trainloaders[0].dataset))
print("Client's validation set size:", len(valloaders[0].dataset))
print("Server's testset size:", len(testloader.dataset))

In [None]:
class FlowerClient(PeerReviewClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def train(self, parameters, config):
        # Read values from config
        current_round = config["current_round"]
        local_epochs = config["local_epochs"]
        # Use values provided by the config
        print(f"[Client {self.cid}, round {current_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def review(self, parameters, config):
        l = len(parameters)
        assert (l % 2) == 0
        c_loss, num_examples, _ = self.evaluate(parameters[:(l // 2)], {})
        a_loss, _, _ = self.evaluate(parameters[(l // 2):], {})
        return [], num_examples, {CANDIDATE_REVIEW_SCORE: c_loss, ANTITHETIC_REVIEW_SCORE: a_loss}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {}


In [None]:
class FedEs(PeerReviewStrategy):
    def __init__(
        self,
        fraction_review: float = 0.1,
        fraction_fit: float = 0.1,
        fraction_eval: float = 0.1,
        min_review_clients: int = 2,
        min_fit_clients: int = 2,
        min_eval_clients: int = 2,
        min_available_clients: int = 2,
        eval_fn: Optional[
            Callable[[Weights], Optional[Tuple[float, Dict[str, Scalar]]]]
        ] = None,
        on_review_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        super(FedEs, self).__init__()
        # Strategy logic fields
        self.current_loss: float = None
        self.candidate_weights: Weights = None
        self.candidate_loss: float = None
        self.antithetic_loss: float = None
        self.lr = 1
        self.patience = 5
        self.worse_loss = 0
        self.prev_loss = np.inf
        
        # Strategy adapter fields
        self.fraction_review = fraction_review
        self.min_review_clients = min_review_clients
        self.on_review_config_fn = on_review_config_fn
        self.fedavg = FedAvg(
            fraction_fit=fraction_fit,
            fraction_eval=fraction_eval,
            min_fit_clients=min_fit_clients,
            min_eval_clients=min_eval_clients,
            min_available_clients=min_available_clients,
            eval_fn=eval_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def num_review_clients(self, num_available_clients: int):
        """Return the sample size and the required number of available
        clients."""
        num_clients = int(num_available_clients * self.fraction_review)
        return (
            max(num_clients, self.min_review_clients),
            self.fedavg.min_available_clients,
        )

    # Standard strategy
    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        return self.fedavg.initialize_parameters(client_manager)

    def configure_evaluate(
        self, rnd: int, parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        return self.fedavg.configure_evaluate(rnd, parameters, client_manager)

    def aggregate_evaluate(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: BaseException,
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        return self.fedavg.aggregate_evaluate(rnd, results, failures)

    def evaluate(
        self, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        loss, metrics = self.fedavg.evaluate(parameters)
        self.current_loss = loss
        return loss, metrics

    # Multiple reviews strategy
    def configure_train(
        self, rnd: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        return self.fedavg.configure_fit(rnd, parameters, client_manager)

    def aggregate_train(
        self,
        rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> List[Tuple[Optional[Parameters], Dict[str, Scalar]]]:
        if not results:
            raise AggregateTrainException

        # Do not aggregate if there are failures and failures are not accepted
        if not self.fedavg.accept_failures and failures:
            raise AggregateTrainException

        # Convert results
        weights_aggregated = aggregate(
            [
                (parameters_to_weights(fit_res.parameters), fit_res.num_examples)
                for _, fit_res in results
            ]
        )
        return [(weights_to_parameters(weights_aggregated), {})]

    def configure_review(
        self,
        rnd: int,
        review_rnd: int,
        parameters: Parameters,
        client_manager: ClientManager,
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
    ) -> List[Tuple[ClientProxy, FitIns]]:
        # Do not configure federated review if fraction_review is 0
        if self.fraction_review == 0.0:
            raise ConfigureReviewException

        # Parameters and config
        config = {}
        if self.on_review_config_fn is not None:
            # Custom fit config function provided
            config = self.on_review_config_fn(rnd, review_rnd)

        # Prepare review instructions
        params = parameters_aggregated.pop(0)
        self.candidate_weights = parameters_to_weights(params)
        weights = parameters_to_weights(parameters)
        params = [*self.candidate_weights, *weights]
        review_ins = FitIns(weights_to_parameters(params), config)

        # Sample clients
        sample_size, min_num_clients = self.num_review_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, review_ins) for client in clients]

    def aggregate_review(
        self,
        rnd: int,
        review_rnd: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> List[Tuple[Optional[Parameters], Dict[str, Scalar]]]:
        if not results:
            raise AggregateReviewException

        # Do not aggregate if there are failures and failures are not accepted
        if not self.fedavg.accept_failures and failures:
            raise AggregateReviewException
        
        # Aggregate results
        candidate_loss = weighted_loss_avg(
            [
                (review_res.num_examples, review_res.metrics[CANDIDATE_REVIEW_SCORE])
                for _, review_res in results
            ]
        )
        antithetic_loss = weighted_loss_avg(
            [
                (review_res.num_examples, review_res.metrics[ANTITHETIC_REVIEW_SCORE])
                for _, review_res in results
            ]
        )

        # Save losses
        self.candidate_loss = candidate_loss
        self.antithetic_loss = antithetic_loss
        return []

    def aggregate_after_review(
        self,
        rnd: int,
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
        parameters: Optional[Parameters],
    ) -> Optional[Parameters]:
        # Get weigths and losses
        parameters = parameters_to_weights(parameters)

        # Compute lr
        if self.current_loss < self.prev_loss:
            self.worse_loss = 0
            self.prev_loss = self.current_loss
        elif self.current_loss < self.prev_loss: 
            self.worse_loss += 1
        if self.worse_loss >= self.patience:
            self.lr *= LR_GAMMA
            self.prev_loss = self.current_loss

        # Compute update
        parameters_prime = deepcopy(parameters)
        for j, tensor in enumerate(self.candidate_weights):
            parameters_prime[j] -= self.lr * (self.candidate_weights[j] - parameters[j]) * ((self.candidate_loss - self.current_loss) / self.candidate_loss)
        
        print(f"Round {rnd}: alpha {(self.candidate_loss - self.current_loss) * self.lr}, lr {self.lr}")

        # Return
        return weights_to_parameters(parameters_prime)

    def stop_review(
        self,
        rnd: int,
        review_rnd: int,
        parameters: Parameters,
        client_manager: ClientManager,
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
    ) -> bool:
        print("stop_review", flush=True)
        return True


In [None]:
def client_fn(cid) -> FlowerClient:
    net = Net(NUM_CLASSES).to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)


def evaluate(
    weights: fl.common.Weights,
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net(NUM_CLASSES).to(DEVICE)
    set_parameters(net, weights)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}


def review_config(rnd: int , review_rnd: int):
    return {"rnd": rnd, "review_rnd": review_rnd}


def fit_config(rnd: int):
    """Return training configuration dict for each round.
    
    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "current_round": rnd,  # The current round of federated learning
        "local_epochs": 1 if rnd < 2 else LOCAL_EPOCHS,  # 
    }
    return config

In [None]:
strategy = FedEs(
    fraction_review=FRACTION_REV,
    fraction_fit=FRACTION_FIT,
    fraction_eval=FRACTION_EVAL,
    min_review_clients=int(FRACTION_REV * NUM_CLIENTS),
    min_fit_clients=int(FRACTION_FIT * NUM_CLIENTS),
    min_eval_clients=int(FRACTION_EVAL * NUM_CLIENTS),
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.weights_to_parameters(get_parameters(net)),
    on_fit_config_fn=fit_config,
    on_review_config_fn=review_config,
    eval_fn=evaluate,
)
client_manager = SimpleClientManager()

start_simulation(
    server=PeerReviewServer(client_manager, strategy),
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    num_rounds=NUM_ROUNDS,
    strategy=strategy,
    client_manager=client_manager,
    client_resources={"num_cpus": 1, "num_gpus": 1},
    ray_init_args={"local_mode": True, "include_dashboard": False}
)