# Federated PyTorch MNIST Tutorial

In [None]:
#Install dependencies if not already installed
!pip install torch torchvision

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import openfl.native as fx
from openfl.federated import FederatedModel,FederatedDataSet

torch.manual_seed(0)
np.random.seed(0)

After importing the required packages, the next step is setting up our openfl workspace. To do this, simply run the `fx.init()` command as follows:

In [None]:
#Setup default workspace, logging, etc.
fx.init('torch_cnn_mnist')

Now we are ready to define our dataset and model to perform federated learning on. The dataset should be composed of a numpy arrayWe start with a simple fully connected model that is trained on the MNIST dataset. 

In [None]:
def one_hot(labels, classes):
    return np.eye(classes)[labels]

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)

train_images,train_labels = trainset.train_data, np.array(trainset.train_labels)
train_images = torch.from_numpy(np.expand_dims(train_images, axis=1)).float()
train_labels = one_hot(train_labels,10)

validset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)

valid_images,valid_labels = validset.test_data, np.array(validset.test_labels)
valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()
valid_labels = one_hot(valid_labels,10)

In [None]:
feature_shape = train_images.shape[1]
classes       = 10

fl_data = FederatedDataSet(train_images,train_labels,valid_images,valid_labels,batch_size=32,num_classes=classes)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(32 * 5 * 5, 32)
        self.fc2 = nn.Linear(32, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
optimizer = lambda x: optim.Adam(x, lr=1e-4)

def cross_entropy(output, target):
    """Binary cross-entropy metric
    """
    return F.cross_entropy(input=output,target=torch.argmax(target, dim=1))

In [None]:

#Create a federated model using the pytorch class, lambda optimizer function, and loss function
fl_model = FederatedModel(build_model=Net,optimizer=optimizer,loss_fn=cross_entropy,data_loader=fl_data)

The `FederatedModel` object is a wrapper around your Keras, Tensorflow or PyTorch model that makes it compatible with openfl. It provides built in federated training and validation functions that we will see used below. Using it's `setup` function, collaborator models and datasets can be automatically defined for the experiment. 

In [None]:
collaborator_models = fl_model.setup(num_collaborators=10)
collaborators = {str(i): collaborator_models[i] for i in range(10)}#, 'three':collaborator_models[2]}

In [None]:
#Original MNIST dataset
print(f'Original training data size: {len(train_images)}')
print(f'Original validation data size: {len(valid_images)}\n')

#Collaborator one's data
print(f'Collaborator one\'s training data size: {len(collaborator_models[0].data_loader.X_train)}')
print(f'Collaborator one\'s validation data size: {len(collaborator_models[0].data_loader.X_valid)}\n')

#Collaborator two's data
print(f'Collaborator two\'s training data size: {len(collaborator_models[1].data_loader.X_train)}')
print(f'Collaborator two\'s validation data size: {len(collaborator_models[1].data_loader.X_valid)}\n')

#Collaborator three's data
#print(f'Collaborator three\'s training data size: {len(collaborator_models[2].data_loader.X_train)}')
#print(f'Collaborator three\'s validation data size: {len(collaborator_models[2].data_loader.X_valid)}')

We can see the current plan values by running the `fx.get_plan()` function

In [None]:
 #Get the current values of the plan. Each of these can be overridden
print(fx.get_plan())

In [None]:
from openfl.interface.aggregation_functions import AggregationFunction
import numpy as np

class ExponentialSmoothingAveraging(AggregationFunction):
    """
        Averaging via exponential smoothing.
        
        In order to use this mechanism properly you should specify `aggregator.settings.db_store_rounds` 
        in `override_config` keyword argument of `run_experiment` function. 
        It should be equal to the number of rounds you want to include in smoothing window.
        
        Args:
            alpha(float): Smoothing term.
    """
    def __init__(self, alpha=0.9):
        self.alpha = alpha
        
    def call(self,
             local_tensors,
             db_iterator,
             tensor_name,
             fl_round,
             tags):
        """Aggregate tensors.

        Args:
            local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate.
            db_iterator: iterator over history of all tensors. Columns:
                - 'tensor_name': name of the tensor.
                    Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'.
                - 'round': 0-based number of round corresponding to this tensor.
                - 'tags': tuple of tensor tags. Tags that can appear:
                    - 'model' indicates that the tensor is a model parameter.
                    - 'trained' indicates that tensor is a part of a training result.
                        These tensors are passed to the aggregator node after local learning.
                    - 'aggregated' indicates that tensor is a result of aggregation.
                        These tensors are sent to collaborators for the next round.
                    - 'delta' indicates that value is a difference between rounds
                        for a specific tensor.
                    also one of the tags is a collaborator name
                    if it corresponds to a result of a local task.

                - 'nparray': value of the tensor.
            tensor_name: name of the tensor
            fl_round: round number
            tags: tuple of tags for this tensor
        Returns:
            np.ndarray: aggregated tensor
        """
        tensors, weights = zip(*[(x.tensor, x.weight) for x in local_tensors])
        tensors, weights = np.array(tensors), np.array(weights)
        average = np.average(tensors, weights=weights, axis=0)
        previous_tensor_values = []
        for record in db_iterator:
            if (
                record['tensor_name'] == tensor_name
                and 'aggregated' in record['tags']
                and 'delta' not in record['tags']
               ):
                previous_tensor_values.append(record['nparray'])
        for i, x in enumerate(previous_tensor_values):
            previous_tensor_values[i] = x * self.alpha * (1 - self.alpha) ** i
        smoothing_term = np.sum(previous_tensor_values, axis=0)
        return self.alpha * average + (1 - self.alpha) * smoothing_term

In [None]:
from openfl.interface.aggregation_functions import AggregationFunction
import numpy as np

class ClippedAveraging(AggregationFunction):
    def __init__(self, ratio):
        """Average clipped tensors.
            
            Args:
                ratio(float): Ratio to multiply with a tensor for clipping
        """
        self.ratio = ratio
        
    def call(self,
             local_tensors,
             db_iterator,
             tensor_name,
             fl_round,
             *__):
        """Aggregate tensors.

        Args:
            local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate.
            db_iterator: iterator over history of all tensors. Columns:
                - 'tensor_name': name of the tensor.
                    Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'.
                - 'round': 0-based number of round corresponding to this tensor.
                - 'tags': tuple of tensor tags. Tags that can appear:
                    - 'model' indicates that the tensor is a model parameter.
                    - 'trained' indicates that tensor is a part of a training result.
                        These tensors are passed to the aggregator node after local learning.
                    - 'aggregated' indicates that tensor is a result of aggregation.
                        These tensors are sent to collaborators for the next round.
                    - 'delta' indicates that value is a difference between rounds
                        for a specific tensor.
                    also one of the tags is a collaborator name
                    if it corresponds to a result of a local task.

                - 'nparray': value of the tensor.
            tensor_name: name of the tensor
            fl_round: round number
            tags: tuple of tags for this tensor
        Returns:
            np.ndarray: aggregated tensor
        """
        clipped_tensors = []
        previous_tensor_value = None
        for record in db_iterator:
            if (
                record['round'] == (fl_round - 1)
                and record['tensor_name'] == tensor_name
                and record['tags'] == ('trained',)
               ):
                previous_tensor_value = record['nparray']
        weights = []
        for local_tensor in local_tensors:
            prev_tensor = previous_tensor_value if previous_tensor_value is not None else local_tensor.tensor
            delta = local_tensor.tensor - prev_tensor
            new_tensor = prev_tensor + delta * self.ratio
            clipped_tensors.append(new_tensor)
            weights.append(local_tensor.weight)

        return np.average(clipped_tensors, weights=weights, axis=0)

In [None]:
from openfl.interface.aggregation_functions import AggregationFunction

class ConditionalThresholdAveraging(AggregationFunction):
    def __init__(self, threshold_fn, metric_name='acc', tags=['metric', 'validate_local']):
        """Average tensors by metric value on previous round.
        If no tensors match threshold condition, a simple weighted averaging will be performed.
           
           Args:
               threshold_fn(callable): function to define a threshold for each round.
                   Has single argument `round_number`. 
                   Returns threshold value above which collaborators are allowed to participate in aggregation.
               metric_name(str): name of the metric to trace. Can be either 'acc' or 'loss'.
               tags(Tuple[str]): tags of the metric tensor.
        """
        self.metric_name = metric_name
        self.threshold_fn = threshold_fn
        self.tags = tags
        self.logged_round = -1
        
    def call(self,
             local_tensors,
             db_iterator,
             tensor_name,
             fl_round,
             *__):
        """Aggregate tensors.

        Args:
            local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate.
            db_iterator: iterator over history of all tensors. Columns:
                - 'tensor_name': name of the tensor.
                    Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'.
                - 'round': 0-based number of round corresponding to this tensor.
                - 'tags': tuple of tensor tags. Tags that can appear:
                    - 'model' indicates that the tensor is a model parameter.
                    - 'trained' indicates that tensor is a part of a training result.
                        These tensors are passed to the aggregator node after local learning.
                    - 'aggregated' indicates that tensor is a result of aggregation.
                        These tensors are sent to collaborators for the next round.
                    - 'delta' indicates that value is a difference between rounds
                        for a specific tensor.
                    also one of the tags is a collaborator name
                    if it corresponds to a result of a local task.

                - 'nparray': value of the tensor.
            tensor_name: name of the tensor
            fl_round: round number
            tags: tuple of tags for this tensor
        Returns:
            np.ndarray: aggregated tensor
        """
        selected_tensors = []
        selected_weights = []
        for record in db_iterator:
            for local_tensor in local_tensors:
                tags = set(self.tags + [local_tensor.col_name])
                if (
                    tags <= set(record['tags']) 
                    and record['round'] == fl_round
                    and record['tensor_name'] == self.metric_name
                    and record['nparray'] >= self.threshold_fn(fl_round)
                ):
                    selected_tensors.append(local_tensor.tensor)
                    selected_weights.append(local_tensor.weight)
        if not selected_tensors:
            if self.logged_round < fl_round:
                fx.logger.warning('No collaborators match threshold condition. Performing simple averaging...')
            selected_tensors = [local_tensor.tensor for local_tensor in local_tensors]
            selected_weights = [local_tensor.weight for local_tensor in local_tensors]
        if self.logged_round < fl_round:
            self.logged_round += 1
        return np.average(selected_tensors, weights=selected_weights, axis=0)

# Privileged Aggregation Functions
Most of the time the `AggregationFunction` interface is sufficient to implement custom methods, but in certain scenarios users may want to store additional information inside the TensorDB Dataframe beyond the aggregated tensor. The `openfl.interface.aggregation_functions.experimental.PrivilegedAggregationFunction` interface is provided for this use, and gives the user direct access to aggregator's TensorDB dataframe (notice the `tensor_db` param in the call function replaces the `db_iterator` from the standard AggregationFunction interface). As the name suggests, this interface is called privileged because with great power comes great responsibility, and modifying the TensorDB dataframe directly can lead to unexpected behavior and experiment failures if entries are arbitrarily deleted.

Note that in-place methods (`.loc`) on the tensor_db dataframe are required for write operations. 

In [None]:
from openfl.interface.aggregation_functions.experimental import PrivilegedAggregationFunction
import numpy as np
import pandas as pd

class PrioritizeLeastImproved(PrivilegedAggregationFunction):
    """
        Give collaborator with the least improvement in validation accuracy more influence over future weights
        
    """
        
    def call(self,
             local_tensors,
             tensor_db,
             tensor_name,
             fl_round,
             tags):
        """Aggregate tensors.

        Args:
            local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate.
            tensor_db: Aggregator's TensorDB [writable]. Columns:
                - 'tensor_name': name of the tensor.
                    Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'.
                - 'round': 0-based number of round corresponding to this tensor.
                - 'tags': tuple of tensor tags. Tags that can appear:
                    - 'model' indicates that the tensor is a model parameter.
                    - 'trained' indicates that tensor is a part of a training result.
                        These tensors are passed to the aggregator node after local learning.
                    - 'aggregated' indicates that tensor is a result of aggregation.
                        These tensors are sent to collaborators for the next round.
                    - 'delta' indicates that value is a difference between rounds
                        for a specific tensor.
                    also one of the tags is a collaborator name
                    if it corresponds to a result of a local task.

                - 'nparray': value of the tensor.
            tensor_name: name of the tensor
            fl_round: round number
            tags: tuple of tags for this tensor
        Returns:
            np.ndarray: aggregated tensor
        """
        from openfl.utilities import change_tags

        tensors, weights, collaborators = zip(*[(x.tensor, x.weight, x.col_name) for idx,x in enumerate(local_tensors)])
        tensors, weights, collaborators = np.array(tensors), np.array(weights), collaborators

        if fl_round > 0:
            metric_tags = ('metric','validate_agg')
            collaborator_accuracy = {}
            previous_col_accuracy = {}
            change_in_accuracy = {}
            for col in collaborators:
                col_metric_tag = change_tags(metric_tags,add_field=col)
                collaborator_accuracy[col] = float(tensor_db[(tensor_db['tensor_name'] == 'acc') &
                                                       (tensor_db['round'] == fl_round) &
                                                       (tensor_db['tags'] == col_metric_tag)]['nparray'])
                previous_col_accuracy[col] = float(tensor_db[(tensor_db['tensor_name'] == 'acc') &
                                                       (tensor_db['round'] == fl_round - 1) &
                                                       (tensor_db['tags'] == col_metric_tag)]['nparray'])
                change_in_accuracy[col] = collaborator_accuracy[col] - previous_col_accuracy[col]
                
        
            least_improved_collaborator = min(change_in_accuracy,key=change_in_accuracy.get)
            
            # Dont add least improved collaborator more than once
            if len(tensor_db[(tensor_db['tags'] == ('least_improved',)) &
                         (tensor_db['round'] == fl_round)]) == 0:
                tensor_db.loc[tensor_db.shape[0]] = \
                        ['_','_',fl_round,True,('least_improved',),np.array(least_improved_collaborator)]
            least_improved_weight_factor = 0.1 * len(tensor_db[(tensor_db['tags'] == ('least_improved',)) &
                                                               (tensor_db['nparray'] == np.array(least_improved_collaborator))])
            weights[collaborators.index(least_improved_collaborator)] += least_improved_weight_factor
            weights = weights / np.sum(weights)
            
        return np.average(tensors, weights=weights, axis=0)

To make the process of writing, reading from, and searching through dataframes easier, we add three methods to the tensor_db dataframe. `store`, `retrieve`, and `search`. Power users can still use all of the built-in pandas dataframe methods, but because some prior knowledge is needed to effectively deal with dataframe column types, iterating through them, and how to store them in a consistent way that won't break other OpenFL functionality, these three methods provide a conventient way to let researchers focus on algorithms instead internal framework machinery.  

In [None]:
class FedAvgM_Selection(PrivilegedAggregationFunction):
    """
        Adapted from FeTS Challenge 2021
        Federated Brain Tumor Segmentation:Multi-Institutional Privacy-Preserving Collaborative Learning
        Ece Isik-Polat, Gorkem Polat,Altan Kocyigit1, and Alptekin Temizel1
        
    """
        
    def call(
             self,
             local_tensors,
             tensor_db,
             tensor_name,
             fl_round,
             tags):
    
        """Aggregate tensors.

        Args:
            local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate.
            tensor_db: Aggregator's TensorDB [writable]. Columns:
                - 'tensor_name': name of the tensor.
                    Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'.
                - 'round': 0-based number of round corresponding to this tensor.
                - 'tags': tuple of tensor tags. Tags that can appear:
                    - 'model' indicates that the tensor is a model parameter.
                    - 'trained' indicates that tensor is a part of a training result.
                        These tensors are passed to the aggregator node after local learning.
                    - 'aggregated' indicates that tensor is a result of aggregation.
                        These tensors are sent to collaborators for the next round.
                    - 'delta' indicates that value is a difference between rounds
                        for a specific tensor.
                    also one of the tags is a collaborator name
                    if it corresponds to a result of a local task.

                - 'nparray': value of the tensor.
            tensor_name: name of the tensor
            fl_round: round number
            tags: tuple of tags for this tensor
        Returns:
            np.ndarray: aggregated tensor
        """
        #momentum
        tensor_db.store(tensor_name='momentum',nparray=0.9,overwrite=False)
        #aggregator_lr
        tensor_db.store(tensor_name='aggregator_lr',nparray=1.0,overwrite=False)

        if fl_round == 0:
            # Just apply FedAvg

            tensor_values = [t.tensor for t in local_tensors]
            weight_values = [t.weight for t in local_tensors]               
            new_tensor_weight =  np.average(tensor_values, weights=weight_values, axis=0)        

            #if not (tensor_name in weight_speeds):
            if tensor_name not in tensor_db.search(tags=('weight_speeds',))['tensor_name']:    
                #weight_speeds[tensor_name] = np.zeros_like(local_tensors[0].tensor) # weight_speeds[tensor_name] = np.zeros(local_tensors[0].tensor.shape)
                tensor_db.store(
                    tensor_name=tensor_name, 
                    tags=('weight_speeds',), 
                    nparray=np.zeros_like(local_tensors[0].tensor),
                )
            return new_tensor_weight        
        else:
            if tensor_name.endswith("weight") or tensor_name.endswith("bias"):
                # Calculate aggregator's last value
                previous_tensor_value = None
                for _, record in tensor_db.iterrows():
                    if (record['round'] == fl_round 
                        and record["tensor_name"] == tensor_name
                        and record["tags"] == ("aggregated",)): 
                        previous_tensor_value = record['nparray']
                        break

                if previous_tensor_value is None:
                    logger.warning("Error in fedAvgM: previous_tensor_value is None")
                    logger.warning("Tensor: " + tensor_name)

                    # Just apply FedAvg       
                    tensor_values = [t.tensor for t in local_tensors]
                    weight_values = [t.weight for t in local_tensors]               
                    new_tensor_weight =  np.average(tensor_values, weights=weight_values, axis=0)        
                    
                    if tensor_name not in tensor_db.search(tags=('weight_speeds',))['tensor_name']:    
                        tensor_db.store(
                            tensor_name=tensor_name, 
                            tags=('weight_speeds',), 
                            nparray=np.zeros_like(local_tensors[0].tensor),
                        )

                    return new_tensor_weight
                else:
                    # compute the average delta for that layer
                    deltas = [previous_tensor_value - t.tensor for t in local_tensors]
                    weight_values = [t.weight for t in local_tensors]
                    average_deltas = np.average(deltas, weights=weight_values, axis=0) 

                    # V_(t+1) = momentum*V_t + Average_Delta_t
                    tensor_weight_speed = tensor_db.retrieve(
                        tensor_name=tensor_name,
                        tags=('weight_speeds',)
                    )
                    
                    momentum = float(tensor_db.retrieve(tensor_name='momentum'))
                    aggregator_lr = float(tensor_db.retrieve(tensor_name='aggregator_lr'))
                    
                    new_tensor_weight_speed = momentum * tensor_weight_speed + average_deltas # fix delete (1-momentum)
                    
                    tensor_db.store(
                        tensor_name=tensor_name, 
                        tags=('weight_speeds',), 
                        nparray=new_tensor_weight_speed
                    )
                    # W_(t+1) = W_t-lr*V_(t+1)
                    new_tensor_weight = previous_tensor_value - aggregator_lr*new_tensor_weight_speed

                    return new_tensor_weight
            else:
                # Just apply FedAvg       
                tensor_values = [t.tensor for t in local_tensors]
                weight_values = [t.weight for t in local_tensors]               
                new_tensor_weight =  np.average(tensor_values, weights=weight_values, axis=0)

                return new_tensor_weight

In [None]:
#Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(collaborators,
                                   {
                                       'aggregator.settings.rounds_to_train':5,
                                       'aggregator.settings.db_store_rounds':5,
                                       'tasks.train.aggregation_type': ClippedAveraging(ratio=0.9)
                                   })

In [None]:
#Save final model
final_fl_model.save_native('final_pytorch_model')