In [None]:
# ------ Copyright (C) 2024 University of Strathclyde and Author ------
# ---------------- Author: Robert Cowlishaw ------------------------
# ----------- e-mail: robert.cowlishaw.2017@uni.strath.ac.uk ----------------

# Crop Type Federated Learning Simulation - Number of Client Study

This notebook is designed to test different federated learning parameters and how they affect the learning.

## Install python libraries

In [None]:
!pip install flwr flwr_datasets datasets torcheval -q
!pip install flwr[simulation]@git+https://github.com/0x365/flower-fl.git -q

## Download github files

The files:
- `UNET.py` - NN architecture
- `utils.py` - Python utitily functions
- `params.yml` - Model hyper-parameter definition file

In [None]:
#@title {vertical-output: true}

from pathlib import Path
my_file = Path("smart_dao")
if not my_file.is_dir():
    !git clone https://github.com/strath-ace/smart-dao smart_dao

!cp smart_dao/agriculture-federated-learning/centralised-unet/UNET.py .
!cp smart_dao/agriculture-federated-learning/centralised-unet/utils.py .
!cp smart_dao/agriculture-federated-learning/setup/params_default.yml params.yml

## Import required libraries

Noteable libraries are PyTorch and Flwr

In [None]:
import argparse
from collections import OrderedDict
from typing import Dict, Tuple, List
import time
import os
import yaml
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Metrics
from flwr.common.typing import Scalar

from datasets import Dataset, load_dataset
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, LinearPartitioner
from flwr_datasets.utils import divide_dataset

from torch.utils.data import random_split

from utils import *
from UNET import UNet

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import Dataset
from torcheval.metrics.functional import multiclass_f1_score
from torch.utils.data import DataLoader
import sklearn.metrics

from tifffile import imread

import numpy as np


## Load hyperparameter file

In [None]:
parser = argparse.ArgumentParser(description="Flower Simulation with PyTorch")

file_location = "."
if not os.path.exists(file_location+"/model_save"):
    os.mkdir(file_location+"/model_save")
with open(file_location+"/params.yml", "r") as f:
    params_config = yaml.load(f, Loader=yaml.SafeLoader)

## Define Loss Functions

In [None]:

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)

    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    return loss.mean()

criterion = torch.nn.CrossEntropyLoss()

## Define Training Function

In [None]:
def train(model,dataloaders, optimizer, num_classes, use_scheduler=False, epochs=10):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    stored_data = {"epoch": {}}
    for i, epoch in enumerate(range(epochs)):
        description = 'Epoch {}/{}'.format(epoch+1, epochs)

        model.train()

        sum_loss = 0
        totaler = 0

        for j, data in enumerate(dataloaders):#tqdm(dataloaders,desc=description)):

            inputs, labels = data

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)

            multi_scale = 0.5 

            loss_ce = criterion(outputs, labels)

            outputs = F.softmax(outputs, dim=1)
            loss_d = dice_loss(outputs, F.one_hot(labels, num_classes=num_classes).permute(0, 3, 1, 2).float())

            loss = loss_ce * multi_scale + loss_d * (1-multi_scale)

            loss.backward()
            optimizer.step()
            # if use_scheduler:
            #     scheduler.step()

            sum_loss += loss.data.cpu().numpy() * inputs.size(0)
            totaler += inputs.size(0)

        updater = {str(epoch+1): {"loss": sum_loss, "epoch_loss": sum_loss/totaler, "time": 3}}
        stored_data["epoch"].update(updater)
        print("----- Loss:", updater[str(epoch+1)]["epoch_loss"], "-----")

    return model, stored_data

## Define testing function

In [None]:
def test(model, dataloaders, num_classes, num_epochs=25, threshold=0):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model.eval()

    sum_loss = 0
    totaler = 0

    big_confusion = np.zeros((num_classes, num_classes))
    number_in_class = np.zeros(num_classes)
    number_in_pred = np.zeros(num_classes)

    for data in dataloaders:#tqdm(dataloaders,desc=description)):

        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        pred = model(inputs)

        multi_scale = 0.5

        loss_ce = criterion(pred, labels)

        pred = F.softmax(pred, dim=1)

        loss_d = dice_loss(pred, F.one_hot(labels, num_classes=num_classes).permute(0, 3, 1, 2).float())

        loss = loss_ce * multi_scale + loss_d * (1-multi_scale)

        sum_loss += loss.data.cpu().numpy() * inputs.size(0)

        def argmax_with_threshold(tensor, threshold, default_value):
            below_threshold = (tensor < threshold).all(dim=1)
            argmax_indices = torch.argmax(tensor, dim=1)
            argmax_indices[below_threshold] = default_value
            return argmax_indices

        if threshold == 0:
            pred = torch.argmax(pred, dim=1)
        else:
            pred = argmax_with_threshold(pred, threshold, 0)

        pred_out = pred.data

        flat_pred = pred_out.cpu().numpy().flatten().astype(int)
        flat_label = labels.cpu().numpy().flatten().astype(int)

        totaler += inputs.size(0)

        big_confusion += sklearn.metrics.confusion_matrix(flat_label, flat_pred, labels=np.arange(num_classes))#, normalize="true")

        binners = np.bincount(flat_label)
        binners_pred = np.bincount(flat_pred)

        number_in_class[range(len(binners))] += binners
        number_in_pred[range(len(binners_pred))] += binners_pred

    loss = float(sum_loss)/float(totaler)

    return loss, big_confusion, number_in_class, number_in_pred

