# Federated Learning using Flower


## 1. Preparing and loading the Dataset
### Installing Dependencies

In [None]:
!pip install --quiet flwr[simulation] flwr-datasets[vision]
# datasets.utils.logging not working so had to update datasets
# https://github.com/huggingface/datasets/issues/6985
!pip install --quiet --upgrade datasets

### Import libraries

In [None]:
%matplotlib inline
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt 
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context 
from flwr.server import ServerApp, ServerConfig, ServerAppComponents 
from flwr.server.strategy import FedAvg 
from flwr.simulation import run_simulation 
from flwr_datasets import FederatedDataset

import warnings
warnings.filterwarnings("ignore")

In [None]:
# Throwing error: https://github.com/huggingface/datasets/issues/6985
# from datasets.utils.logging import disable_progress_bar
# disable_progress_bar()

In [None]:
DEVICE = torch.device('cuda')
print(f"Training on Device: {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

### Load the Data

We will load the CIFAR-10 dataset which has 10 different classes. We will train a CNN on this dataset. 

To simulate a federated setting, we will split the CIFAR dataset in partitions, to create a scenario of multiple datasets from multiple organizations. We will use the `flwr_datasets` library to partition the dataset into 10 partitions using `FederatedDataset`

In [None]:
NUM_CLIENTS = 10
BATCH_SIZE = 32

### Load Dataset

In [None]:
def load_datasets(partition_id: int):
    fds = FederatedDataset(dataset = "cifar10", partitioners = {'train': NUM_CLIENTS})
    partition = fds.load_partition(partition_id)
    # Divide dataset on each node/client: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size = 0.2, seed = 42)
    pytorch_transforms = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10 (..., transform = transform)
        # use the function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch['img'] = [pytorch_transforms(img) for img in batch['img']]
        return batch 

    # Create train/val for each partition and wrap it into DataLoader 
    partition_train_test = partition_train_test.with_transform(apply_transforms)
    train_loader = DataLoader(partition_train_test['train'], batch_size = BATCH_SIZE, shuffle = True)
    val_loader = DataLoader(partition_train_test['test'], batch_size = BATCH_SIZE)
    test_set = fds.load_split('test').with_transform(apply_transforms)
    test_loader = DataLoader(test_set, batch_size = BATCH_SIZE)
    return train_loader, val_loader, test_loader

### Sanity Check

In [None]:
train_loader, _, _ = load_datasets(partition_id = 0)
batch = next(iter(train_loader))
images, labels = batch['img'], batch['label']

# Reshape and convert the images into a NumPy Array
# matplotlib requires images with the shape (height, width, 3)
images = images.permute(0, 2, 3, 1).numpy()

# Denormalize
images = images / 2 + 0.5 

# Create a figure and a grid of subplots 
fig, axs = plt.subplots(4, 8, figsize = (12, 6))

# Loop over the images and plot them 
for i, ax in enumerate(axs.flat):
    ax.imshow(images[i])
    ax.set_title(train_loader.dataset.features['label'].int2str([labels[i]])[0])

# Show the plot
fig.tight_layout()
plt.show()

## 2. Centralized Training with PyTorch
### Define the Model

We will define a simple CNN Model described in this [PyTorch Tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network) for the CIFAR-10 Dataset.

In [None]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5* 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return x

### Train function

In [None]:
def train(net, train_loader, epochs: int, verbose = False):
    """
    Train the network on the training set.
    """
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()

    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in train_loader:
            images, labels = batch['img'].to(DEVICE), batch['label'].to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(train_loader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch + 1}: train loss {epoch_loss}, accuracy {epoch_acc}" )

### Test function

In [None]:
def test(net, test_loader):
    """
    Evaluate the network on the entire test set.
    """
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in test_loader:
            images, labels = batch['img'].to(DEVICE), batch['label'].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(test_loader.dataset)
    accuracy = correct / total 
    return loss, accuracy

### Train the Model

In [None]:
train_loader, val_loader, test_loader = load_datasets(partition_id = 0)
net = Net().to(DEVICE)

for epoch in range(5):
    train(net, train_loader, 1)
    loss, accuracy = test(net, val_loader)
    print(f"Epoch {epoch+1}: validation loss {loss}, accuracy {accuracy}")

loss, accuracy = test(net, test_loader)
print(f"Final test set performance: \n\tloss {loss}\n\taccuracy {accuracy}")

## 3. Federated Learning with Flower
### Update model parameters
We need two helper functions to update the local model with parameters received from the server and to get the updated model parameters from the local model: set_parameters and get_parameters. The following two functions do just that for the PyTorch model above.



In [None]:
def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

### Define the Flower ClientApp

In [None]:
class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [None]:
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    trainloader, valloader, _ = load_datasets(partition_id=partition_id)

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    return FlowerClient(net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

### Define the Flower ServerApp

In [None]:
# Create FedAvg strategy
strategy = FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
    min_fit_clients=10,  # Never sample less than 10 clients for training
    min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
)

In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5)

    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

### Run the training

In [None]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [None]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

### Where’s the accuracy?

In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """

    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=10,
        min_evaluate_clients=5,
        min_available_clients=10,
        evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    )

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=5)

    return ServerAppComponents(strategy=strategy, config=config)


# Create a new server instance with the updated FedAvg strategy
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)