# Workflow Interface 104: MNIST with Fedcurv implementation
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedcurv.ipynb)

In this OpenFL workflow interface tutorial, we'll learn how to implement FedProx and FedAvg algorithms using synthetic dataset.

# Getting Started

First we start by installing the necessary dependencies for the workflow interface

In [None]:
# !pip install git+https://github.com/intel/openfl.git
# !pip install -r requirements_workflow_interface.txt

# Uncomment this if running in Google Colab
#pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/requirements_workflow_interface.txt
#import os
#os.environ["USERNAME"] = "colab"

In [None]:
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import functional as F
import torch as pt
import numpy as np

from openfl.interface.aggregation_functions.weighted_average import weighted_average as wa
# from openfl.federated.task.runner_pt import _set_optimizer_state, _get_optimizer_state
from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator
from openfl.utilities.optimizers.torch import FedProxOptimizer

from collections import OrderedDict
# from copy import deepcopy
import math
import random
import warnings
warnings.filterwarnings("ignore")

In [None]:
n_epochs = 100
batch_size = 10
learning_rate = 0.01
log_interval = 1000
weight_decay = 0.001
E = 20
NUM_COLLABORATORS = 30
RANDOM_SEED = 10

Set the following `mu` parameter to `0.0` to run FedAvg, and `1.0` to run FedProx

In [None]:
mu = 1.0

Set seed in order to reproduce the results

In [None]:
# Sets seed to reproduce the results
def set_seed(seed):
    pt.manual_seed(seed)
    pt.cuda.manual_seed_all(seed)
    pt.use_deterministic_algorithms(True)
    pt.backends.cudnn.deterministic = True
    pt.backends.cudnn.benchmark = False
    pt.backends.cudnn.enabled = False
    np.random.seed(seed)
    random.seed(seed)

set_seed(RANDOM_SEED)

Generate Synthetic Dataset

In [None]:
def one_hot(labels, classes):
    return np.eye(classes)[labels]


def softmax(x):
    ex = np.exp(x)
    sum_ex = np.sum(np.exp(x))
    return ex / sum_ex


def generate_synthetic(alpha, beta, iid, num_collaborators, num_classes):
    dimension = 60
    NUM_CLASS = num_classes
    NUM_USER = num_collaborators

    samples_per_user = np.random.lognormal(4, 2, (NUM_USER)).astype(int) + 50
    num_samples = np.sum(samples_per_user)

    X_split = [[] for _ in range(NUM_USER)]
    y_split = [[] for _ in range(NUM_USER)]

    #### define some eprior ####
    mean_W = np.random.normal(0, alpha, NUM_USER)
    mean_b = mean_W
    B = np.random.normal(0, beta, NUM_USER)
    mean_x = np.zeros((NUM_USER, dimension))

    diagonal = np.zeros(dimension)
    for j in range(dimension):
        diagonal[j] = np.power((j + 1), -1.2)
    cov_x = np.diag(diagonal)

    for i in range(NUM_USER):
        if iid == 1:
            mean_x[i] = np.ones(dimension) * B[i]  # all zeros
        else:
            mean_x[i] = np.random.normal(B[i], 1, dimension)

    if iid == 1:
        W_global = np.random.normal(0, 1, (dimension, NUM_CLASS))
        b_global = np.random.normal(0, 1, NUM_CLASS)

    for i in range(NUM_USER):

        W = np.random.normal(mean_W[i], 1, (dimension, NUM_CLASS))
        b = np.random.normal(mean_b[i], 1, NUM_CLASS)

        if iid == 1:
            W = W_global
            b = b_global

        xx = np.random.multivariate_normal(
            mean_x[i], cov_x, samples_per_user[i])
        yy = np.zeros(samples_per_user[i])

        for j in range(samples_per_user[i]):
            tmp = np.dot(xx[j], W) + b
            yy[j] = np.argmax(softmax(tmp))

        X_split[i] = xx.tolist()
        y_split[i] = yy.tolist()

    return X_split, y_split


