# FedProx PyTorch MNIST Tutorial using Workflow API
This notebook sets up a distributed training federation which runs the `FedProx`[https://arxiv.org/abs/1812.06127] algorithm using OpenFL's  `Workflow API`[https://openfl.readthedocs.io/en/latest/about/features_index/workflowinterface.html] locally using a `LocalRuntime`[https://openfl.readthedocs.io/en/latest/about/features_index/workflowinterface.html#runtimes] - scalable to a federated setting in the future.


Import the relevant libraries

In [20]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.utils
import torch.utils.data
import torchvision
import torchvision.transforms as transforms

from openfl.utilities.optimizers.torch.fedprox import FedProxAdam

from openfl.experimental.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.runtime import LocalRuntime
from openfl.experimental.placement import aggregator, collaborator

Define the model:

In [21]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(32 * 5 * 5, 32)
        self.fc2 = nn.Linear(32, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

Set up the dataset:

In [22]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnist_train = torchvision.datasets.MNIST(
    "./files/",
    train=True,
    download=True,
    transform=transform,
)

mnist_test = torchvision.datasets.MNIST(
    "./files/",
    train=False,
    download=True,
    transform=transform,
)

class CustomDataset(torch.utils.data.Dataset):
    """Dataset enumeration as tensors"""
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

The next step is setting up the participants, an `Aggregator` and a few `Collaborator`s which will train the model, partition the dataset between the collaborators, and pass them to the appropriate runtime environment (in our case, a `LocalRuntime`).


In [23]:
def one_hot(labels, classes):
    return np.eye(classes)[labels]

# Setup participants
aggregator_ = Aggregator()
aggregator_.private_attributes = {}

# Setup collaborators with private attributes
collaborator_names = [f'collaborator{i}' for i in range(4)]
collaborators = [Collaborator(name=name) for name in collaborator_names]
batch_size_train = 1024
batch_size_test = 1024
log_interval = 10

for idx, collaborator_ in enumerate(collaborators):
    train_images, train_labels = mnist_train.train_data, np.array(mnist_train.train_labels)
    train_images = torch.from_numpy(np.expand_dims(train_images, axis=1)).float()
    train_labels = one_hot(train_labels, 10)

    valid_images, valid_labels = mnist_test.test_data, np.array(mnist_test.test_labels)
    valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()

    collaborator_.private_attributes = {
            'train_loader': torch.utils.data.DataLoader(
                CustomDataset(train_images[idx::len(collaborators)], 
                              train_labels[idx::len(collaborators)]), 
                              batch_size=batch_size_train, 
                              shuffle=True),
            'test_loader': torch.utils.data.DataLoader(
                CustomDataset(valid_images[idx::len(collaborators)], 
                              valid_labels[idx::len(collaborators)]), 
                              batch_size=batch_size_test, 
                              shuffle=True)
    }

local_runtime = LocalRuntime(aggregator=aggregator_, collaborators=collaborators, backend='single_process')


Define an aggregation algorithm, optimizer and a loss function:

In [24]:
# Aggregation algorithm
def FedAvg(models, weights=None):
    new_model = models[0]
    new_state_dict = dict()
    for key in new_model.state_dict().keys():
        new_state_dict[key] = torch.from_numpy(np.average([model.state_dict()[key].numpy() for model in models],
                                           axis=0, 
                                           weights=weights))

    new_model.load_state_dict(new_state_dict)
    return new_model

def get_optimizer(model):
    return FedProxAdam(model.parameters(), lr=1e-3, mu=0.01)

def cross_entropy(output, target):
    """Binary cross-entropy loss function"""
    return F.binary_cross_entropy_with_logits(input=output,target=target.float())

Set up work to be executed by the aggregator and the collaborators by extending `FLSpec`:

In [25]:
class FederatedFlow(FLSpec):
    def __init__(self, model=None, optimizer=None, rounds=10, **kwargs):
        super().__init__(**kwargs)
        self.model = model
        self.optimizer = optimizer
        self.rounds = rounds
        self.loss = 0.

    @aggregator
    def start(self):
        print(f'Performing initialization for model')
        self.collaborators = self.runtime.collaborators
        self.current_round = 0
        self.next(self.aggregated_model_validation, foreach='collaborators')

    def compute_accuracy(self, data_loader):
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in data_loader:
                output = self.model(data)
                test_loss += F.cross_entropy(output, target, size_average=False).item()
                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).sum()

        test_loss /= len(data_loader.dataset)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))
        accuracy = float(correct / len(data_loader.dataset))
        return accuracy

    @collaborator
    def aggregated_model_validation(self):
        print(f'Performing aggregated model validation for collaborator {self.input}, model: {id(self.model)}')
        self.agg_validation_score = self.compute_accuracy(self.test_loader)
        self.next(self.train)

    @collaborator
    def train(self):
        # Log after processing a quarter of the samples
        log_threshold = .25

        self.model.train()
        self.optimizer = get_optimizer(self.model)
        for batch_idx, (data, target) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            self.optimizer.step()

            if (len(data) * batch_idx) / len(self.train_loader.dataset) >= log_threshold:
                print('Train Epoch: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx * len(data), len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                self.loss = loss.item()
                log_threshold += .25
                torch.save(self.model.state_dict(), 'model.pth')
                torch.save(self.optimizer.state_dict(), 'optimizer.pth')
            
        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        print(f'Performing local model validation for collaborator {self.input}')
        self.local_validation_score = self.compute_accuracy(self.test_loader)
        print(
            f'Done with local model validation for collaborator {self.input}, Accuracy: {self.local_validation_score}')
        self.next(self.join)

    @aggregator
    def join(self, inputs):
        self.model = FedAvg([input.model for input in inputs])
        self.optimizer = inputs[0].optimizer
        self.current_round += 1

        self.average_loss = sum(input.loss for input in inputs) / len(inputs)
        self.aggregated_model_accuracy = sum(
            input.agg_validation_score for input in inputs) / len(inputs)
        self.local_model_accuracy = sum(
            input.local_validation_score for input in inputs) / len(inputs)
        print(f'Average aggregated model accuracy = {self.aggregated_model_accuracy}')
        print(f'Average training loss = {self.average_loss}')
        print(f'Average local model validation values = {self.local_model_accuracy}')

        if self.current_round < self.rounds:
            self.next(self.aggregated_model_validation, foreach='collaborators')
        else:
            self.next(self.end)

    @aggregator
    def end(self):
        print(f'Flow ended')

Aggregator step "start" registered
Collaborator step "aggregated_model_validation" registered
Collaborator step "train" registered
Collaborator step "local_model_validation" registered
Aggregator step "join" registered
Aggregator step "end" registered


Finally, run the federation:

In [26]:
model = Net()
flflow = FederatedFlow(model, get_optimizer(model), rounds=3, checkpoint=False)
flflow.runtime = local_runtime
flflow.run()


Calling start
[94mPerforming initialization for model[0m[94m
[0m
Calling aggregated_model_validation
[94mPerforming aggregated model validation for collaborator collaborator0, model: 140162497619616[0m[94m
[0m[94m
Test set: Avg. loss: 4.6833, Accuracy: 171/2500 (7%)
[0m[94m
[0m
Calling train
[0m
Calling local_model_validation
[94mPerforming local model validation for collaborator collaborator0[0m[94m
[0m[94m
Test set: Avg. loss: 0.7548, Accuracy: 1929/2500 (77%)
[0m[94m
[0m[94mDone with local model validation for collaborator collaborator0, Accuracy: 0.7716000080108643[0m[94m
[0mShould transfer from local_model_validation to join

Calling aggregated_model_validation
[94mPerforming aggregated model validation for collaborator collaborator1, model: 140158910463952[0m[94m
[0m[94m
Test set: Avg. loss: 4.7259, Accuracy: 173/2500 (7%)
[0m[94m
[0m
Calling train
[0m
Calling local_model_validation
[94mPerforming local model validation for collaborator collabo