# Federated Runtime: CrowdGuard

This work is based on the [CrowdGuard demo code](https://github.com/securefederatedai/openfl/blob/develop/openfl-tutorials/experimental/workflow/CrowdGuard). It has been adapted to demonstrate CrowdGuard in the `FederatedRuntime`.

# Getting Started

Initially, we start by specifying the module where cells marked with the `#| export` directive will be automatically exported. 

In the following cell, `#| default_exp experiment `indicates that the exported file will be named 'experiment'. This name can be modified based on user's requirement & preferences

In [None]:
#| default_exp experiment

Once we have specified the name of the module, subsequent cells of the notebook need to be *appended* by the `#| export` directive as shown below. User should ensure that *all* the notebook functionality required in the Federated Learning experiment is included in this directive

## Installing Pre-requisties
We start by installing OpenFL and dependencies of the workflow interface. These dependencies are exported and become requirements for the Federated Learning Environment

In [None]:
#| export

!pip install git+https://github.com/securefederatedai/openfl.git
!pip install -r ../../workflow_interface_requirements.txt
!pip install numpy
!pip install torch==2.3.1
!pip install torchvision==0.18.1
!pip install -U ipywidgets

### Some global variables for CrowdGuard

In [None]:
# | export

import torch
import numpy as np

BATCH_SIZE_TRAIN = 32
BATCH_SIZE_TEST = 1000
LEARNING_RATE = 0.00075
MOMENTUM = 0.9
LOG_INTERVAL = 10
TOTAL_CLIENT_NUMBER = 4
PMR = 0.25
NUMBER_OF_MALICIOUS_CLIENTS = max(1, int(TOTAL_CLIENT_NUMBER * PMR)) if PMR > 0 else 0
NUMBER_OF_BENIGN_CLIENTS = TOTAL_CLIENT_NUMBER - NUMBER_OF_MALICIOUS_CLIENTS
PRETRAINED_MODEL_FILE = 'pretrained_cifar.pt'

# set the random seed for repeatable results
RANDOM_SEED = 10

VOTE_FOR_BENIGN = 1
VOTE_FOR_POISONED = 0
STD_DEV = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010]))
MEAN = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465]))

We now define our model, optimizer, and some helper functions like we would for any other deep learning experiment 

> This cell and all the subsequent cells are important ingredients of the Federated Learning experiment and therefore annotated with the `#| export` directive

In [None]:
# | export

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import random

