# Secure Aggregation with MNIST

### Introduction

In Federated Learning (FL), Secure Aggregation (SecAgg) is a technique that allows the participants to collaborate on the central model without revealing their individual contributions (local model updates). The goal is to allow participants to aggregate their model updates in a secure and privacy-preserving fashion. 

SecAgg is therefore used to:
- Safeguard the updates sent from the clients from interception or manipulation from malicious actors. 
- Increase the trust in a federation as it guarantees that the client updates are private and not accessible by any other participant. 

In [1]:
#| default_exp experiment

### Installing Pre-requisties
We start by installing OpenFL and dependencies of the workflow interface. These dependencies are exported and become requirements for the Federated Learning Environment 

In [2]:
#| export

!pip install git+https://github.com/securefederatedai/openfl.git
!pip install -r ../../workflow_interface_requirements.txt
!pip install pycryptodome
!pip install torch==2.4.1
!pip install torchvision==0.18.1
!pip install -U ipywidgets


### Model definition

We use the quintessential example of a pytorch CNN model trained on MNIST dataset to demonstrate the Secure Aggregation.

In [3]:
# | export

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Hyperparameters
learning_rate = 0.01
momentum = 0.5
batch_size = 32
log_interval = 10

# Model definition
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)


# Helper function to validate the model
def validate(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy


# Helper function to train the model
def _train_model(model, optimizer, data_loader):
    train_loss = 0
    model.train()
    for _, (X, y) in enumerate(data_loader):
        optimizer.zero_grad()
        output = model(X)
        loss = F.nll_loss(output, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * len(X)

    train_loss /= len(data_loader.dataset)
    return train_loss


# Helper function to initialize seed for reproducibility
def initialize_seed(random_seed=42):
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)

### Dataset definition

We now download the training and test datasets of MNIST which are used to run the experiment using LocalRuntime.

In [4]:
#| export

import torchvision

# Train and Test datasets
mnist_train = torchvision.datasets.MNIST(
    "../files/",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,)),
        ]
    ),
)

mnist_test = torchvision.datasets.MNIST(
    "../files/",
    train=False,
    download=True,
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,)),
        ]
    ),
)

### Secure Aggregation definitions

