# Federated Learning Model Poisoning Attack Simulation

In [57]:
!pip install -q flwr[simulation] torch torchvision matplotlib

from collections import OrderedDict
from typing import List, Tuple, Optional, Callable
from matplotlib import pyplot as plt
import certifi
import ssl
import numpy as np
import torch
import torch.nn as nn
import random
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset, Dataset
from torchvision.datasets import CIFAR10
import flwr as fl
from flwr.common import Metrics
import os

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # fixed something with matplotlib killing the kernel
ssl._create_default_https_context = ssl._create_unverified_context  # fixed something with the ssl certificate of the dataset
DEVICE = torch.device("cpu")  # "cpu" -> train in cpu | "cuda" -> train in gpu

In [58]:
from logging import WARNING # we need those imports to implement the strategy class
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg, aggregate_krum
from flwr.server.strategy.strategy import Strategy

In [59]:
CLASSES = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") # classes of the example dataset

NUM_CLIENTS = 10 # this is the number of devices participating in the federated learning

BATCH_SIZE = 32 # this is the size of a mini-batch for the training of a CNN using SGD

In [60]:
class MaliciousDataset(Dataset):
    def __init__(self):
        self.modifiedData = []

    def __len__(self):
        return len(self.modifiedData)

    def __getitem__(self, idx):
        return self.modifiedData[idx]

    def __insertitem___(self, item):
        self.modifiedData.append(item)

In [61]:
def load_datasets(): # download and transform cifar-10
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # split training set into 10 partitions to simulate the individual dataset
    partition_size = len(trainset) // NUM_CLIENTS
    lengths = [partition_size] * NUM_CLIENTS
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    dsidx = 0
    for ds in datasets:
        len_val = len(ds) // 10  # 10% validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))

        #############################################################################
        ######################----SIMULATE THE DATA ATTACK----#######################
        #############################################################################
        if dsidx < int(NUM_CLIENTS*0.3):
            datasetIndices = ds_train.indices
            clientDataset = ds_train.dataset
            poisonedDataset = MaliciousDataset()
            for index in range(len(datasetIndices)):
                poisonedDataset.__insertitem___((clientDataset[datasetIndices[index]][0], random.randint(0, 9)))
            ds_train = Subset(poisonedDataset, [i for i in range(poisonedDataset.__len__())])
        #############################################################################
        #############################################################################
        #############################################################################

        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
        dsidx = dsidx + 1
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets() # load the datasets

Files already downloaded and verified
Files already downloaded and verified


In [62]:
class LeNet5(nn.Module):
    def __init__(self, num_classes, grayscale=False):
        super(LeNet5, self).__init__()

        self.grayscale = grayscale
        self.num_classes = num_classes

        if self.grayscale:
            in_channels = 1
        else:
            in_channels = 3

        self.features = nn.Sequential(

            nn.Conv2d(in_channels, 6*in_channels, kernel_size=5),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(6*in_channels, 16*in_channels, kernel_size=5),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(16*5*5*in_channels, 120*in_channels),
            nn.Tanh(),
            nn.Linear(120*in_channels, 84*in_channels),
            nn.Tanh(),
            nn.Linear(84*in_channels, num_classes),
        )


    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas

