In [None]:
!pip install git+https://github.com/passerim/peer-reviewed-flower.git

In [None]:
import os
import random
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple, Union

import flwr as fl
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from flwr.common import (
    EvaluateRes,
    EvaluateIns,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    Parameters,
    Scalar,
    NDArrays,
    parameters_to_ndarrays,
    ndarrays_to_parameters,
)
from flwr.server import ServerConfig
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from flwr.simulation import start_simulation
from overrides import overrides
from prflwr.peer_review import PeerReviewClient
from prflwr.peer_review import PeerReviewServer
from prflwr.peer_review.strategy import PeerReviewStrategy
from prflwr.peer_review.strategy import (
    AggregateTrainException,
    AggregateReviewException,
    ConfigureReviewException,
)
from prflwr.utils import non_iid_partitions
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets


In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [None]:
# Setting random seed for reproducibility
SEED = 123
set_seed(SEED)

DATASET = "CIFAR10"  # possible values: "CIFAR10" or "CIFAR100"
NUM_EPOCHS = 50
NUM_CLIENTS = 50
LOCAL_EPOCHS = 2
BATCH_SIZE = 32
FRACTION_REV = 1 / 4
FRACTION_FIT = 1 / 4
FRACTION_EVAL = 0
REVIEW_SCORE = "review_score"
NUM_ROUNDS = int((NUM_EPOCHS // (LOCAL_EPOCHS * FRACTION_FIT)) * (1 + FRACTION_FIT))
print(f"Training for {NUM_ROUNDS} rounds")

# Device to use for training and evaluation
DEVICE = torch.device("cpu")
print(f"Training on {DEVICE}")


In [None]:
def load_datasets(
    num_clients: int,
    dataset: str = "CIFAR10",
    src: str = ".",
    iid: bool = True,
    concentration: float = 1,
    use_augmentation: bool = False,
) -> Tuple[List[DataLoader], List[DataLoader], DataLoader]:
    if dataset not in ["CIFAR10", "CIFAR100"]:
        raise ValueError(
            "Unknown dataset! Admissible values are: 'CIFAR10' or 'CIFAR100'."
        )

    # Download and transform CIFAR dataset (train and test)
    augmentation = (
        [
            transforms.Pad(4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32),
        ]
        if use_augmentation
        else []
    )
    transform = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    trainset = getattr(datasets, dataset)(
        os.path.join(src, "./data"),
        train=True,
        download=True,
        transform=transforms.Compose([*augmentation, *transform]),
    )
    testset = getattr(datasets, dataset)(
        os.path.join(src, "./data"),
        train=False,
        download=True,
        transform=transforms.Compose([*transform]),
    )

    # Split training set into `num_clients` partitions to simulate different local datasets
    if not iid:
        targets = np.array(trainset.targets)
        idxs = np.array(range(len(targets)))
        dataset = [idxs, targets]
        train_partitions = non_iid_partitions(
            dataset,
            num_partitions=num_clients,
            concentration=concentration,
        )
        subsets = list(map(lambda p: Subset(trainset, p), train_partitions))
    else:
        partition_size = len(trainset) // num_clients
        lengths = [partition_size] * num_clients
        subsets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in subsets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloader


In [None]:
class Net(nn.Module):
    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(512, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 512)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= total
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net: nn.Module, testloader: DataLoader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= total
    accuracy = correct / total
    return loss, accuracy


In [None]:
# Load data
trainloaders, valloaders, testloader = load_datasets(
    NUM_CLIENTS, DATASET, iid=False, concentration=0.1
)
NUM_CLASSES = len(np.unique(testloader.dataset.targets))

# Create an instance of the model
net = Net(NUM_CLASSES).to(DEVICE)
with torch.no_grad():
    assert net(torch.randn((3, 32, 32), device=DEVICE)).shape == torch.Size(
        [1, NUM_CLASSES]
    )

# Print some stats about the model and the data
print("Model parameters:", sum(p.numel() for p in net.parameters() if p.requires_grad))
print("Client's trainset size:", len(trainloaders[0].dataset))
print("Client's validation set size:", len(valloaders[0].dataset))
print("Server's testset size:", len(testloader.dataset))


In [None]:
def histshow(loader):
    plt.hist(torch.concat([labels for _, labels in iter(loader)]))
    plt.show()


histshow(trainloaders[0])


In [None]:
class FlowerClient(PeerReviewClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def train(self, parameters, config):
        # Read values from config
        current_round = config["current_round"]
        local_epochs = config["local_epochs"]
        # Use values provided by the config
        print(f"[Client {self.cid}, round {current_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def review(self, parameters, config):
        loss, num_examples, _ = self.evaluate(parameters, {})
        return [], num_examples, {REVIEW_SCORE: loss}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {}


In [None]:
class FedLS(PeerReviewStrategy):
    def __init__(
        self,
        max_review_rounds: int = 5,
        max_step_size: float = 1,
        min_step_size: float = 1e-3,
        step_size_decay: float = 0.1,
        gamma: float = 0,
        fraction_review: float = 0.1,
        fraction_fit: float = 0.1,
        fraction_evaluate: float = 0.1,
        min_review_clients: int = 2,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_review_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        super(FedLS, self).__init__()

        # Strategy adapter fields
        self.max_review_rounds = max_review_rounds
        self.max_step_size = max_step_size
        self.min_step_size = min_step_size
        self.step_size_decay = step_size_decay
        self.gamma = gamma
        self.step_size: float = 0.0
        self.current_loss: float = np.inf
        self.candidate_loss: float = np.inf

        # Strategy adapter fields
        self.fraction_review = fraction_review
        self.min_review_clients = min_review_clients
        self.on_review_config_fn = on_review_config_fn
        self.fedavg = FedAvg(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def num_review_clients(self, num_available_clients: int):
        """Return the sample size and the required number of available
        clients."""
        num_clients = int(num_available_clients * self.fraction_review)
        return (
            max(num_clients, self.min_review_clients),
            self.fedavg.min_available_clients,
        )

    # Standard strategy
    @overrides
    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        return self.fedavg.initialize_parameters(client_manager)

    @overrides
    def configure_evaluate(
        self, server_round: int, parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        return self.fedavg.configure_evaluate(server_round, parameters, client_manager)

    @overrides
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        loss, metrics = self.fedavg.aggregate_evaluate(server_round, results, failures)
        self.current_loss = loss
        return loss, metrics

    @overrides
    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        loss, metrics = self.fedavg.evaluate(server_round, parameters)
        self.current_loss = loss
        return loss, metrics

    # Multiple reviews strategy
    @overrides
    def configure_train(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        print(f"Round: {server_round}, configure_train")
        self.step_size = self.max_step_size
        return self.fedavg.configure_fit(server_round, parameters, client_manager)

    @overrides
    def aggregate_train(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
        parameters: Optional[Parameters],
    ) -> List[Tuple[Optional[Parameters], Dict[str, Scalar]]]:
        print(f"Round: {server_round}, aggregate_train")
        if not results:
            raise AggregateTrainException

        # Do not aggregate if there are failures and failures are not accepted
        if not self.fedavg.accept_failures and failures:
            raise AggregateTrainException

        # Convert results
        weights_aggregated = aggregate(
            [
                (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
                for _, fit_res in results
            ]
        )
        current_weights = parameters_to_ndarrays(parameters)
        gradient = [
            weights_aggregated[i] - current_weights[i]
            for i in range(len(weights_aggregated))
        ]
        gradient_params = ndarrays_to_parameters(gradient)
        del weights_aggregated, current_weights, gradient
        return [(gradient_params, {})]

    @overrides
    def configure_review(
        self,
        server_round: int,
        review_round: int,
        parameters: Parameters,
        client_manager: ClientManager,
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
    ) -> List[Tuple[ClientProxy, FitIns]]:
        # Do not configure federated review if fraction_review is 0
        if self.fraction_review == 0.0:
            raise ConfigureReviewException

        # Parameters and config
        config = {}
        if self.on_review_config_fn is not None:
            # Custom fit config function provided
            config = self.on_review_config_fn(server_round, review_round)

        # Prepare review instructions
        gradient = parameters_to_ndarrays(parameters_aggregated[0])
        current_weights = parameters_to_ndarrays(parameters)
        self.step_size = max(
            self.min_step_size,
            self.max_step_size * (self.step_size_decay ** (review_round - 1)),
        )
        weights = [
            gradient[i] * self.step_size + current_weights[i]
            for i in range(len(gradient))
        ]
        review_ins = FitIns(ndarrays_to_parameters(weights), config)

        # Sample clients
        sample_size, min_num_clients = self.num_review_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        del weights, gradient, current_weights
        print(
            f"Round: {server_round}, review_round: {review_round}, configure_review, step_size: {self.step_size}"
        )
        # Return client/config pairs
        return [(client, review_ins) for client in clients]

    @overrides
    def aggregate_review(
        self,
        server_round: int,
        review_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
        parameters: Parameters,
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
    ) -> List[Tuple[Optional[Parameters], Dict[str, Scalar]]]:
        print(f"Round: {server_round}, aggregate_review")
        if not results:
            raise AggregateReviewException

        # Do not aggregate if there are failures and failures are not accepted
        if not self.fedavg.accept_failures and failures:
            raise AggregateReviewException

        # Aggregate results
        aggregated_loss = weighted_loss_avg(
            [
                (review_res.num_examples, review_res.metrics[REVIEW_SCORE])
                for _, review_res in results
            ]
        )
        self.candidate_loss = aggregated_loss
        return list(zip(parameters_aggregated, metrics_aggregated))

    @overrides
    def aggregate_after_review(
        self,
        server_round: int,
        parameters: Optional[Parameters],
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
    ) -> Optional[Parameters]:
        print(f"Round: {server_round}, aggregate_after_review")
        current_weights = parameters_to_ndarrays(parameters)
        gradient = parameters_to_ndarrays(parameters_aggregated[0])

        # Compute update
        for j, tensor in enumerate(gradient):
            current_weights[j] += self.step_size * gradient[j]
        print(f"Round {server_round}: lr {self.step_size}")

        # Return
        parameters_prime = ndarrays_to_parameters(current_weights)
        del current_weights, gradient
        return parameters_prime

    @overrides
    def stop_review(
        self,
        server_round: int,
        review_round: int,
        parameters: Parameters,
        client_manager: ClientManager,
        parameters_aggregated: List[Optional[Parameters]],
        metrics_aggregated: List[Dict[str, Scalar]],
    ) -> bool:
        print(
            f"Round: {server_round}, review_round: {review_round}, stop_review, candidate loss: {self.candidate_loss}"
        )
        gradient = parameters_to_ndarrays(parameters_aggregated[0])
        if self.candidate_loss <= (
            self.current_loss
            - self.gamma
            * self.step_size
            * sum(map(lambda x: np.linalg.norm(x) ** 2, gradient))
        ):
            del gradient
            return True
        else:
            del gradient
            if (self.step_size > self.min_step_size) and (
                review_round < (self.max_review_rounds + 1)
            ):
                return False
            else:
                return True


In [None]:
def client_fn(cid: str) -> FlowerClient:
    net = Net(NUM_CLASSES).to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)


def evaluate(
    server_round: int,
    weights: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    net = Net(NUM_CLASSES).to(DEVICE)
    set_parameters(net, weights)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}


def review_config(server_round: int, review_round: int):
    return {"server_round": server_round, "review_round": review_round}


def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "current_round": server_round,  # The current round of federated learning
        "local_epochs": 1 if server_round < 2 else LOCAL_EPOCHS,  #
    }
    return config


In [None]:
strategy = FedLS(
    max_review_rounds=4,
    step_size_decay=np.cos(np.pi / 4),
    fraction_review=FRACTION_REV,
    fraction_fit=FRACTION_FIT,
    fraction_evaluate=FRACTION_EVAL,
    min_review_clients=int(FRACTION_REV * NUM_CLIENTS),
    min_fit_clients=int(FRACTION_FIT * NUM_CLIENTS),
    min_evaluate_clients=int(FRACTION_EVAL * NUM_CLIENTS),
    min_available_clients=NUM_CLIENTS,
    initial_parameters=ndarrays_to_parameters(get_parameters(net)),
    on_fit_config_fn=fit_config,
    on_review_config_fn=review_config,
    evaluate_fn=evaluate,
)
client_manager = SimpleClientManager()
server = PeerReviewServer(client_manager, strategy)

# Enable this in order to debug or to use the gpu,
# running simulations in local mode however suffers
# from memory leakage and fills the disk with time.
LOCAL_MODE = False

hist = start_simulation(
    server=server,
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_manager=client_manager,
    client_resources={
        "num_cpus": 1,
        "num_gpus": 1
        if (torch.cuda.is_available() and LOCAL_MODE and "cuda" in DEVICE)
        else 0,
    },
    ray_init_args={
        "local_mode": LOCAL_MODE,
        "include_dashboard": False,
    },
)