def seed_random_generators(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
seed_random_generators(RANDOM_SEED)

class Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.features = SequentialWithInternalStatePrediction(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = SequentialWithInternalStatePrediction(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 * 2)
        x = self.classifier(x)
        return x

    def predict_internal_states(self, x):
        result, x = self.features.predict_internal_states(x)
        x = x.view(x.size(0), 256 * 2 * 2)
        result += self.classifier.predict_internal_states(x)[0]
        return result
    

class SequentialWithInternalStatePrediction(nn.Sequential):
    """
    Adapted version of Sequential that implements the function predict_internal_states
    """

    def predict_internal_states(self, x):
        """
        applies the submodules on the input. Compared to forward, this function also returns
        all intermediate outputs
        """
        result = []
        for module in self:
            x = module(x)
            # We can define our layer as we want. We selected Convolutional and
            # Linear Modules as layers here.
            # Differs for every model architecture.
            # Can be defined by the defender.
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                result.append(x)
        return result, x

## Helper Functions

In [None]:
#| export

def default_optimizer(model, optimizer_type=None, optimizer_like=None):
    """
    Return a new optimizer based on the optimizer_type or the optimizer template

    Args:
        model:   NN model architected from nn.module class
        optimizer_type: "SGD" or "Adam"
        optimizer_like: "torch.optim.SGD" or "torch.optim.Adam" optimizer
    """
    if optimizer_type == "SGD" or isinstance(optimizer_like, optim.SGD):
        return optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    elif optimizer_type == "Adam" or isinstance(optimizer_like, optim.Adam):
        return optim.Adam(model.parameters())
    
def test(network, test_loader, device, mode='Benign', move_to_cpu_afterward=True,
         test_train='Test'):
    network.eval()
    network.to(device)
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            output = network(data)
            criterion = nn.CrossEntropyLoss()
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader)
    accuracy = float(correct / len(test_loader.dataset))
    print(
        (
            f"{mode} {test_train} set: Avg. loss: {test_loss}, "
            f"Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * accuracy:5.03f}%)"
        )
    )
    if move_to_cpu_afterward:
        network.to("cpu")
    return accuracy

def scale_update_of_model(to_scale, global_model, scaling_factor):
    """
    Scales the update of a local model (thus the difference between global and local model)
    :param to_scale: local model as state dict
    :pram global_model
    :param scaling factor
    :return scaled local model as state dict
    """
    print(f'Scale Model by {scaling_factor}')
    result = {}
    for name, data in to_scale.items():
        if not (name.endswith('.bias') or name.endswith('.weight')):
            result[name] = data
        else:
            update = data - global_model[name]
            scaled = scaling_factor * update
            result[name] = scaled + global_model[name]
    return result

def create_cluster_map_from_labels(expected_number_of_labels, clustering_labels):
    """
    Converts a list of labels into a dictionary where each label is the key and
    the values are lists/np arrays of the indices from the samples that received
    the respective label
    :param expected_number_of_labels number of samples whose labels are contained in
    clustering_labels
    :param clustering_labels list containing the labels of each sample
    :return dictionary of clusters
    """
    assert len(clustering_labels) == expected_number_of_labels

    clusters = {}
    for i, cluster in enumerate(clustering_labels):
        if cluster not in clusters:
            clusters[cluster] = []
        clusters[cluster].append(i)
    return {index: np.array(cluster) for index, cluster in clusters.items()}


def determine_biggest_cluster(clustering):
    """
    Given a clustering, given as dictionary of the form {cluster_id: [items in cluster]}, the
    function returns the id of the biggest cluster
    """
    biggest_cluster_id = None
    biggest_cluster_size = None
    for cluster_id, cluster in clustering.items():
        size_of_current_cluster = np.array(cluster).shape[0]
        if biggest_cluster_id is None or size_of_current_cluster > biggest_cluster_size:
            biggest_cluster_id = cluster_id
            biggest_cluster_size = size_of_current_cluster
    return biggest_cluster_id

def trigger_single_image(image):
    """
    Adds a red square with a height/width of 6 pixels into
    the upper left corner of the given image.
    @param image tensor, containing the normalized pixel values of the image.
    The image will be modified in-place.
    @return given image
    """
    color = (torch.Tensor((1, 0, 0)) - MEAN) / STD_DEV
    image[:, 0:6, 0:6] = color.repeat((6, 6, 1)).permute(2, 1, 0)
    return image


def poison_data(samples_to_poison, labels_to_poison, pdr=0.5):
    """
    poisons a given local dataset, consisting of samples and labels, s.t.,
    the given ratio of this image consists of samples for the backdoor behavior
    :param samples_to_poison tensor containing all samples of the local dataset
    :param labels_to_poison tensor containing all labels
    :param pdr poisoned data rate
    :return poisoned local dataset (samples, labels)
    """
    if pdr == 0:
        return samples_to_poison, labels_to_poison

    assert 0 < pdr <= 1.0
    samples_to_poison = samples_to_poison.clone()
    labels_to_poison = labels_to_poison.clone()

    dataset_size = samples_to_poison.shape[0]
    num_samples_to_poison = int(dataset_size * pdr)
    if num_samples_to_poison == 0:
        # corner case for tiny pdrs
        assert pdr > 0  # Already checked above
        assert dataset_size > 1
        num_samples_to_poison += 1

    indices = np.random.choice(dataset_size, size=num_samples_to_poison, replace=False)
    for image_index in indices:
        image = trigger_single_image(samples_to_poison[image_index])
        samples_to_poison[image_index] = image
    labels_to_poison[indices] = 2
    return samples_to_poison, labels_to_poison.long()

## CrowdGuardClientValidation

The `FederatedRuntime`, as of now, requires all user-defined code to be included within the notebook. Therefore the definition of `CrowdGuardClientValidation` has to be here. Before it was inside `CrowdGuardClientValidation.py`.

In [None]:
#| export

# Copyright (C) 2022-2024 TU Darmstadt
# SPDX-License-Identifier: Apache-2.0

# -----------------------------------------------------------
# Primary author: Phillip Rieger <phillip.rieger@trust.tu-darmstadt.de>
# Co-authored-by: Torsten Krauss <torsten.krauss@uni-wuerzburg.de>
# ------------------------------------------------------------
from enum import Enum
from copy import deepcopy
from scipy.stats import kstest, levene, ttest_ind, bartlett
from sklearn import preprocessing
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt


class DistanceMetric(str, Enum):
    """Enum to identify distance metrics necessary in this project"""
    COSINE = 'cosine'
    EUCLIDEAN = 'euclid'


class DistanceHandler:
    """Helper, that calculates distances between two tensors."""

    @staticmethod
    def __get_euclid_distance(t1: torch.Tensor, t2: torch.Tensor) -> float:
        t = t1.view(-1) - t2.view(-1)
        return torch.norm(t, 2).cpu().item()

    @staticmethod
    def __get_cosine_distance(t1: torch.Tensor, t2: torch.Tensor) -> float:
        t1 = t1.view(-1).reshape(1, -1)
        t2 = t2.view(-1).reshape(1, -1)
        return 1 - torch.cosine_similarity(t1, t2).cpu().item()

    @staticmethod
    def get_distance(distance: DistanceMetric, t1: torch.Tensor, t2: torch.Tensor) -> float:
        """Factory Method for Distances"""
        if distance == DistanceMetric.COSINE:
            return DistanceHandler.__get_cosine_distance(t1, t2)
        if distance == DistanceMetric.EUCLIDEAN:
            return DistanceHandler.__get_euclid_distance(t1, t2)

        raise Exception(f"Extractor for {distance} not implemented yet.")


class CrowdGuardClientValidation:

    @staticmethod
    def __distance_global_model_final_metric(distance_type: str, prediction_matrix,
                                             prediction_global_model, sample_indices_by_label,
                                             own_index):
        """
        Calculates the distance matrix containing the metric for CrowdGuard
        with dimensions label x model x layer x values
        """

        sample_count = len(prediction_matrix)
        model_count = len(prediction_matrix[0])
        layer_count = len(prediction_matrix[0][0])

        # We create a distance matrix with distances between global and local models
        # of the dimensions sample x model x layer x values
        global_distance_matrix = [[[0.] * layer_count for _ in range(model_count)]
                                  for _ in range(sample_count)]
        # 1. calculate distances between predictions of global model and each local model
        for s_i, s in enumerate(prediction_matrix):
            g = prediction_global_model[s_i]
            for m_i, m in enumerate(s):
                for l_i, l in enumerate(m):
                    distance = DistanceHandler.get_distance(distance_type, l, g[
                        l_i])  # either euclidean or cosine distance
                    global_distance_matrix[s_i][m_i][l_i] = distance  # line 18

        # 2. Sort the sample-wise distances by the label of the sample
        for label, sample_list in sample_indices_by_label.items():
            # First pick the samples from the global predictions
            global_distance_matrix_for_label_helper = [
                [[0.] * len(sample_list) for _ in range(layer_count)] for _ in
                range(model_count)]

            s_i_new = 0
            for s_i, s in enumerate(global_distance_matrix):
                if s_i not in sample_list:
                    continue
                for m_i, mi in enumerate(s):
                    for l_i, l in enumerate(mi):
                        global_distance_matrix_for_label_helper[m_i][l_i][s_i_new] = l
                s_i_new += 1

        # We produce the first relative matrix
        sample_relation_matrix = [[[0.] * layer_count for _ in range(model_count)] for _ in
                                  range(sample_count)]

        # 3. divide by distances of this client to use its values as reference
        for s_i, s in enumerate(global_distance_matrix):
            distances_for_own_models_predictions = s[own_index]
            for m_j, mj in enumerate(s):
                for l_i, l in enumerate(mj):
                    relation = 0
                    if distances_for_own_models_predictions[l_i] != 0:
                        relation = l / distances_for_own_models_predictions[l_i]
                    sample_relation_matrix[s_i][m_j][l_i] = relation  # line 21

        # We produce the Label average
        # We produce a matrix with not all samples, but mean all the samples, so that we have a
        # Matrix per label
        sample_relation_matrix_for_label = {}

        # 4. Transpose matrix as preparation for averaging
        for label, sample_list in sample_indices_by_label.items():
            sample_relation_matrix_for_label[label] = [[0.] * layer_count for _ in
                                                       range(model_count)]
            sample_relation_matrix_for_label_helper = [
                [[0.] * len(sample_list) for _ in range(layer_count)] for _ in range(model_count)]
            # transpose dimensions of distance matrix, before we had (sample,model, layer) and
            # we transpose it to (model,layer,sample)
            s_i_new = 0
            for s_i, s in enumerate(sample_relation_matrix):
                if s_i not in sample_list:
                    continue
                for m_j, mj in enumerate(s):
                    for l_i, l in enumerate(mj):
                        sample_relation_matrix_for_label_helper[m_j][l_i][s_i_new] = l
                s_i_new += 1

            # 5. Average over all samples from the same label (basically kick-out the last
            # dimension)
            for m_j, mj in enumerate(sample_relation_matrix_for_label_helper):
                for l_i, l in enumerate(mj):
                    sample_relation_matrix_for_label[label][m_j][l_i] = np.mean(l).item()

        avg_sample_relation_matrix_squared_negative_models_first = {}

        # 6. subtract 1 (mainly for cosine distances) and square (but keep the sign)
        for label, label_values in sample_relation_matrix_for_label.items():
            avg_sample_relation_matrix_squared_negative_models_first[label] = [[0.] * layer_count
                                                                               for _ in
                                                                               range(model_count)]
            for m_j, mj in enumerate(label_values):
                for l_i, l in enumerate(mj):
                    x = l - 1
                    relation = x * x
                    relation = relation if x >= 0 else relation * (-1)
                    avg_sample_relation_matrix_squared_negative_models_first[label][m_j][
                        l_i] = relation
        return avg_sample_relation_matrix_squared_negative_models_first

    @staticmethod
    def __predict_for_single_model(model, local_data, device):
        """
        Returns
        - A matrix with Deep Layer Outputs with dimensions sample x layer x values.
        - The labels for all samples in the client's training dataset
        - The number of layers defined in the model
        """
        num_layers = None
        sample_label_list = []
        predictions = []
        model.eval()
        model = model.to(device)
        number_of_previous_samples = 0
        for batch_id, batch in enumerate(local_data):
            data, target = batch
            data, target = data.to(device), target.to(device)
            output = model.predict_internal_states(data)
            if num_layers is None:
                num_layers = len(output)
            assert num_layers == len(output)

            for layer_output_index, layer_output_values in enumerate(output):
                for idx in range(target.shape[0]):
                    sample_idx = number_of_previous_samples + idx
                    assert len(predictions) >= sample_idx
                    if len(predictions) == sample_idx:
                        assert layer_output_index == 0
                        predictions.append([])

                    if layer_output_index == 0:
                        expected_predictions = sample_idx + 1
                    else:
                        expected_predictions = number_of_previous_samples + target.shape[0]
                    assert_msg = f'{len(predictions)} vs. {sample_idx} ({idx} {batch_id} '
                    assert_msg += f'{layer_output_index} {number_of_previous_samples})'
                    assert len(predictions) == expected_predictions, assert_msg
                    assert_msg = f'{len(predictions[sample_idx])} {layer_output_index} '
                    assert_msg += f'{sample_idx} {batch_id} {idx} {number_of_previous_samples}'
                    assert len(predictions[sample_idx]) == layer_output_index, assert_msg
                    value = layer_output_values[idx].clone().detach().cpu()
                    predictions[sample_idx].append(value)
            number_of_previous_samples += target.shape[0]
            for t in target:
                sample_label_list.append(t.detach().clone().cpu().item())
        model.cpu()
        return predictions, sample_label_list, num_layers

    @staticmethod
    def __do_predictions(models, global_model, local_data, device):
        """
        Returns
        - The Deep Layer Outputs for all models in a matrix of dimension
          sample x model x layer x value
        - The Deep Layer Outputs of the global model int he dimension sample x layer x value
        - A dict containing lists of sample indices for each label class
        - The number of layers from the model
        """
        all_models_predictions = []
        for model_index, model in enumerate(models):
            predictions, _, _ = CrowdGuardClientValidation.__predict_for_single_model(model,
                                                                                      local_data,
                                                                                      device)
            for sample_index, layer_predictions_for_sample in enumerate(predictions):
                # why not just append?
                if sample_index >= len(all_models_predictions):
                    assert model_index == 0
                    assert len(all_models_predictions) == sample_index
                    all_models_predictions.append([])
                all_models_predictions[sample_index].append(layer_predictions_for_sample)
        tmp = CrowdGuardClientValidation.__predict_for_single_model(global_model, local_data,
                                                                    device)
        global_model_predictions, sample_label_list, n_layers = tmp
        sample_indices_by_label = {}
        for s_i, label in enumerate(sample_label_list):
            if label not in sample_indices_by_label.keys():
                sample_indices_by_label[label] = []
            sample_indices_by_label[label].append(s_i)

        return all_models_predictions, global_model_predictions, sample_indices_by_label, n_layers

    @staticmethod
    def __prune_poisoned_models(num_layers, total_number_of_clients, own_client_index,
                                distances_by_metric, verbose=False):
        detected_poisoned_models = []
        for distance_type in distances_by_metric.keys():

            # First load the distance Matrix for this client and the samples by labels.
            distance_matrix_la_m_l = distances_by_metric[distance_type]

            # We put all of our labels into one big row.
            layer_length = num_layers * len(distance_matrix_la_m_l)
            dist_matrix_m_lcon = [[0.] * layer_length for _ in range(total_number_of_clients)]
            label_count = 0
            for label_x, dist_matrix_m_l_for_label in distance_matrix_la_m_l.items():
                for model_idx, model_values in enumerate(dist_matrix_m_l_for_label):
                    for layer_idx, layer in enumerate(model_values):
                        dist_matrix_m_lcon[model_idx][layer_idx + label_count * num_layers] = layer
                label_count = label_count + 1

            dist_matrix_m_l = dist_matrix_m_lcon

            client_indices = [i for i, _ in enumerate(dist_matrix_m_l) if i != own_client_index]
            pruned_indices = []
            has_malicious_model = True
            new_round_needed = True
            prune_idx = 0

            max_pruning_count = (len(dist_matrix_m_l) - 1) // 2

            while has_malicious_model and new_round_needed:
                # unique
                pruned_indices_local = deepcopy(pruned_indices)
                # Ignore the own label again and the pruned indices
                pruned_cluster_input_m_l = [
                    value for i, value in enumerate(dist_matrix_m_l) 
                    if i != own_client_index and i not in pruned_indices
                    ]
                pruned_client_indices = [
                    i for i, value in enumerate(dist_matrix_m_l) 
                    if i != own_client_index and i not in pruned_indices
                    ]

                if len(pruned_cluster_input_m_l) <= 1:
                    break

                layer_values = {}

                for m in pruned_cluster_input_m_l:
                    for l_i, l in enumerate(m):
                        if l_i not in layer_values.keys():
                            layer_values[l_i] = []
                        layer_values[l_i].append(l)

                median_layer_values = []

                for l_i, l_values in layer_values.items():
                    median_layer_values.append(np.median(l_values).item())

                median_graph = list(median_layer_values)

                pca_list = []
                for m in pruned_cluster_input_m_l:
                    pca_list.append(m)

                pca_list.append(median_graph)

                scaled_data = preprocessing.scale(pca_list)

                pca = PCA()
                pca.fit(scaled_data)
                pca_data = pca.transform(scaled_data)

                cluster_input = []
                cluster_input_plain = []
                pca_one_data = pca_data.T[0]
                for pca_one_value in pca_one_data:
                    cluster_input.append([pca_one_value])
                    cluster_input_plain.append(pca_one_value)

                # Significance tests
                median_val = np.median(cluster_input_plain)
                if verbose:
                    print(f'cluster_input_plain={cluster_input_plain}')
                x_values = []
                y_values = []
                for value in cluster_input_plain:
                    # Split the samples into two groups
                    distance_value = abs(value - median_val)
                    if value >= median_val:
                        x_values.append(distance_value)
                    else:
                        y_values.append(distance_value)
                print(f'Distance: {distance_type}, use y {len(y_values)}: {y_values}')
                print(f'Distance: {distance_type}, use x {len(x_values)}: {x_values}')

                # Statistical tests
                t_value, t_p_value = ttest_ind(x_values, y_values)
                ks_value, ks_p_value = kstest(x_values, y_values)
                barlett_value, bartlett_p_value = levene(x_values, y_values)
                # Outlier tests
                # Creating boxplot
                bp_result = plt.boxplot(cluster_input_plain, whis=5.5)
                fliers = bp_result['fliers'][0].get_ydata()
                outlier_boxplot = len(fliers)
                plt.close()

                # Outlier based on variance
                deviation_mean = np.mean(cluster_input_plain)
                deviation_std = abs(np.std(cluster_input_plain))

                max_dist_rule_factor = 0
                for cip in cluster_input_plain:
                    cip_abs = abs(cip - deviation_mean)
                    rule_factor = cip_abs / deviation_std
                    if max_dist_rule_factor < rule_factor:
                        max_dist_rule_factor = rule_factor

                outlier_three_sigma = max_dist_rule_factor

                has_malicious_model_t_threshold = True if t_p_value < 0.01 else False
                has_malicious_model_ks_threshold = True if ks_p_value < 0.01 else False
                has_malicious_model_bartlett_threshold = True if bartlett_p_value < 0.01 else False

                has_boxplot_outlier = True if outlier_boxplot > 0 else False
                has_three_sigma_outlier = True if outlier_three_sigma >= 3 else False

                # Choose exit criterium
                has_malicious_model = (has_malicious_model_t_threshold
                                       or has_malicious_model_ks_threshold
                                       or has_malicious_model_bartlett_threshold
                                       or has_boxplot_outlier
                                       or has_three_sigma_outlier)
                
                # print(f'{t_p_value:.2f}, {ks_p_value:.2f}, {bartlett_p_value:.2f}, {outlier_boxplot:.2f}, {outlier_three_sigma:.2f}')

                ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None,
                                               compute_full_tree=True,
                                               metric="euclidean", memory=None,
                                               connectivity=None,
                                               linkage='single',
                                               compute_distances=True).fit(cluster_input)
                ac_e_labels: list = ac_e.labels_.tolist()
                median_value_cluster_label = ac_e_labels[-1]
                ac_e_malicious_class_indices = [idx for idx, val in enumerate(ac_e_labels) if
                                                val != median_value_cluster_label]

                for m_j, value in enumerate(pruned_client_indices):
                    if m_j in ac_e_malicious_class_indices:
                        pruned_indices_local.append(value)

                pruned_indices_local = list(set(pruned_indices_local))

                # If we now prune more than half, we stop and remove the best items from the last
                # pruning list.
                pruned_too_much = True
                if len(pruned_indices_local) > max_pruning_count:
                    dist_values_of_pruned_models = []
                    for midx in ac_e_malicious_class_indices:
                        dist_to_median = abs(cluster_input[midx][0] - cluster_input[-1][0])
                        dist_values_of_pruned_models.append(dist_to_median)

                    sorted_dist_values_of_pruned_models = list(dist_values_of_pruned_models)
                    sorted_dist_values_of_pruned_models.sort()

                    sorted_ac_e_malicious_class_indices = []
                    for sdv in sorted_dist_values_of_pruned_models:
                        dvidx = dist_values_of_pruned_models.index(sdv)
                        for m_j, value in enumerate(pruned_client_indices):
                            if m_j == ac_e_malicious_class_indices[dvidx]:
                                sorted_ac_e_malicious_class_indices.append(value)
                    overflowed_count = len(pruned_indices_local) - max_pruning_count
                    for oc in range(overflowed_count):
                        # Get the values of the clusters and remove the nearest ones
                        # from pruned_indices_local
                        pruned_indices_local.remove(sorted_ac_e_malicious_class_indices[-1])
                        del sorted_ac_e_malicious_class_indices[-1]
                    pruned_too_much = False

                still_pruning = len(pruned_indices) < len(pruned_indices_local)
                new_round_needed = still_pruning and pruned_too_much
                if has_malicious_model and new_round_needed:
                    pruned_indices = pruned_indices_local

                prune_idx += 1

            # Analyze the voting
            for value in client_indices:
                if value in pruned_indices:
                    detected_poisoned_models.append(value)

        return list(set(detected_poisoned_models))

    @staticmethod
    def validate_models(global_model, models, own_client_index, local_data, device):
        tmp = CrowdGuardClientValidation.__do_predictions(models, global_model, local_data, device)
        prediction_matrix, global_model_predictions, sample_indices_by_label, num_layers = tmp
        distances_by_metric = {}
        for dist_type in [DistanceMetric.COSINE, DistanceMetric.EUCLIDEAN]:
            calculated_distances = CrowdGuardClientValidation.__distance_global_model_final_metric(
                dist_type,
                prediction_matrix,
                global_model_predictions,
                sample_indices_by_label,
                own_client_index)
            distances_by_metric[dist_type] = calculated_distances
        result = CrowdGuardClientValidation.__prune_poisoned_models(num_layers, len(models),
                                                                    own_client_index,
                                                                    distances_by_metric)
        return result

