# Federated Runtime: 301_MNIST_Watermarking

This tutorial is based on the LocalRuntime example [301_MNIST_Watermarking](https://github.com/securefederatedai/openfl/blob/develop/openfl-tutorials/experimental/workflow/301_MNIST_Watermarking.ipynb). It has been adapted to demonstrate the FederatedRuntime version of the watermarking workflow. In this tutorial, we will guide you through the process of deploying the watermarking example within a federation, showcasing how to transition from a local setup to a federated environment effectively. User should follow the steps described in [README.md](https://github.com/securefederatedai/openfl/blob/develop/openfl-tutorials/experimental/workflow/FederatedRuntime/301_MNIST_Watermarking/README.md) to ensure that participants in federation are launched before the experiment is deployed to federated environment.


# Getting Started

Initially, we start by specifying the module where cells marked with the `#| export` directive will be automatically exported. 

In the following cell, `#| default_exp experiment `indicates that the exported file will be named 'experiment'. This name can be modified based on user's requirement & preferences

In [1]:
#| default_exp experiment

Once we have specified the name of the module, subsequent cells of the notebook need to be *appended* by the `#| export` directive as shown below. User should ensure that *all* the notebook functionality required in the Federated Learning experiment is included in this directive

We start by installing OpenFL and dependencies of the workflow interface 
> These dependencies are required to be exported and become the requirements for the Federated Learning Workspace 

In [2]:
#| export

!pip install git+https://github.com/securefederatedai/openfl.git
!pip install -r ../../../workflow_interface_requirements.txt
!pip install matplotlib
!pip install torch==2.3.1
!pip install torchvision==0.18.1
!pip install git+https://github.com/pyviz-topics/imagen.git@master
!pip install holoviews==1.15.4
!pip install -U ipywidgets

We now define our model, optimizer, and some helper functions like we would for any other deep learning experiment 

> This cell and all the subsequent cells are important ingredients of the Federated Learning experiment and therefore annotated with the `#| export` directive

In [3]:
# | export

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import numpy as np

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

class Net(nn.Module):
    def __init__(self, dropout=0.0):
        super(Net, self).__init__()
        self.dropout = dropout
        self.block = nn.Sequential(
            nn.Conv2d(1, 32, 2),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 2),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 2),
            nn.ReLU(),
        )
        self.fc1 = nn.Linear(128 * 5**2, 200)
        self.fc2 = nn.Linear(200, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.dropout(x)
        out = self.block(x)
        out = out.view(-1, 128 * 5**2)
        out = self.dropout(out)
        out = self.relu(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        return F.log_softmax(out, 1)


def inference(network, test_loader):
    network.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    accuracy = float(correct / len(test_loader.dataset))
    return accuracy


def train_model(model, optimizer, data_loader, entity, round_number, log=False):
    # Helper function to train the model
    train_loss = 0
    log_interval = 20
    model.train()
    for batch_idx, (X, y) in enumerate(data_loader):
        optimizer.zero_grad()

        output = model(X)
        loss = F.nll_loss(output, y)
        loss.backward()

        optimizer.step()

        train_loss += loss.item() * len(X)
        if batch_idx % log_interval == 0 and log:
            print("{:<20} Train Epoch: {:<3} [{:<3}/{:<4} ({:<.0f}%)] Loss: {:<.6f}".format(
                    entity,
                    round_number,
                    batch_idx * len(X),
                    len(data_loader.dataset),
                    100.0 * batch_idx / len(data_loader),
                    loss.item(),
                )
            )
    train_loss /= len(data_loader.dataset)
    return train_loss

Next we import the `FLSpec` & placement decorators (`aggregator/collaborator`)

In [4]:
#| export

from openfl.experimental.workflow.interface import FLSpec
from openfl.experimental.workflow.placement import aggregator, collaborator

def FedAvg(agg_model, models, weights=None):
    state_dicts = [model.state_dict() for model in models]
    state_dict = agg_model.state_dict()
    for key in models[0].state_dict():
        state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],
                                                      axis=0, 
                                                      weights=weights))
        
    agg_model.load_state_dict(state_dict)
    return agg_model

Let us now define the Workflow for Watermark embedding.

In [None]:
#| export