We will be using [Practical Secure Aggregation for Privacy-Preserving Machine Learning](https://eprint.iacr.org/2017/281.pdf) as a reference for the steps involved in secure aggregation.

We will use the secure aggregation utility functions built-in to openfl.

In [5]:
#| export

from openfl.utilities.secagg import (
    create_secret_shares,
    reconstruct_secret,
    create_ciphertext,
    decipher_ciphertext,
    generate_agreed_key,
    calculate_mask
)



def masked_train_model(
    model, optimizer, data_loader,
    collaborator_index, agreed_keys, private_seed,
):
    """
    Helper function to perform training of the model giving masked input
    vector instead of loss directly.
    """
    loss = _train_model(model, optimizer, data_loader)
    collaborator_mask = calculate_mask(collaborator_index, agreed_keys, private_seed)

    return loss + collaborator_mask

### Workflow definition

Next we import the `FLSpec`, placement decorators (`aggregator/collaborator`), and define the `FedAvg` helper function

In [6]:
# | export

from copy import deepcopy

from openfl.experimental.workflow.interface import FLSpec
from openfl.experimental.workflow.placement import aggregator, collaborator

# Helper function for federated averaging
def FedAvg(agg_model, models, weights=None):
    state_dicts = [model.state_dict() for model in models]
    agg_state_dict = agg_model.state_dict()
    for key in models[0].state_dict():
        agg_state_dict[key] = torch.from_numpy(
            np.average([state[key].numpy() for state in state_dicts], axis=0, weights=weights)
        )

    agg_model.load_state_dict(agg_state_dict)
    return agg_model

Let us now define the Workflow.

![SecAggFlow](mermaid-flow.png)

In [None]:
# | export
import struct

from Crypto.PublicKey import ECC


class SecureAggregation_MNIST(FLSpec):
    """
    Federated Flow to train a CNN on MNIST dataset using Secure Aggregation
    """

    def __init__(self, model=None, optimizer=None, learning_rate=1e-2, momentum=0.5, rounds=3, **kwargs):
        super().__init__(**kwargs)

        if model is not None:
            self.model = model
            self.optimizer = optimizer
        else:
            initialize_seed()
            self.model = Net()
            self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=momentum)

        self.learning_rate = learning_rate
        self.momentum = momentum
        self.rounds = rounds
        self.results = []
        self.collaborator_secagg_data = {}

    @aggregator
    def start(self):
        """
        This is the start of the Flow.
        """
        print(f"Initializing Workflow .... ")

        self.collaborators = self.runtime.collaborators
        self.collaborator_count = len(self.collaborators)
        self.current_round = 0

        self.next(self.generate_keys, foreach="collaborators")

    @collaborator
    def generate_keys(self):
        """
        Receives the collaborator IDs for all the collaborators.
        Shares the public keys of the key pairs generated as private
        attributes with the aggregator.
        """
        print(f"<Collab: {self.input}> generating key pairs for secure aggregation...")
        # Generate key pairs for secure aggregation.
        # The private attribute self.private_key is a list of 2 private keys of type
        # Crypto.PublicKey.ECC.EccKey.
        self.public_key = [
            ECC.import_key(pvt_key).public_key().export_key(format="PEM")
            for pvt_key in self.private_key
        ]

        self.next(self.generate_collaborator_index)

    @aggregator
    def generate_collaborator_index(self, inputs):
        """
        Shares the collaborators' public keys with all the collaborators
        in the experiment.
        """
        print(f"<Agg>: Generating unique indices for collaborators...")
        index = 1
        # This step is needed as the aggregator also needs to have the
        # public keys for all collaborators.
        for input in inputs:
            self.collaborator_secagg_data[input.input] = {
                # Set IDs for the collaborators.
                "index": index,
                "public_key": input.public_key
            }
            index += 1

        self.next(self.generate_ciphertexts, foreach="collaborators")

    @collaborator
    def generate_ciphertexts(self):
        """
        Generates the ciphertexts for all the other collaborators which
        includes source collaborator ID, recipient collaborator ID, source
        collaborator's private seed share for recipient collaborator, source
        collaborator's private key share for recipient collaborator.
        """
        print(f"<Collab: {self.input}> generating ciphertexts for secure aggregation...")
        # Find the index of the current collaborator.
        for collab in self.collaborator_secagg_data:
            index = self.collaborator_secagg_data[collab]["index"]
            if self.collaborator_secagg_data[collab]["public_key"][0] == self.public_key[0]:
                self.index = index
                break

        self.agreed_key = {}
        self.cipher_verification = {}
        self.ciphers = {}

        # Using Shamir's secret sharing.
        seed_shares = create_secret_shares(
            # Using private attribute private_seed.
            # Converts the floating-point number private_seed into an 8-byte
            # binary representation.
            struct.pack("d", self.private_seed),
            self.collaborator_count,
            self.collaborator_count,
        )

        key_shares = create_secret_shares(
            str.encode(ECC.import_key(self.private_key[0]).export_key(format="PEM")),
            self.collaborator_count,
            self.collaborator_count,
        )

        for collab in self.collaborator_secagg_data:
            self.collaborator_secagg_data[collab]["ciphertext_from"] = {}
            collab_index = self.collaborator_secagg_data[collab]["index"]
            # Use the private attribute (private_key) and public keys
            # shared in `generate_keys` to generate agreed keys.
            self.agreed_key[collab_index] = [
                generate_agreed_key(
                    self.private_key[0],
                    self.collaborator_secagg_data[collab]["public_key"][0]
                ),
                generate_agreed_key(
                    self.private_key[1],
                    self.collaborator_secagg_data[collab]["public_key"][1]
                )
            ]
            # Generate ciphertext for the collaborator.
            ciphertext, mac, nonce = create_ciphertext(
                self.agreed_key[collab_index][0],   # agreed key
                self.index,                         # source collaborator index
                collab_index,                       # destination collaborator index
                seed_shares[collab_index],          # seed share from source to dest
                key_shares[collab_index]            # key share from source to dest
            )
            # Adding the ciphertext from self.index to collab_index to class
            # attribute collaborator_secagg_data such that it can be accessed by other
            # collaborators.
            self.collaborator_secagg_data[collab]["ciphertext_from"][self.index] = [ciphertext, mac, nonce]

        self.next(self.filter_ciphertexts)

    @aggregator
    def filter_ciphertexts(self, inputs):
        """
        Receives ciphertexts from all collaborators and filters them for each collaborator
        that they are addressed to.
        """
        for input in inputs:
            for collab in input.collaborator_secagg_data:
                if "ciphertext_from" not in self.collaborator_secagg_data[collab]:
                    self.collaborator_secagg_data[collab]["ciphertext_from"] = input.collaborator_secagg_data[collab]["ciphertext_from"]
                else:
                    self.collaborator_secagg_data[collab]["ciphertext_from"].update(
                        input.collaborator_secagg_data[collab]["ciphertext_from"]
                    )

        self.next(self.decrypt_ciphertext, foreach="collaborators")

    @collaborator
    def decrypt_ciphertext(self):
        """
        Receives the ciphertexts addressed to it and deciphers them.
        Shares the deciphered values with the aggregator.
        """
        print(f"<Collab: {self.input}> decrypting ciphertexts for secure aggregation... ")
        self.seed_shares = {}
        self.key_shares = {}
        for collab in self.collaborator_secagg_data:
            if self.collaborator_secagg_data[collab]["index"] == self.index:
                addressed_ciphertexts = self.collaborator_secagg_data[collab].get("ciphertext_from", {})
                for source_id in addressed_ciphertexts:
                    source_public_key = [
                        self.collaborator_secagg_data[collab]["public_key"][0] 
                        if self.collaborator_secagg_data[collab]["index"] == source_id 
                        else None
                        for collab in self.collaborator_secagg_data
                    ]
                    source_public_key = list(filter(None, source_public_key))[0]
                    _, _, seed_share, key_share = decipher_ciphertext(
                        generate_agreed_key(
                            self.private_key[0],
                            source_public_key
                        ),                                          # agreed_key_1
                        addressed_ciphertexts[source_id][0],        # ciphertext
                        addressed_ciphertexts[source_id][1],        # mac
                        addressed_ciphertexts[source_id][2],        # nonce
                    )
                    self.seed_shares[source_id] = [self.index, str(seed_share)]
                    self.key_shares[source_id] = [self.index, str(key_share)]

        self.next(
            self.reconstruct,
            exclude=["agreed_key", "cipher_verification"]
        )

    @aggregator
    def reconstruct(self, inputs):
        """
        Reconstructs the secrets for all the collaborators
        which are required for unmasking during aggregation.
        """
        print(f"<Agg>: Reconstructing secrets for secure aggregation...")
        seed_shares = {}
        key_shares = {}
        for input in inputs:
            # Create a dictionary of shares for each seed.
            for source_id in input.seed_shares:
                if source_id not in seed_shares:
                    seed_shares[source_id] = {}
                seed_shares[source_id][input.seed_shares[source_id][0]] = input.seed_shares[source_id][1][2:-1]
            # Create a dictionary of shares for each key.
            for source_id in input.key_shares:
                if source_id not in key_shares:
                    key_shares[source_id] = {}
                key_shares[source_id][input.key_shares[source_id][0]] = input.key_shares[source_id][1][2:-1]

        # Reconstruct the secrets (seeds and private keys) for all source
        # collaborators.
        self._private_seeds = {}
        self._private_keys = {}
        print(seed_shares, key_shares)
        for source_id in seed_shares:
            self._private_seeds[source_id] = struct.unpack("d", reconstruct_secret(seed_shares[source_id]))[0]
            self._private_keys[source_id] = reconstruct_secret(key_shares[source_id])

        # Generate all agreed keys.
        self.agreed_keys = {}
        for source in self.collaborator_secagg_data:
            source_index = self.collaborator_secagg_data[source]["index"]
            source_private_key = self._private_keys[source_index]
            self.agreed_keys[source_index] = {}
            for dest in self.collaborator_secagg_data:
                dest_index = self.collaborator_secagg_data[dest]["index"]
                dest_public_key = self.collaborator_secagg_data[dest]["public_key"][0]
                self.agreed_keys[source_index][dest_index] = [dest, generate_agreed_key(
                    source_private_key,
                    dest_public_key
                )]

        self.next(self.aggregated_model_validation, foreach="collaborators")

    @collaborator
    def aggregated_model_validation(self):
        """
        Perform validation of aggregated model on collaborators.
        """
        print(f"<Collab: {self.input}> Performing Validation on aggregated model ... ")
        self.agg_validation_score = validate(self.model, self.test_loader)
        print(
            f"<Collab: {self.input}> Aggregated Model validation score = {self.agg_validation_score:.4f}"
        )

        self.next(self.train)

    @collaborator
    def train(self):
        """
        Train model on Local collaborator dataset.
        """
        print(f"<Collab: {self.input}>: Training Model on local dataset ... ")

        self.optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=self.momentum)

        self.loss = masked_train_model(
            self.model,
            self.optimizer,
            self.train_loader,
            self.index,
            self.agreed_keys[self.index],
            self.private_seed,
        )

        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        """
        Validate locally trained model.
        """
        print(f"<Collab: {self.input}> Performing Validation on locally trained model ... ")
        self.local_validation_score = validate(self.model, self.test_loader)
        print(
            f"<Collab: {self.input}> Local model validation score = {self.local_validation_score:.4f}"
        )
        self.next(self.join)

    @aggregator
    def join(self, inputs):
        """
        Model aggregation step.
        """
        print(f"<Agg>: Joining models from collaborators...")

        # Average Training loss, aggregated and locally trained model accuracy
        total_loss = sum(input.loss for input in inputs)
        # Calculate and remove the masks from the total loss.
        for input in inputs:
            total_loss -= calculate_mask(input.index, self.agreed_keys[input.index], self._private_seeds[input.index])

        self.average_loss = total_loss / len(inputs)
        self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs) / len(inputs)
        self.local_model_accuracy = sum(input.local_validation_score for input in inputs) / len(inputs)

        print(f"Avg. aggregated model validation score = {self.aggregated_model_accuracy:.4f}")
        print(f"Avg. training loss = {self.average_loss:.4f}")
        print(f"Avg. local model validation score = {self.local_model_accuracy:.4f}")

        # Average works as the masks on the loss get cancelled out during aggregation.
        self.model = FedAvg(self.model, [input.model for input in inputs])

        self.results.append(
            [
                self.current_round,
                self.aggregated_model_accuracy,
                self.average_loss,
                self.local_model_accuracy,
            ]
        )

        self.current_round += 1
        if self.current_round < self.rounds:
            self.next( self.aggregated_model_validation, foreach="collaborators")
        else:
            self.next(self.end)

    @aggregator
    def end(self):
        """
        This is the last step in the Flow.
        """
        print(f"This is the end of the flow")

