# Federated Learning Model Poisoning Attack Simulation

In [3]:
!pip install -q flwr[simulation] torch torchvision matplotlib

from collections import OrderedDict
from typing import List, Tuple, Optional, Callable
from matplotlib import pyplot as plt
import certifi
import ssl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
import flwr as fl
from flwr.common import Metrics
import os

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # fixed something with matplotlib killing the kernel
ssl._create_default_https_context = ssl._create_unverified_context  # fixed something with the ssl certificate of the dataset
DEVICE = torch.device("cpu")  # "cpu" -> train in cpu | "cuda" -> train in gpu

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.4/200.4 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m201.4/201.4 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m76.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.1/98.1 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.2/128.2 kB[0m [31m10.5 MB/s[0m et

In [4]:
from logging import WARNING # we need those imports to implement the strategy class
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg, aggregate_krum
from flwr.server.strategy.strategy import Strategy

In [5]:
CLASSES = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") # classes of the example dataset

NUM_CLIENTS = 10 # this is the number of devices participating in the federated learning

BATCH_SIZE = 32 # this is the size of a mini-batch for the training of a CNN using SGD

In [6]:
def load_datasets(): # download and transform cifar-10
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # split training set into 10 partitions to simulate the individual dataset
    partition_size = len(trainset) // NUM_CLIENTS
    lengths = [partition_size] * NUM_CLIENTS
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10% validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets() # load the datasets

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 79581848.65it/s]


Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified


In [7]:
class Net(nn.Module): # here we create a neural network class that inherits from nn.Module
    def __init__(self) -> None: # we have to implement only two functions, the constructor and the forward pass
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
def train(net, trainloader, epochs: int, verbose=False): # we need a train and a test function that our clients will be using
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):

        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad() # reset the gradients
            outputs = net(images) # do the forward pass
            loss = criterion(outputs, labels) # calculate the loss function
            loss.backward() # calculate the gradients of the loss function
            optimizer.step() # do one stochastic gradient descent step
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total

        if verbose: # log information
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [9]:
def get_parameters(net) -> List[np.ndarray]: # we need these two function because this way flower knows how
    return [val.cpu().numpy() for _, val in net.state_dict().items()] # to serialize/deserialize data

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

In [10]:
class FlowerClient(fl.client.NumPyClient): # here we define the FlowerClient. We have to implement only the following 3 methods
    def __init__(self, net, trainloader, valloader): # of course we can customize these methods as we like
        self.net = net                            # we can add extra member variables as well
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [11]:
class BenignFlowerClient(FlowerClient):
    def __init__(self, net, trainloader, valloader):
        super().__init__(net, trainloader, valloader)

    def fit(self, parameters, config):
        #print("Benign Client Got Selected", flush=True)
        parameters, num_examples, _ = super().fit(parameters, config)
        return parameters, num_examples, {"intention": "BENIGN"}

class MaliciousFlowerClient(FlowerClient):
    def __init__(self, net, trainloader, valloader):
        super().__init__(net, trainloader, valloader)

    def fit(self, parameters, config):
        #print("Malicious Client Got Selected", flush=True)
        parameters, num_examples, _ = super().fit(parameters, config)
        return parameters, num_examples, {"intention": "MALICIOUS"}

In [12]:
def client_fn(cid: str) -> FlowerClient: # this is a factory function (factory desing pattern)
    net = Net().to(DEVICE) # flower calls this function to create FlowerClients on demand (this way it uses less memory)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    if int(cid) <= NUM_CLIENTS * 0.2:    # 20 - 80 --> malicious - benign
        return MaliciousFlowerClient(net, trainloader, valloader)
    else:
        return BenignFlowerClient(net, trainloader, valloader)

In [13]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # we pass this function as an argument to the Strategy
    # calculate accuracy
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # return it as a dictionary
    return {"accuracy": sum(accuracies) / sum(examples)}

In [36]:
class AttackSimulationStrategy(fl.server.strategy.FedAvg): # we inherit from FedAvg strategy and change only what we need
    def __init__(self, *, fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_fit_clients: int = 2,
        min_evaluate_clients: int = 2, min_available_clients: int = 2, num_malicious_clients: int = 2,
        num_clients_to_keep: int = 0, evaluate_fn: Optional[Callable[[int, NDArrays, Dict[str, Scalar]],
        Optional[Tuple[float, Dict[str, Scalar]]],]] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        perturbationVector: str, adversaryKnowledge: str
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients, min_available_clients=min_available_clients, evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures,
            initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.num_malicious_clients = num_malicious_clients
        self.num_clients_to_keep = num_clients_to_keep
        self.perturbationVector = perturbationVector
        self.adversaryKnowledge = adversaryKnowledge

    def __repr__(self) -> str:
        rep = f"AttackSimulationStrategy(accept_failures={self.accept_failures})"
        return rep

    def aggregate_fit(self, server_round: int, results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        if not results: # boilerplate code to handle exceptions
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        # convert results
        total_weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results]

        """
        #############################################################################
        #########################----SIMULATE THE ATTACK----#########################
        #############################################################################
        benign_weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results
                         if fit_res.metrics["intention"] == "BENIGN"]
        if self.adversaryKnowledge == "agr-only" or self.adversaryKnowledge == "agnostic":
            malicious_weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results
                                if fit_res.metrics["intention"] == "MALICIOUS"]
            best_parameters = aggregate(total_weights_results) # malicious_weights_results
        else:
            best_parameters = aggregate(total_weights_results)

        # so far we have [1]: all clients parameters, [2]: benign clients parameters (seperate),
        #                [3]: malicious clients parameters (seperate), [4]: best_parameters (based on the adversary's knowledge)
        # no we have to optimize the model poisoning attack based on the adversary's knowledge and the perturbation vector

        # calculate the perturbation vector
        if self.perturbationVector == "InverseSign":
            perturbationVector = [-np.sign(best_parameter) for best_parameter in best_parameters]
        if self.perturbationVector == "InverseStd":
            parameters = [parameters for parameters, _ in total_weights_results]
            perturbationVector = []

            for parameterIdx in range(len(parameters[0])):
                arr = []
                for clientIdx in range(len(parameters)):
                    arr.append(parameters[clientIdx][parameterIdx])
                numpyArr = np.stack(arr)
                perturbationVector.append(-np.std(numpyArr, axis=0))

        # optimize for gamma
        gamma_init = 100.0
        gamma = gamma_init
        gamma_succ = 0.0
        t = 0.000001 # threshold
        step = gamma_init

        while True: # this is a c++ do-while loop emulation
            malicious_parameters = []
            for idx in range(len(perturbationVector)):
                malicious_parameters.append(best_parameters[idx] + gamma*perturbationVector[idx])

            new_malicious_weights_results = [(malicious_parameters, num_examples) for _, num_examples in malicious_weights_results]
            new_total_weights_results = new_malicious_weights_results + benign_weights_results

            are_equal = True
            krum_parameters = aggregate_krum(new_total_weights_results, self.num_malicious_clients, self.num_clients_to_keep)
            for idx in range(len(krum_parameters)):
                if np.array_equal(malicious_parameters[idx], krum_parameters[idx]) == False:
                    are_equal = False
                    break

            if are_equal:
                gamma_succ = gamma
                gamma = gamma + step/2
            else:
                gamma = gamma - step/2
            step = step/2

            if abs(gamma_succ - gamma) < t:
                break

            print(gamma, gamma_succ)

        malicious_parameters = []
        for idx in range(len(perturbationVector)):
            malicious_parameters.append(best_parameters[idx] + gamma_succ*perturbationVector[idx])
        new_malicious_weights_results = [(malicious_parameters, num_examples) for _, num_examples in malicious_weights_results]
        total_weights_results = new_malicious_weights_results + benign_weights_results
        #############################################################################
        #############################################################################
        #############################################################################
        """

        # calculate AGR
        parameters_aggregated = ndarrays_to_parameters(aggregate_krum( # this thing here is for krum-AGR
            total_weights_results, self.num_malicious_clients, self.num_clients_to_keep))

        # boilerplate code if aggregation metrics are provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated



In [37]:
# create a strategy
strategy = AttackSimulationStrategy(fraction_fit=0.8, fraction_evaluate=0.5, min_fit_clients=5,
    min_evaluate_clients=5, min_available_clients=10, perturbationVector="InverseStd", adversaryKnowledge="agr-only",
    evaluate_metrics_aggregation_fn=weighted_average,) # <-- pass the metric aggregation function. This function will be called
                                                       # in every federated learning round for evaluation (it aggregates the
                                                       # client-side evaluation metrics in the server)

# start simulation
fl.simulation.start_simulation(
    client_fn=client_fn, # out factory function
    num_clients=NUM_CLIENTS, # number of clients
    config=fl.server.ServerConfig(num_rounds=5), # number of federated learning rounds
    strategy=strategy, # our attack simulation strategy
    client_resources=None,
)

INFO flwr 2023-09-16 07:02:18,609 | app.py:175 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2023-09-16 07:02:23,414	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2023-09-16 07:02:26,191 | app.py:210 | Flower VCE: Ray initialized with resources: {'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'CPU': 2.0, 'object_store_memory': 3915856281.0, 'memory': 7831712564.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'node:__internal_head__': 1.0, 'node:172.28.0.12': 1.0, 'CPU': 2.0, 'object_store_memory': 3915856281.0, 'memory': 7831712564.0}
INFO flwr 2023-09-16 07:02:26,196 | app.py:218 | No `client_resources` specified. Using minimal resources for clients.
INFO:flwr:No `client_resources` specified. Using minimal resources for clients.
INFO flwr 2023-09-16 07:02:26,198 | app.py:224 | Flower VCE: Resources for each Virtual Clie

History (loss, distributed):
	round 1: 0.061935691118240356
	round 2: 0.0557009561061859
	round 3: 0.053341494178771966
	round 4: 0.053248284673690804
	round 5: 0.049612216806411744
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.2856), (2, 0.35960000000000003), (3, 0.38639999999999997), (4, 0.3996), (5, 0.43920000000000003)]}