class FederatedFlow_MNIST_Watermarking(FLSpec):
    """
    This Flow demonstrates Watermarking on a Deep Learning Model in Federated Learning
    Ref: WAFFLE: Watermarking in Federated Learning (https://arxiv.org/abs/2008.07298)
    """

    def __init__(
        self,
        model=None,
        optimizer=None,
        watermark_pretrain_optimizer=None,
        watermark_retrain_optimizer=None,
        round_number=0,
        n_rounds=3,
        **kwargs,
    ):
        super().__init__(**kwargs)

        if model is not None:
            self.model = model
            self.optimizer = optimizer
            self.watermark_pretrain_optimizer = watermark_pretrain_optimizer
            self.watermark_retrain_optimizer = watermark_retrain_optimizer
        else:
            self.model = Net()
            self.optimizer = optim.SGD(
                self.model.parameters(), lr=learning_rate, momentum=momentum
            )
            self.watermark_pretrain_optimizer = optim.SGD(
                self.model.parameters(),
                lr=watermark_pretrain_learning_rate,
                momentum=watermark_pretrain_momentum,
                weight_decay=watermark_pretrain_weight_decay,
            )
            self.watermark_retrain_optimizer = optim.SGD(
                self.model.parameters(), lr=watermark_retrain_learning_rate
            )
        self.round_number = round_number
        self.n_rounds = n_rounds
        self.watermark_pretraining_completed = False

    @aggregator
    def start(self):
        """
        This is the start of the Flow.
        """
        print("<Agg>: Start of flow ... ")
        self.collaborators = self.runtime.collaborators

        self.next(self.watermark_pretrain)

    @aggregator
    def watermark_pretrain(self):
        """
        Pre-Train the Model before starting Federated Learning.
        """
        if not self.watermark_pretraining_completed:

            print("<Agg>: Performing Watermark Pre-training")

            for i in range(self.pretrain_epochs):

                watermark_pretrain_loss = train_model(
                    self.model,
                    self.watermark_pretrain_optimizer,
                    self.watermark_data_loader,
                    "<Agg>:",
                    i,
                    log=False,
                )
                watermark_pretrain_validation_score = inference(
                    self.model, self.watermark_data_loader
                )

                print(f"<Agg>: Watermark Pretraining: Round: {i:<3}"
                      + f" Loss: {watermark_pretrain_loss:<.6f}"
                      + f" Acc: {watermark_pretrain_validation_score:<.6f}")

            self.watermark_pretraining_completed = True

        self.next(
            self.aggregated_model_validation,
            foreach="collaborators",
        )

    @collaborator
    def aggregated_model_validation(self):
        """
        Perform Aggregated Model validation on Collaborators.
        """
        self.agg_validation_score = inference(self.model, self.test_loader)
        print(f"<Collab: {self.input}>"
              + f" Aggregated Model validation score = {self.agg_validation_score}"
              )

        self.next(self.train)

    @collaborator
    def train(self):
        """
        Train model on Local collab dataset.
        """
        print("<Collab>: Performing Model Training on Local dataset ... ")

        self.optimizer = optim.SGD(
            self.model.parameters(), lr=learning_rate, momentum=momentum
        )

        self.loss = train_model(
            self.model,
            self.optimizer,
            self.train_loader,
            f"<Collab: {self.input}>",
            self.round_number,
            log=True,
        )

        self.next(self.local_model_validation)

    @collaborator
    def local_model_validation(self):
        """
        Validate locally trained model.
        """
        self.local_validation_score = inference(self.model, self.test_loader)
        print(
            f"<Collab: {self.input}> Local model validation score = {self.local_validation_score}"
        )
        self.next(self.join)

    @aggregator
    def join(self, inputs):
        """
        Model aggregation step.
        """
        self.average_loss = sum(input.loss for input in inputs) / len(inputs)
        self.aggregated_model_accuracy = sum(
            input.agg_validation_score for input in inputs
        ) / len(inputs)
        self.local_model_accuracy = sum(
            input.local_validation_score for input in inputs
        ) / len(inputs)

        print("<Agg>: Joining models from collaborators...")

        print(
            f"   Aggregated model validation score = {self.aggregated_model_accuracy}"
        )
        print(f"   Average training loss = {self.average_loss}")
        print(f"   Average local model validation values = {self.local_model_accuracy}")

        self.model = FedAvg(self.model, [input.model for input in inputs])

        self.next(self.watermark_retrain)

    @aggregator
    def watermark_retrain(self):
        """
        Retrain the aggregated model.
        """
        print("<Agg>: Performing Watermark Retraining ... ")
        self.watermark_retrain_optimizer = optim.SGD(
            self.model.parameters(), lr=watermark_retrain_learning_rate
        )

        retrain_round = 0

        # Perform re-training until (accuracy >= acc_threshold) or
        # (retrain_round > number of retrain_epochs)
        self.watermark_retrain_validation_score = inference(
            self.model, self.watermark_data_loader
        )
        while (
            self.watermark_retrain_validation_score < self.watermark_acc_threshold
        ) and (retrain_round < self.retrain_epochs):
            self.watermark_retrain_train_loss = train_model(
                self.model,
                self.watermark_retrain_optimizer,
                self.watermark_data_loader,
                "<Agg>",
                retrain_round,
                log=False,
            )
            self.watermark_retrain_validation_score = inference(
                self.model, self.watermark_data_loader
            )

            print(f"<Agg>: Watermark Retraining: Train Epoch: {self.round_number:<3}"
                  + f" Retrain Round: {retrain_round:<3}"
                  + f" Loss: {self.watermark_retrain_train_loss:<.6f},"
                  + f" Acc: {self.watermark_retrain_validation_score:<.6f}")
            retrain_round += 1

        self.next(self.internal_loop)
    
    @aggregator
    def internal_loop(self):
        """
        Internal loop to continue the Federated Learning process.
        """
        if self.round_number == self.n_rounds - 1:
            print(f"\nCompleted training for all {self.n_rounds} round(s)")
            self.next(self.end)
        else:
            self.round_number += 1
            print(f"\nCompleted round: {self.round_number}")
            self.next(self.aggregated_model_validation, foreach='collaborators')

    @aggregator
    def end(self):
        """
        This is the last step in the Flow.
        """
        print("This is the end of the flow")

