# Workflow Interface 104: MNIST with Fedprox
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedprox.ipynb)

In this tutorial, we demonstate how to use FedProx aggregation algorithm to tackle data heterogeneity in federated setup.

- Fedprox is a generalization and reparameterization of FedAvg
- Demonstrates more stable and accurate convergence compared to FedAvg for non-iid datasets.
- It uses a proximal term in its calculations to help improve stability.
- Fedprox paper: https://arxiv.org/pdf/1812.06127.pdf

# Getting Started

First we start by installing the necessary dependencies for the workflow interface

In [None]:
!pip install git+https://github.com/intel/openfl.git
!pip install -r requirements_workflow_interface.txt

# Uncomment this if running in Google Colab
#!pip install -r https://raw.githubusercontent.com/intel/openfl/develop/openfl-tutorials/experimental/requirements_workflow_interface.txt
#import os
#os.environ["USERNAME"] = "colab"

We begin with the quintessential example of a small pytorch CNN model trained on the MNIST dataset. Let's start define our dataloaders, model, optimizer, and some helper functions like we would for any other deep learning experiment

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import numpy as np

batch_size_train = 64
batch_size_test = 1000


random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

mnist_train = torchvision.datasets.MNIST('files/', train=True, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

mnist_test = torchvision.datasets.MNIST('files/', train=False, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def compute_loss_and_acc(network,test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
        output = network(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy, test_loss, correct

Next we import the `FLSpec`, `LocalRuntime`, and placement decorators.

- `FLSpec` – Defines the flow specification. User defined flows are subclasses of this.
- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.
- `aggregator/collaborator` - placement decorators that define where the task will be assigned

In [None]:
from copy import deepcopy

from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

# Import the framework adapter plugin, Fedprox optimizer and aggregation function
from openfl.plugins.frameworks_adapters.pytorch_adapter import FrameworkAdapterPlugin as fa
from openfl.experimental.interface.aggregation_functions.fedprox import FedProxAgg
from openfl.utilities.optimizers.torch import FedProxOptimizer

from collections import OrderedDict


Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.


In [None]:
class FedProxFlow(FLSpec):

    def __init__(self, model=None, optimizer=None, n_selected_collaborators=10, n_rounds=10, **kwargs):
        super(FedProxFlow, self).__init__(**kwargs)
        self.round_number = 1
        self.n_selected_collaborators = n_selected_collaborators
        self.n_rounds = n_rounds
        self.loss_and_acc = {"Train Loss": [], "Test Accuracy": []}
        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            self.model = Net()
            self.optimizer = FedProxOptimizer(
                self.model.parameters(), lr=learning_rate, mu=mu, weight_decay=weight_decay)

        self.agg_func = FedProxAgg()

    @aggregator
    def start(self):
        """
        Start of the flow. Call compute_loss_and_accuracy step for each collaborator
        """
        print(f'\nStarting round number {self.round_number} .... \n')
        self.collaborators = self.runtime.collaborators
        self.next(self.compute_loss_and_accuracy, foreach='collaborators')

    @collaborator
    def compute_loss_and_accuracy(self):
        """
        Compute training accuracy, training loss, aggregated validation accuracy,
        aggregated validation loss, 
        """
        # Compute Train Loss and Train Acc
        self.training_accuracy, self.training_loss, _, = compute_loss_and_acc(
            self.model, self.train_loader)
        
        # Compute Test Loss and Test Acc
        self.agg_validation_score, self.agg_validation_loss, test_correct = compute_loss_and_acc(
            self.model, self.test_loader)

        self.train_dataset_length = len(self.train_loader.dataset)
        self.test_dataset_length = len(self.test_loader.dataset)

        print(
            "<Collab: {:<5}> | Train Round: {:<5} : Train Loss {:<.6f}, Test Acc: {:<.6f} [{}/{}]".format(
                self.input,
                self.round_number,
                self.training_loss,
                self.agg_validation_score, 
                test_correct, 
                self.test_dataset_length
            )
        )

        self.next(self.gather_results_and_take_weighted_average)

    @aggregator
    def gather_results_and_take_weighted_average(self, inputs):
        """
        Gather results of all collaborators computed in previous 
        step.
        Compute train and test weightes, and compute weighted average of 
        aggregated training loss, and aggregated test accuracy
        """
        # Calculate train_weights and test_weights
        train_datasize, test_datasize = [], []
        for input_ in inputs:
            train_datasize.append(input_.train_dataset_length)
            test_datasize.append(input_.test_dataset_length)

        self.train_weights, self.test_weights = [], []
        for input_ in inputs:
            self.train_weights.append(input_.train_dataset_length / sum(train_datasize))
            self.test_weights.append(input_.test_dataset_length / sum(test_datasize))

        aggregated_model_accuracy_list, aggregated_model_loss_list = [], []
        for input_ in inputs:
            aggregated_model_loss_list.append(input_.training_loss)
            aggregated_model_accuracy_list.append(input_.agg_validation_score)

        self.aggregated_model_training_loss, self.aggregated_model_test_accuracy = self.agg_func.aggregate_metrics(
            [aggregated_model_loss_list, aggregated_model_accuracy_list], [self.train_weights, self.test_weights])
        
        # Store experiment results
        self.loss_and_acc["Train Loss"].append(self.aggregated_model_training_loss)
        self.loss_and_acc["Test Accuracy"].append(self.aggregated_model_test_accuracy)

        print(
            "<Agg> | Train Round: {:<5} : Agg Train Loss {:<.6f}, Agg Test Acc: {:<.6f}".format(
                self.round_number,
                self.aggregated_model_training_loss,
                self.aggregated_model_test_accuracy
            )
        )

        self.next(self.select_collaborators)

    @aggregator
    def select_collaborators(self):
        """
        Randomly select n_selected_collaborators collaborator
        """
        np.random.seed(self.round_number)
        self.selected_collaborator_indices = np.random.choice(range(len(self.collaborators)), \
            self.n_selected_collaborators, replace=False)
        self.selected_collaborators = [self.collaborators[idx] for idx in self.selected_collaborator_indices]

        self.next(self.train_selected_collaborators, foreach="selected_collaborators")


    @collaborator
    def train_selected_collaborators(self):
        """
        Train selected collaborators
        """
        self.model.train(mode=True)

        self.train_dataset_length = len(self.train_loader.dataset)

        # Rebuild the optimizer with global model parameters
        self.optimizer = FedProxOptimizer(
            self.model.parameters(), lr=learning_rate, mu=mu, weight_decay=weight_decay)
        # Set global model parameters as old weights to enable computation of proximal term
        self.optimizer.set_old_weights([p.clone().detach() for p in self.model.parameters()])

        for epoch in range(local_epoch):
            train_loss = []
            correct = 0
            for data, target in self.train_loader:
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                self.optimizer.step()
                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).sum()
                train_loss.append(loss.item())
            training_accuracy = float(correct / self.train_dataset_length)
            training_loss = np.mean(train_loss)
            print(
                "<Collab: {:<5}> | Train Round: {:<5} | Local Epoch: {:<3}: FedProx Optimization Train Loss {:<.6f}, Train Acc: {:<.6f} [{}/{}]".format(
                    self.input,
                    self.round_number,
                    epoch,
                    training_loss,
                    training_accuracy,
                    correct, 
                    len(self.train_loader.dataset)
                )
            )

        self.next(self.join)
    
    @aggregator
    def join(self, inputs):
        """
        Compute train dataset, and take weighted average of model.
        """
        train_datasize = sum([input_.train_dataset_length for input_ in inputs])
        
        train_tensors, train_weights = [], [] 
        for input_ in inputs:
            train_weights.append(input_.train_dataset_length / train_datasize)
            train_tensors.append([v for k,v in (fa.get_tensor_dict(input_.model)).items()])
            keys_list=[k for k,v in (fa.get_tensor_dict(input_.model)).items()]

        avg_tensors = self.agg_func.aggregate_models(train_tensors, train_weights)
        state_dict = dict(zip(keys_list, avg_tensors))
        fa.set_tensor_dict(self.model, state_dict)
        
        self.next(self.internal_loop)

    @aggregator
    def internal_loop(self):
        """
        Check if training is finished for `self.n_rounds`
        if finished move to end step. Otherwise, go back to start
        step for next round of training.
        """
        if self.round_number < self.n_rounds:
            self.round_number += 1
            self.next(self.start)
        else:
            self.next(self.end)

    @aggregator
    def end(self):
        """
        This is the 'end' step.
        """
        self.round_number += 1
        print('This is end of the flow')

You'll notice in the `FederatedFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** that are exposed only throught he runtime. Each participant has it's own set of private attributes: a dictionary where the key is the attribute name, and the value is the object that will be made accessible through that participant's task. 

Below, we segment shards of the MNIST dataset for **seven collaborators**: Portland, Seattle, Chandler, Bangalore, Guadalajara, Santa Clara and San Jose. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa.  

In [None]:
# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore', 'Guadalajara', 'Santa Clara', 'San Jose']
collaborators = [Collaborator(name=name) for name in collaborator_names]
# Keep a list of collaborator weights. The weights are decided by the number of samples for each collaborator
collaborators_weights_dict = {}

for idx, collaborator in enumerate(collaborators):
    local_train = deepcopy(mnist_train)
    local_test = deepcopy(mnist_test)
    local_train.data = mnist_train.data[idx::len(collaborators)]
    local_train.targets = mnist_train.targets[idx::len(collaborators)]
    local_test.data = mnist_test.data[idx::len(collaborators)]
    local_test.targets = mnist_test.targets[idx::len(collaborators)]
    collaborator.private_attributes = {
            'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),
            'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)
    }
    collaborators_weights_dict[collaborator] = len(local_train.data)

for col in collaborators_weights_dict:
    collaborators_weights_dict[col] /= len(mnist_train.data)

if len(collaborators_weights_dict) != 0:
        assert np.abs(1.0 - sum(collaborators_weights_dict.values())) < 0.01, (
            f'Collaborator weights do not sum to 1.0: {collaborators_weights_dict}'
        )

local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')
print(f'Local runtime collaborators = {local_runtime.collaborators}')

Now that we have our flow and runtime defined, let's run the experiment! 

In [None]:
model = None
best_model = None
optimizer = None
n_selected_collaborators = 2
n_epochs = 5
learning_rate = 0.01
weight_decay = 0.001
local_epoch = 5

# Set `mu` to `1.0` for FedProx
mu = 1.0

flflow = FedProxFlow(n_selected_collaborators=n_selected_collaborators, n_rounds=n_epochs, checkpoint=False)
flflow.runtime = local_runtime

flflow.run()