### Simulation: LocalRuntime

We now import & define the `LocalRuntime`, participants (`Aggregator/Collaborator`), and initialize the private attributes for participants

- `Runtime` – Defines where the flow runs. `LocalRuntime` simulates the flow on local node.
- `Aggregator/Collaborator` - (Local) Participants in the simulation


In [None]:
# | export

from openfl.experimental.workflow.interface import Aggregator, Collaborator
from openfl.experimental.workflow.runtime import LocalRuntime

# Setup Aggregator & initialize private attributes
aggregator = Aggregator()
aggregator.private_attributes = {}

# Setup Collaborators & initialize shards of MNIST dataset as private attributes
n_collaborators = 2
collaborator_names = ["Bengaluru", "Portland"]

collaborators = [Collaborator(name=name) for name in collaborator_names]
for idx, collaborator in enumerate(collaborators):
    local_train = deepcopy(mnist_train)
    local_test = deepcopy(mnist_test)
    local_train.data = mnist_train.data[idx::n_collaborators]
    local_train.targets = mnist_train.targets[idx::n_collaborators]
    local_test.data = mnist_test.data[idx::n_collaborators]
    local_test.targets = mnist_test.targets[idx::n_collaborators]

    collaborator.private_attributes = {
        "train_loader": torch.utils.data.DataLoader(
            local_train, batch_size=batch_size, shuffle=False
        ),
        "test_loader": torch.utils.data.DataLoader(
            local_test, batch_size=batch_size, shuffle=False
        ),
        "private_seed": random.random(),
        "private_key": [
            ECC.generate(curve="ed25519").export_key(format="PEM"),
            ECC.generate(curve="ed25519").export_key(format="PEM"),
        ],
    }