class SyntheticFederatedDataset:
    def __init__(self, num_collaborators, batch_size=1, num_classes=10, **kwargs):
        self.batch_size = batch_size
        X, y = generate_synthetic(0.0, 0.0, 0, num_collaborators, num_classes)
        X = [np.array([np.array(sample).astype(np.float32)
                      for sample in col]) for col in X]
        y = [np.array([np.array(one_hot(int(sample), num_classes))
                      for sample in col]) for col in y]
        self.X_train_all = np.array([col[:int(0.9 * len(col))] for col in X])
        self.X_valid_all = np.array([col[int(0.9 * len(col)):] for col in X])
        self.y_train_all = np.array([col[:int(0.9 * len(col))] for col in y])
        self.y_valid_all = np.array([col[int(0.9 * len(col)):] for col in y])

    def split(self, collaborators):
        for i, collaborator in enumerate(collaborators):
            collaborator.private_attributes = {
                "train_loader":
                    DataLoader(
                        TensorDataset(
                            pt.from_numpy(self.X_train_all[i]),
                            pt.from_numpy(self.y_train_all[i])
                        ), 
                        batch_size=batch_size, shuffle=True
                    ),
                "test_loader":
                    DataLoader(
                        TensorDataset(
                            pt.from_numpy(self.X_valid_all[i]),
                            pt.from_numpy(self.y_valid_all[i])
                        ), 
                        batch_size=batch_size, shuffle=True
                    )
            }

Model Class

In [None]:
class Net(nn.Module):
    """
    Model to train the dataset

    Args:
        None
    
    Returns:
        model: class Net object
    """
    def __init__(self):
        # Set RANDOM_STATE to reproduce same model
        pt.set_rng_state(pt.manual_seed(RANDOM_SEED).get_state())
        super(Net, self).__init__()
        self.linear1 = nn.Linear(60, 100)
        self.linear2 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

Loss Function

In [None]:
def cross_entropy(output, target):
    """
    cross-entropy metric

    Args:
        output: model ouput,
        target: target label

    Returns:
        crossentropy_loss: float
    """
    return F.cross_entropy(output, pt.max(target, 1)[1])

Model Test method

In [None]:
def inference(network, dataloader):
    """
    Model test method

    Args:
        network: class Net object (model)
        test_loader

    Returns:
        avg_test_accuracy
        avg_test_loss
    """
    network.eval()
    test_loss = 0
    correct = 0
    with pt.no_grad():
        for data, target in dataloader:
            output = network(data)
            test_loss += cross_entropy(output, target).item()
            tar = target.argmax(dim=1, keepdim=True)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(tar).sum().cpu().numpy()
    dataloader_size = len(dataloader.dataset)
    test_loss /= dataloader_size
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, dataloader_size,
        100. * correct / dataloader_size))
    accuracy = float(correct / dataloader_size)
    return accuracy, test_loss

WeightedAverage function to take the weighted average of model, optimizer or loss, accuracy list

In [None]:
def weighted_average(tensors, weights):
    """
    Take weighted average of models / optimizers / loss / accuracy

    Args:
        tensors: models state_dict list or optimizers state_dict list or loss list or accuracy list
        weights: Weight as per to each model
    
    Returns:
        dict / float: weighted average model / optimizzer / loss / accuracy
    """
    # Check if passed list tensors elements are of what type
    if type(tensors[0]) in (dict, OrderedDict):
        optimizer = False
        # If __opt_state_needed found then optimizer state dictionary is passed
        if "__opt_state_needed" in tensors[0]:
            optimizer = True
            # Remove __opt_state_needed from all state dictionary in list
            [tensor.pop("__opt_state_needed") for tensor in tensors]
        tmp_list = []
        # Take keys in order to rebuild the state dictionary taking keys back up
        input_state_dict_keys = tensors[0].keys()
        for tensor in tensors:
            # Append values of each state dictionary in list
            # If type(value) is Tensor then it needs to be detached
            tmp_list.append([value.detach() if type(value) is pt.Tensor else value for value in tensor.values()])
        # Take weighted average of list of arrays
        # new_params passed is weighted average of each array in tmp_list
        new_params = wa(tmp_list, weights)
        new_state = {}
        # Take weighted average parameters and building a dictionary
        [new_state.update({k:new_params[i]}) if optimizer else new_state.update({k:pt.from_numpy(new_params[i].numpy())}) \
            for i, k in enumerate(input_state_dict_keys)]
        return new_state
    else:
        return wa(tensors, weights)