## Workflow definition
Next we import the FLSpec, placement decorators (aggregator/collaborator), and define the FedAvg helper function

In [None]:
#| export

from openfl.experimental.workflow.interface import FLSpec
from openfl.experimental.workflow.placement import aggregator, collaborator

def FedAvg(models):  # NOQA: N802
    """
    Return a Federated average model based on Fedavg algorithm: H. B. Mcmahan,
    E. Moore, D. Ramage, S. Hampson, and B. A. Y.Arcas,
    “Communication-efficient learning of deep networks from decentralized data,” 2017.

    Args:
        models: Python list of locally trained models by each collaborator
    """
    new_model = models[0]
    if len(models) > 1:
        state_dicts = [model.state_dict() for model in models]
        state_dict = new_model.state_dict()
        for key in models[1].state_dict():
            state_dict[key] = torch.from_numpy(
                np.average([state[key].numpy() for state in state_dicts], axis=0))
        new_model.load_state_dict(state_dict)
    return new_model

In [None]:
#| export

from sklearn.cluster import AgglomerativeClustering, DBSCAN

class FederatedFlow_CrowdGuard(FLSpec):
    def __init__(
        self,
        model,
        optimizer_type,
        total_rounds=50,
        top_model_accuracy=0,
        pmr=0.25,
        aggregation_algorithm='CrowdGuard',
        **kwargs, 
    ):
        super().__init__(**kwargs)
        self.aggregation_algorithm = aggregation_algorithm
        self.model = model
        self.global_model = Net()
        self.pmr = pmr
        self.start_time = None
        self.collaborators = None
        self.private = None
        self.optimizer_type = optimizer_type
        self.total_rounds = total_rounds
        self.top_model_accuracy = top_model_accuracy
        if torch.cuda.is_available():
            self.device = torch.device(
                "cuda:1"
            )  # This will enable Ray library to reserve available GPU(s) for the task
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")
        print(f"Using device: {self.device}")
        self.round_num = 0  # starting round
        print(20 * "#")
        print(f"Round {self.round_num}...")
        print(20 * "#")
        print(20 * "#")
        print(f"Round {self.round_num}...")
        print(20 * "#")

    @aggregator
    def start(self):
        print("Performing initialization for model")
        self.collaborators = self.runtime.collaborators
        self.private = 10
        self.next(
            self.train,
            foreach="collaborators",
            exclude=["private"],
        )

    @collaborator
    def train(self):
        self.collaborator_name = self.input
        print(20 * "#")
        print(f"Performing model training for collaborator {self.input} in round {self.round_num}")

        self.model.to(self.device)
        original_model = {n: d.clone() for n, d in self.model.state_dict().items()}
        test(self.model, self.train_loader, self.device, move_to_cpu_afterward=False,
             test_train='Train')
        test(self.model, self.test_loader, self.device, move_to_cpu_afterward=False)
        test(self.model, self.backdoor_test_loader, self.device, mode='Backdoor',
             move_to_cpu_afterward=False)
        self.optimizer = default_optimizer(self.model, self.optimizer_type)

        self.model.train()
        train_losses = []
        for batch_idx, (data, target) in enumerate(self.train_loader):
            data = data.to(self.device)
            target = target.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(output, target).to(self.device)
            loss.backward()
            self.optimizer.step()
            if batch_idx % LOG_INTERVAL == 0:
                train_losses.append(loss.item())

        self.loss = np.mean(train_losses)
        self.training_completed = True

        test(self.model, self.train_loader, self.device, move_to_cpu_afterward=False,
             test_train='Train')
        test(self.model, self.test_loader, self.device, move_to_cpu_afterward=False)
        test(self.model, self.backdoor_test_loader, self.device, mode='Backdoor',
             move_to_cpu_afterward=False)
        if 'malicious' in self.input:
            weights = self.model.state_dict()
            scaled = scale_update_of_model(weights, original_model, 1 / self.pmr)
            self.model.load_state_dict(scaled)
        self.model.to("cpu")
        torch.cuda.empty_cache()
        if self.aggregation_algorithm == 'FedAVG':
            self.next(self.fed_avg_aggregation, exclude=["training_completed"])
        else:
            self.next(self.collect_models, exclude=["training_completed"])

    @aggregator
    def fed_avg_aggregation(self, inputs):
        self.all_models = {input.collaborator_name: input.model.cpu() for input in inputs}
        self.model = FedAvg([m.cpu() for m in self.all_models.values()])
        self.round_num += 1
        if self.round_num + 1 < self.total_rounds:
            self.next(self.train, foreach="collaborators")
        else:
            self.next(self.end)

    @aggregator
    def collect_models(self, inputs):
        # Following the CrowdGuard paper, this should be executed within SGX

        self.all_models = {i.collaborator_name: i.model.cpu() for i in inputs}
        self.next(self.local_validation, foreach="collaborators")

    @collaborator
    def local_validation(self):
        # Following the CrowdGuard paper, this should be executed within SGX
        print(
            f"Performing model validation for collaborator {self.input} in round {self.round_num}"
        )
        self.collaborator_name = self.input
        all_names = list(self.all_models.keys())
        all_models = [self.all_models[n] for n in all_names]
        own_client_index = all_names.index(self.collaborator_name)
        detected_suspicious_models = CrowdGuardClientValidation.validate_models(self.global_model,
                                                                                all_models,
                                                                                own_client_index,
                                                                                self.train_loader,
                                                                                self.device)
        detected_suspicious_models = sorted(detected_suspicious_models)
        print(
            f'Suspicious Models detected by {own_client_index}: {detected_suspicious_models}')

        votes_of_this_client = []
        for c in range(len(all_models)):
            if c == own_client_index:
                votes_of_this_client.append(VOTE_FOR_BENIGN)
            elif c in detected_suspicious_models:
                votes_of_this_client.append(VOTE_FOR_POISONED)
            else:
                votes_of_this_client.append(VOTE_FOR_BENIGN)
        self.votes_of_this_client = {}
        for name, vote in zip(all_names, votes_of_this_client):
            self.votes_of_this_client[name] = vote

        self.next(self.defend)

    @aggregator
    def defend(self, inputs):
        # Following the CrowdGuard paper, this should be executed within SGX

        all_names = list(self.all_models.keys())
        all_votes_by_name = {i.collaborator_name: i.votes_of_this_client for i in inputs}

        all_models = [self.all_models[name] for name in all_names]
        binary_votes = [[all_votes_by_name[own_name][val_name] for val_name in all_names] for
                        own_name in all_names]

        ac_e = AgglomerativeClustering(n_clusters=2, distance_threshold=None,
                                       compute_full_tree=True,
                                       metric="euclidean", memory=None, connectivity=None,
                                       linkage='single',
                                       compute_distances=True).fit(binary_votes)
        ac_e_labels: list = ac_e.labels_.tolist()
        agglomerative_result = create_cluster_map_from_labels(len(all_names), ac_e_labels)
        print(f'Agglomerative Clustering: {agglomerative_result}')
        agglomerative_negative_cluster = agglomerative_result[
            determine_biggest_cluster(agglomerative_result)]

        db_scan_input_idx_list = agglomerative_negative_cluster
        print(f'DBScan Input: {db_scan_input_idx_list}')
        db_scan_input_list = [binary_votes[vote_id] for vote_id in db_scan_input_idx_list]

        db = DBSCAN(eps=0.5, min_samples=1).fit(db_scan_input_list)
        dbscan_clusters = create_cluster_map_from_labels(len(agglomerative_negative_cluster),
                                                         db.labels_.tolist())
        biggest_dbscan_cluster = dbscan_clusters[determine_biggest_cluster(dbscan_clusters)]
        print(f'DBScan Clustering: {biggest_dbscan_cluster}')

        single_sample_of_biggest_cluster = biggest_dbscan_cluster[0]
        final_voting = db_scan_input_list[single_sample_of_biggest_cluster]
        negatives = [i for i, vote in enumerate(final_voting) if vote == VOTE_FOR_BENIGN]
        recognized_benign_models = [all_models[n] for n in negatives]

        print(f'Negatives: {negatives}')

        self.model = FedAvg([m.cpu() for m in recognized_benign_models])
        del inputs
        self.round_num += 1
        if self.round_num < self.total_rounds:
            print(f'Finished round {self.round_num}/{self.total_rounds}')
            self.next(self.train, foreach="collaborators")
        else:
            self.next(self.end)

    @aggregator
    def end(self):
        print(20 * "#")
        print("All rounds completed successfully")
        print(20 * "#")
        print("This is the end of the flow")
        print(20 * "#")