## Download dataset and define dataset loading type

In [None]:
#@title {vertical-output: true}

dataset = load_dataset("0x365/eo-crop-type-belgium", split="train")

ds = dataset.with_format("torch")

ds_split = ds.train_test_split(test_size=0.2)

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, num_classes, big_memory=False):
        self.bands = ["B02","B03","B04","B05","B06","B07","B08","B11","B12"]
        self.num_classes = num_classes
        self.big_memory = big_memory
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if big_memory:
            self.data = []
            for sample in data:
                inputs = torch.stack([sample[k][0].to(self.device) for k in self.bands], dim=0).float() / 255
                labels = sample["label"][0].to(self.device).long()
                labels[labels > self.num_classes-1] = 0
                self.data.append((inputs, labels))
        else:
            self.data = data

    def __getitem__(self, idx):
        if self.big_memory:
            return self.data[idx]
        else:
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            sample = self.data[idx]
            inputs = torch.stack([sample[k][0].to(self.device) for k in self.bands], dim=0).float() / 255
            labels = sample["label"][0].to(self.device).long()
            labels[labels > self.num_classes-1] = 0
            return (inputs, labels)

    def __len__(self):
        return len(self.data)

## Function to reformat json files

In [None]:
def reformat_json(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()  # Convert numpy array to list
    elif isinstance(obj, list) or isinstance(obj, tuple):
        return [reformat_json(item) for item in obj]  # Recursively apply to list items
    elif isinstance(obj, dict):
        return {key: reformat_json(value) for key, value in obj.items()}  # Handle dicts too
    else:
        return obj

## Build Flwr Simulation

In [None]:
# Flower client, adapted from Pytorch quickstart example
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, trainset, valset, num_classes):
        self.trainset = CustomDataset(trainset, num_classes)
        self.valset = CustomDataset(valset, num_classes)

        self.num_classes = num_classes
        self.model = UNet(num_classes)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)  # send model to device

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        set_params(self.model, parameters)

        trainloader = DataLoader(self.trainset, batch_size=20, shuffle=True)

        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

        self.model, stored_data = train(self.model, trainloader, optimizer, self.num_classes)

        return self.get_parameters({}), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        set_params(self.model, parameters)

        # Construct dataloader
        valloader = DataLoader(self.valset, batch_size=10)

        # Evaluate
        loss, big_confusion, number_in_class, number_in_pred = test(self.model, valloader, self.num_classes)

        extra_info = {
            "big_confusion": big_confusion,
            "number_in_class": number_in_class,
            "number_in_pred": number_in_pred
        }

        # Return statistics
        return float(loss), len(valloader.dataset), {"confusion": np.array(big_confusion).flatten().tolist()}


def get_client_fn(dataset, num_classes):
    """Return a function to construct a client.

    The VirtualClientEngine will execute this function whenever a client is sampled by
    the strategy to participate.
    """

    def client_fn(cid: str) -> fl.client.Client:
        """Construct a FlowerClient with its own dataset partition."""

        client_dataset = dataset.load_partition(int(cid))

        client_dataset_splits = client_dataset.train_test_split(test_size=0.1, seed=42)

        trainset = client_dataset_splits["train"]
        valset = client_dataset_splits["test"]

        # Create and return client
        return FlowerClient(trainset, valset, num_classes).to_client()

    return client_fn


def fit_config(server_round: int) -> Dict[str, Scalar]:
    """Return a configuration with static batch size and (local) epochs."""
    config = {
        "epochs": params_config["CL_epochs"],  # Number of local epochs done by clients
        "batch_size": params_config["CL_batch_size"],  # Batch size to use by clients during fit()
    }
    return config


def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]):
    """Set model weights from a list of NumPy ndarrays."""
    params_dict = zip(model.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """Aggregation function for (federated) evaluation metrics, i.e. those returned by
    the client's evaluate() method."""
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}