In [63]:
def train(net, trainloader, epochs: int, verbose=False): # we need a train and a test function that our clients will be using
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):

        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad() # reset the gradients
            outputs, probas = net(images) # do the forward pass
            loss = criterion(outputs, labels) # calculate the loss function
            loss.backward() # calculate the gradients of the loss function
            optimizer.step() # do one stochastic gradient descent step
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total

        if verbose: # log information
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs, probas = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(probas.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [64]:
def get_parameters(net) -> List[np.ndarray]: # we need these two function because this way flower knows how
    return [val.cpu().numpy() for _, val in net.state_dict().items()] # to serialize/deserialize data

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

In [65]:
class FlowerClient(fl.client.NumPyClient): # here we define the FlowerClient. We have to implement only the following 3 methods
    def __init__(self, net, trainloader, valloader): # of course we can customize these methods as we like
        self.net = net                            # we can add extra member variables as well
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [66]:
class BenignFlowerClient(FlowerClient):
    def __init__(self, net, trainloader, valloader):
        super().__init__(net, trainloader, valloader)

    def fit(self, parameters, config):
        #print("Benign Client Got Selected", flush=True)
        parameters, num_examples, _ = super().fit(parameters, config)
        return parameters, num_examples, {"intention": "BENIGN"}

class MaliciousFlowerClient(FlowerClient):
    def __init__(self, net, trainloader, valloader):
        super().__init__(net, trainloader, valloader)

    def fit(self, parameters, config):
        #print("Malicious Client Got Selected", flush=True)
        parameters, num_examples, _ = super().fit(parameters, config)
        return parameters, num_examples, {"intention": "MALICIOUS"}

In [67]:
def client_fn(cid: str) -> FlowerClient: # this is a factory function (factory desing pattern)
    net = LeNet5(10).to(DEVICE) # flower calls this function to create FlowerClients on demand (this way it uses less memory)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    if int(cid) <= NUM_CLIENTS * 0.2:    # 20 - 80 --> malicious - benign
        return MaliciousFlowerClient(net, trainloader, valloader)
    else:
        return BenignFlowerClient(net, trainloader, valloader)

In [68]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # we pass this function as an argument to the Strategy
    # calculate accuracy
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # return it as a dictionary
    return {"accuracy": sum(accuracies) / sum(examples)}

In [69]:
class AttackSimulationStrategy(fl.server.strategy.FedAvg): # we inherit from FedAvg strategy and change only what we need
    def __init__(self, *, fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_fit_clients: int = 2,
        min_evaluate_clients: int = 2, min_available_clients: int = 2, num_malicious_clients: int = 2,
        num_clients_to_keep: int = 0, evaluate_fn: Optional[Callable[[int, NDArrays, Dict[str, Scalar]],
        Optional[Tuple[float, Dict[str, Scalar]]],]] = None, on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        perturbationVector: str, adversaryKnowledge: str
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients, min_available_clients=min_available_clients, evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, accept_failures=accept_failures,
            initial_parameters=initial_parameters, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.num_malicious_clients = num_malicious_clients
        self.num_clients_to_keep = num_clients_to_keep
        self.perturbationVector = perturbationVector
        self.adversaryKnowledge = adversaryKnowledge

    def __repr__(self) -> str:
        rep = f"AttackSimulationStrategy(accept_failures={self.accept_failures})"
        return rep

    def aggregate_fit(self, server_round: int, results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:

        if not results: # boilerplate code to handle exceptions
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        # convert results
        total_weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results]

        # calculate AGR
        parameters_aggregated = ndarrays_to_parameters(aggregate_krum( # this thing here is for krum-AGR
            total_weights_results, self.num_malicious_clients, self.num_clients_to_keep))

        # boilerplate code if aggregation metrics are provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated



In [70]:
# create a strategy
strategy = AttackSimulationStrategy(fraction_fit=0.8, fraction_evaluate=0.8, min_fit_clients=5,
    min_evaluate_clients=5, min_available_clients=10, perturbationVector="InverseStd", adversaryKnowledge="agr-only",
    evaluate_metrics_aggregation_fn=weighted_average,) # <-- pass the metric aggregation function. This function will be called
                                                       # in every federated learning round for evaluation (it aggregates the
                                                       # client-side evaluation metrics in the server)

# start simulation
fl.simulation.start_simulation(
    client_fn=client_fn, # out factory function
    num_clients=NUM_CLIENTS, # number of clients
    config=fl.server.ServerConfig(num_rounds=5), # number of federated learning rounds
    strategy=strategy, # our attack simulation strategy
    client_resources=None,
)

INFO flwr 2023-10-17 19:33:14,286 | app.py:175 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
2023-10-17 19:33:21,181	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2023-10-17 19:33:26,072 | app.py:210 | Flower VCE: Ray initialized with resources: {'object_store_memory': 3916605849.0, 'memory': 7833211700.0, 'CPU': 2.0, 'node:__internal_head__': 1.0, 'GPU': 1.0, 'node:172.28.0.12': 1.0}
INFO:flwr:Flower VCE: Ray initialized with resources: {'object_store_memory': 3916605849.0, 'memory': 7833211700.0, 'CPU': 2.0, 'node:__internal_head__': 1.0, 'GPU': 1.0, 'node:172.28.0.12': 1.0}
INFO flwr 2023-10-17 19:33:26,076 | app.py:218 | No `client_resources` specified. Using minimal resources for clients.
INFO:flwr:No `client_resources` specified. Using minimal resources for clients.
INFO flwr 2023-10-17 19:33:26,098 | app.py:224 | Flower VCE: Resourc

History (loss, distributed):
	round 1: 0.07381261092424393
	round 2: 0.05530920425057411
	round 3: 0.05014918866753578
	round 4: 0.049052857935428615
	round 5: 0.04509883189201355
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.09774999999999999), (2, 0.38125), (3, 0.426), (4, 0.45524999999999993), (5, 0.48575)]}