## Defining and Initializing the Federated Runtime
We initialize the Federated Runtime by providing:
- `director_info`: The director's connection information 
- `authorized_collaborators`: A list of authorized collaborators
- `notebook_path`: Path to this Jupyter notebook.

In [6]:
#| export

from openfl.experimental.workflow.runtime import FederatedRuntime

director_info = {
    'director_node_fqdn':'localhost',
    'director_port':50050,
}

authorized_collaborators = ['Bangalore', 'Chandler']

federated_runtime = FederatedRuntime(
    collaborators=authorized_collaborators,
    director=director_info, 
    notebook_path='./MNIST_Watermarking.ipynb',
)

The status of the connected Envoys can be checked using the `get_envoys()` method of the `federated_runtime`.

In [None]:
federated_runtime.get_envoys()

With the federated_runtime now instantiated, we will proceed to deploy the watermarking workspace and run the experiment!

In [None]:
#| export

# Set random seed
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)
torch.backends.cudnn.enabled = False

# MNIST parameters
learning_rate = 5e-2
momentum = 5e-1
log_interval = 20

# Watermarking parameters
watermark_pretrain_learning_rate = 1e-1
watermark_pretrain_momentum = 5e-1
watermark_pretrain_weight_decay = 5e-05
watermark_retrain_learning_rate = 5e-3

model = Net()
optimizer = optim.SGD(
    model.parameters(), lr=learning_rate, momentum=momentum
)
watermark_pretrain_optimizer = optim.SGD(
    model.parameters(),
    lr=watermark_pretrain_learning_rate,
    momentum=watermark_pretrain_momentum,
    weight_decay=watermark_pretrain_weight_decay,
)
watermark_retrain_optimizer = optim.SGD(
    model.parameters(), lr=watermark_retrain_learning_rate
)

flflow = FederatedFlow_MNIST_Watermarking(
    model,
    optimizer,
    watermark_pretrain_optimizer,
    watermark_retrain_optimizer,
    checkpoint=True,
)
flflow.runtime = federated_runtime
flflow.run()