def get_evaluate_fn(
    centralized_testset: Dataset,
    num_classes,
):
    """Return an evaluation function for centralized evaluation."""

    def evaluate(
        server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
    ):
        """Use the entire CIFAR-10 test set for evaluation."""

        # Determine device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = UNet(num_classes)
        set_params(model, parameters)
        model.to(device)

        # Apply transform to dataset
        testset = CustomDataset(centralized_testset, num_classes)#.with_transform(apply_transforms)

        # Disable tqdm for dataset preprocessing
        disable_progress_bar()

        testloader = DataLoader(testset, batch_size=10)#params_config["AGR_batch_size"])
        loss, big_confusion, number_in_class, number_in_pred = test(model, testloader, num_classes)

        print(reformat_json(big_confusion))
        print(reformat_json(number_in_class))
        print(reformat_json(number_in_pred))

        extra_info = {
            "big_confusion": reformat_json(big_confusion),
            "number_in_class": reformat_json(number_in_class),
            "number_in_pred": reformat_json(number_in_pred)
        }
        print((extra_info))

        return loss, {"confusion": np.array(big_confusion).flatten().tolist()}

    return evaluate

## Run Simulation

In [None]:
#@title {vertical-output: true}

# Resources to be assigned to each virtual client
client_resources = {
    "num_cpus": 10,#params_config["NUM_CPU"],
    "num_gpus": 0.1,#params_config["NUM_GPU"],
}

save_location = os.path.join(file_location, "output")
if not os.path.exists(save_location):
    os.makedirs(save_location)

centralized_testset = ds_split["test"]


PARAM_partition_style = ["iid"]
PARAM_num_classes = [11]
PARAM_num_clients = [256, 128, 64, 32, 16, 8, 4, 2]

###################
starter = 1
##################

aggregation_rounds = 20


for partition_style in PARAM_partition_style:
    for num_classes in PARAM_num_classes:
        for num_clients in PARAM_num_clients:

            file_name = save_location+"/result"
            file_name += "_iteration_"+str(starter)
            file_name += "_partition_style_"+str(partition_style)
            file_name += "_num_classes_"+str(num_classes)
            file_name += "_num_clients_"+str(num_clients)
            file_name += ".json"

            if os.path.exists(file_name):
                print("File already computed for "+file_name)
                continue

            print("##############")
            print("Run test for", file_name)
            print("##############")

            if partition_style == "iid":
                partitioner = IidPartitioner(num_clients)
            elif partition_style == "niid":
                partitioner = LinearPartitioner(num_partitions=num_clients)

            partitioner.dataset = ds_split["train"]

            # Configure the strategy
            if num_clients == 256:
                strategy = fl.server.strategy.FedAvg(
                    fraction_fit=0.5,#params_config["AGR_fraction_fit"],  # Sample 10% of available clients for training
                    fraction_evaluate=0.5,#params_config["AGR_fraction_evaluate"],  # Sample 5% of available clients for evaluation
                    min_available_clients=params_config["AGR_min_available_clients"],
                    on_fit_config_fn=fit_config,
                    evaluate_metrics_aggregation_fn=weighted_average,  # Aggregate federated metrics
                    evaluate_fn=get_evaluate_fn(centralized_testset, num_classes),  # Global evaluation function
                )
            else:
                strategy = fl.server.strategy.FedAvg(
                    fraction_fit=params_config["AGR_fraction_fit"],  # Sample 10% of available clients for training
                    fraction_evaluate=params_config["AGR_fraction_evaluate"],  # Sample 5% of available clients for evaluation
                    min_available_clients=params_config["AGR_min_available_clients"],
                    on_fit_config_fn=fit_config,
                    evaluate_metrics_aggregation_fn=weighted_average,  # Aggregate federated metrics
                    evaluate_fn=get_evaluate_fn(centralized_testset, num_classes),  # Global evaluation function
                )

            # ClientApp for Flower-Next
            client = fl.client.ClientApp(
                client_fn=get_client_fn(partitioner, num_classes),
            )

            # ServerApp for Flower-Next
            server = fl.server.ServerApp(
                config=fl.server.ServerConfig(num_rounds=aggregation_rounds),
                strategy=strategy,
            )

            # Start simulation
            stuff = fl.simulation.start_simulation(
                client_fn=get_client_fn(partitioner, num_classes),
                num_clients=num_clients,
                client_resources=client_resources,
                config=fl.server.ServerConfig(num_rounds=aggregation_rounds),
                strategy=strategy,
            )

            json_out = {
                "iteration": starter,
                "partition_style": partition_style,
                "num_classes": num_classes,
                "num_clients": num_clients,
                "result": stuff.repr_json()
            }
            save_json(file_name, json_out)