## Defining and Initializing the Federated Runtime
We initialize the Federated Runtime by providing:
- `director_info`: The director's connection information 
- `authorized_collaborators`: A list of authorized collaborators
- `notebook_path`: Path to this Jupyter notebook.

In [None]:
#| export

from openfl.experimental.workflow.runtime import FederatedRuntime

director_info = {
    'director_node_fqdn':'localhost',
    'director_port':50050,
}

authorized_collaborators = ['Amsterdam', 'Bangalore', 'Chandler', 'Detroit']

federated_runtime = FederatedRuntime(
    collaborators=authorized_collaborators,
    director=director_info, 
    notebook_path='./FederatedCrowdGuard.ipynb',
)

The status of the connected Envoys can be checked using the `get_envoys()` method of the `federated_runtime`.

In [None]:
federated_runtime.get_envoys()

With the federated_runtime now instantiated, we will proceed to deploy the workspace and run the experiment!

In [None]:
#| export

seed_random_generators(RANDOM_SEED)

model = Net()
pmr = NUMBER_OF_MALICIOUS_CLIENTS / TOTAL_CLIENT_NUMBER

flflow = FederatedFlow_CrowdGuard(
    model,
    optimizer_type='SGD',
    total_rounds = 5,
    top_model_accuracy = 0,
    pmr = pmr,
    aggregation_algorithm = 'CrowdGuard'
)

flflow.runtime = federated_runtime
flflow.run()