local_runtime = LocalRuntime(
    aggregator=aggregator, collaborators=collaborators, backend="single_process"
)
print(f"Local runtime collaborators = {local_runtime.collaborators}")

### Start Simulation

Now that we have our flow and runtime defined, let's run the simulation ! 

In [None]:
#| export

model = None
optimizer = None
flflow = SecureAggregation_MNIST(model, optimizer, learning_rate, momentum, rounds=2, checkpoint=True)
flflow.runtime = local_runtime
flflow.run()

Let us check the simulation results

In [None]:
from tabulate import tabulate 

headers = ["Rounds", "Agg Model Validation Score", "Local Train loss", "Local Model Validation score"]
print('********** Simulation results **********')
simulation_results = flflow.results
print(tabulate(simulation_results, headers=headers, tablefmt="outline"))

### Setup Federation: Director & Envoys

Before we can deploy the experiment, let us create participants in Federation: Director and Envoys. As the Tutorial uses two collaborators we shall launch three participants:
1. Director: The central node in the Federation
2. Bengaluru: The first envoy in the Federation
3. Portland: The second envoy in the Federation 

The participants can be launched by following steps mentioned in [README]((https://github.com/securefederatedai/openfl/blob/develop/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/README.md))


### Deploy: FederatedRuntime

We now import and instantiate `FederatedRuntime` to enable deployment of experiment on distributed infrastructure. Initializing the `FederatedRuntime` requires following inputs to be provided by the user:

- `director_info` – director information including fqdn of the director node, port, and certificate information
- `collaborators` - names of the collaborators participating in experiment
- `notebook_path`- path to this jupyter notebook


In [12]:
#| export

from openfl.experimental.workflow.runtime import FederatedRuntime

director_info = {
    'director_node_fqdn': 'localhost',
    'director_port': 50050,
}

federated_runtime = FederatedRuntime(
    collaborators=collaborator_names,
    director=director_info, 
    notebook_path='./MNIST_SecAgg.ipynb'
)

Let us connect to federation & check if the envoys are connected to the director by using the `get_envoys` method of `FederatedRuntime`. If the participants are launched successful in previous step the status of `Bengaluru` and `Portland` should be displayed as `Online`

In [None]:
federated_runtime.get_envoys()

Now that we have our distributed infrastructure ready, let us modify the flow runtime to `FederatedRuntime` instance and deploy the experiment. 

Progress of the flow is available on 
1. Jupyter notebook: if `checkpoint` attribute of the flow object is set to `True`
2. Director and Envoy terminals  


In [None]:
flflow.results = [] # clear results from previous run
flflow.runtime = federated_runtime
flflow.run()

Let us compare the simulation results from `LocalRuntime` and federation results from `FederatedRuntime`

In [None]:
headers = ["Rounds", "Agg Model Validation Score", "Local Train loss", "Local Model Validation score"]
print('********** Simulation results **********')
print(tabulate(simulation_results, headers=headers, tablefmt="outline"))

print('********** Federation results **********')
federation_results = flflow.results
print(tabulate(federation_results, headers=headers, tablefmt="outline"))