`FLSpec` – Defines the flow specification. User defined flows are subclasses of this.

Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation` where the aggregated model test beings and we compute, `train loss`, `train accuracy`, `test loss` and `test accuracy` for each collaborator. The tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once metricies are computed flow moves onto `gather_results` step which will execute on aggregator and collect all the computed results and take `weighted_average` of those metricies. The next step `inner_optimization` again will run on collaborator and train all collaborator models. Then flow will move to `join` step which will be executed on aggregator and in `join` model weighted average will be taken. Since we are using `FedProxOptimizer` we do not require to take weighted_average of optimizer in case of stateful optimizer (e.g. FedProxAdam) it will be required. Next step is `end`, it will increment the round number and flow will end.

![flow graph](./graph.jpg)

In [None]:
class FederatedFlow(FLSpec):

    def __init__(self, model=None, optimizer=None, **kwargs):
        super(FederatedFlow, self).__init__(**kwargs)
        self.round_number = 0
        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            self.model = Net()
            self.optimizer = FedProxOptimizer(
                self.model.parameters(), lr=learning_rate, mu=mu, weight_decay=weight_decay)

    @aggregator
    def start(self):
        """
        Start of the flow
        """
        self.collaborators = self.runtime.collaborators

        self.next(self.aggregated_model_validation, foreach='collaborators')

    @collaborator
    def aggregated_model_validation(self):
        """
        Validate aggregated model.
        """
        print(
            f'Performing aggregated model validation for collaborator {self.input}')
        self.agg_validation_score, self.agg_validation_loss = inference(
            self.model, self.test_loader)
        print(f'{self.input} value of {self.agg_validation_score}')

        # Compute Train Loss and Train Acc
        self.training_accuracy, self.training_loss = inference(
            self.model, self.train_loader)

        self.train_ds = len(self.train_loader.dataset)
        self.test_ds = len(self.test_loader.dataset)

        self.next(self.gather_results)

    @aggregator
    def gather_results(self, inputs):
        """
        Gather results of all collaborators
        """
        # Calculate train_weights and test_weights
        train_size = sum([input.train_ds for input in inputs])
        self.train_weights = [input.train_ds / train_size for input in inputs]
        test_size = sum([input.test_ds for input in inputs])
        self.test_weights = [input.test_ds / test_size for input in inputs]

        # Weighted average of training loss
        self.training_loss = weighted_average(
            [input_.training_loss for input_ in inputs], self.train_weights)
        print(f'Average training loss = {self.training_loss}')

        # Weighted average of training accuracy
        self.training_accuracy = weighted_average(
            [input_.training_accuracy for input_ in inputs], self.train_weights)

        # Weighted average of aggregated model loss
        self.agg_validation_loss = weighted_average(
            [input_.agg_validation_loss for input_ in inputs], self.test_weights)

        # Weighted average of aggregated model accuracy
        self.aggregated_model_accuracy = weighted_average(
            [input_.agg_validation_score for input_ in inputs], self.test_weights)
        print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')

        # Randomly select 1/3rd of collaborator out
        self.selected_collaborator_indices = np.random.choice(range(len(self.collaborators)), \
            math.ceil(len(self.collaborators) / 3), replace=False)
        self.selected_collaborators = [self.collaborators[idx] for idx in self.selected_collaborator_indices]

        self.next(self.inner_optimization, foreach="collaborators")

    @collaborator
    def inner_optimization(self):
        """
        Collaborators Training
        """
        # Rebuild optimizer to pass the aggregated model parameters to optimizer
        self.optimizer = FedProxOptimizer(
            self.model.parameters(), lr=learning_rate, mu=mu, weight_decay=weight_decay)

        # Set current model parameters to optimizer after training the same parameters will become old
        self.optimizer.set_old_weights([p.clone().detach() for p in self.model.parameters()])

        train_loss = []
        correct = 0
        for epoch in range(E):
            for batch_idx, (data, target) in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = cross_entropy(output, target)
                loss.backward()
                self.optimizer.step()
                pred = output.argmax(dim=1, keepdim=True)
                tar = target.argmax(dim=1, keepdim=True)
                correct += pred.eq(tar).sum().cpu().numpy()
                train_loss.append(loss.item())
                if batch_idx % log_interval == 0:
                    print('Inner Optimization Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data), len(data),
                        100. * batch_idx * len(data) / len(data), loss.item()))

            inner_opt_training_accuracy = float(correct / len(self.train_loader.dataset))
            inner_opt_training_loss = np.mean(train_loss)
            print (f"Inner Optimization Training Loss: {inner_opt_training_loss}, and Accuracy: {inner_opt_training_accuracy}")

        self.next(self.join)
    
    @aggregator
    def join(self, inputs):
        """
        Aggregate Model.
        """
        self.loss_and_acc = {
            input_.input: {"Train Loss": [], "Agg Train Accuracy": []}
            for input_ in inputs
        }
        self.loss_and_acc.update({"Aggregated": {"Test Loss": [], "Test Accuracy": []}})

        for input_ in inputs:
            self.loss_and_acc[input_.input]["Train Loss"].append(input_.training_loss)
            self.loss_and_acc[input_.input]["Agg Train Accuracy"].append(input_.training_accuracy)
        self.loss_and_acc["Aggregated"]["Test Loss"].append(self.agg_validation_loss)
        self.loss_and_acc["Aggregated"]["Test Accuracy"].append(self.aggregated_model_accuracy)

        input_weights = []
        t_weights = []
        for c in self.selected_collaborator_indices:
            input_weights.append(inputs[c].model.state_dict())
            t_weights.append(self.train_weights[c])

        avg_model_dict = weighted_average(input_weights, t_weights)
        self.model.load_state_dict(avg_model_dict)

        self.optimizer = [input_.optimizer for input_ in inputs][0]

        self.next(self.end)

    @aggregator
    def end(self):
        """
        This is the 'end' step. All flows must have an 'end' step, which is the
        last step in the flow.
        """
        self.round_number += 1
        print('This is end of the flow')

- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.
- `aggregator/collaborator` – placement decorators that define where the task will be assigned

In [None]:
# Setup aggregator
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = [f"col{i}" for i in range(NUM_COLLABORATORS)]

collaborators = [Collaborator(name=name) for name in collaborator_names]

synthetic_federated_dataset = SyntheticFederatedDataset(
    batch_size=batch_size, num_classes=10, num_collaborators=len(collaborators), seed=RANDOM_SEED)
synthetic_federated_dataset.split(collaborators)

local_runtime = LocalRuntime(
    aggregator=aggregator, collaborators=collaborators, backend="single_process")

model = None
best_model = None
optimizer = None
top_model_accuracy = 0
loss_and_acc = {}

Executing FedProx

In [None]:
flflow = FederatedFlow(model, optimizer, checkpoint=True)
flflow.runtime = local_runtime
for i in range(n_epochs):
    print(f'Starting round {i}...')
    flflow.run()
    for k, v in flflow.loss_and_acc.items():
        if k == "Aggregated":
            if "FedProx" not in loss_and_acc:
                loss_and_acc["FedProx"] = {"Test Loss": [], "Test Accuracy": []}
            loss_and_acc["FedProx"]["Test Loss"].append(*v["Test Loss"])
            loss_and_acc["FedProx"]["Test Accuracy"].append(*v["Test Accuracy"])

    model = flflow.model
    optimizer = flflow.optimizer
    aggregated_model_accuracy = flflow.aggregated_model_accuracy
    if aggregated_model_accuracy > top_model_accuracy:
        print(f'Accuracy improved to {aggregated_model_accuracy} '
                + f'for round {i}')
        top_model_accuracy = aggregated_model_accuracy

Setting `mu = 0.0` will produce the FedAvg results

In [None]:
mu = 0.0

Execting FedAvg

In [None]:
flflow = FederatedFlow(model, optimizer, checkpoint=True)
flflow.runtime = local_runtime
for i in range(n_epochs):
    print(f'Starting round {i}...')
    flflow.run()
    for k, v in flflow.loss_and_acc.items():
        if k == "Aggregated":
            if "FedAvg" not in loss_and_acc:
                loss_and_acc["FedAvg"] = {"Test Loss": [], "Test Accuracy": []}
            loss_and_acc["FedAvg"]["Test Loss"].append(*v["Test Loss"])
            loss_and_acc["FedAvg"]["Test Accuracy"].append(*v["Test Accuracy"])

    model = flflow.model
    optimizer = flflow.optimizer
    aggregated_model_accuracy = flflow.aggregated_model_accuracy
    if aggregated_model_accuracy > top_model_accuracy:
        print(f'Accuracy improved to {aggregated_model_accuracy} '
                + f'for round {i}')
        top_model_accuracy = aggregated_model_accuracy

Plotting FedProx vs FedAvg graphs

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

fedprox_loss = loss_and_acc["FedProx"]["Test Loss"]
fedavg_loss = loss_and_acc["FedAvg"]["Test Loss"]
plt.subplot(1, 2, 1)
plt.plot(fedprox_loss)
plt.plot(fedavg_loss)
plt.legend(["FedProx Loss", "FedAvg Loss"])
plt.title("FedProx vs FedAvg Loss")

fedprox_accuracy = loss_and_acc["FedProx"]["Test Accuracy"]
fedavg_accuracy = loss_and_acc["FedAvg"]["Test Accuracy"]
plt.subplot(1, 2, 2)
plt.plot(fedprox_accuracy)
plt.plot(fedavg_accuracy)
plt.legend(["FedProx Accuracy", "FedAvg Accuracy"])
plt.title("FedProx vs FedAvg Accuracy")