In [None]:
# Copyright (C) 2022-2024 TU Darmstadt
# SPDX-License-Identifier: Apache-2.0

# -----------------------------------------------------------
# Primary author: Phillip Rieger <phillip.rieger@trust.tu-darmstadt.de>
# Co-authored-by: Torsten Krauss <torsten.krauss@uni-wuerzburg.de>
# ------------------------------------------------------------

import argparse
import os
import pickle
import time
import warnings
from copy import deepcopy
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
import torch.optim as optim
from torchvision import transforms, datasets
from sklearn.cluster import AgglomerativeClustering, DBSCAN

from CrowdGuardClientValidation import CrowdGuardClientValidation
from openfl.experimental.workflow.interface import Aggregator, Collaborator, FLSpec
from openfl.experimental.workflow.placement import aggregator, collaborator
from openfl.experimental.workflow.runtime import LocalRuntime
from cifar10_crowdguard import MEAN, STD_DEV, poison_data, seed_random_generators
from cifar10_crowdguard import BATCH_SIZE_TRAIN, BATCH_SIZE_TEST, Net, test, default_optimizer
from cifar10_crowdguard import FederatedFlow
from cifar10_crowdguard import PRETRAINED_MODEL_FILE, download_pretrained_model
warnings.filterwarnings("ignore")

In [None]:
TOTAL_CLIENT_NUMBER = 4
PMR = 0.25
NUMBER_OF_MALICIOUS_CLIENTS = max(1, int(TOTAL_CLIENT_NUMBER * PMR)) if PMR > 0 else 0
NUMBER_OF_BENIGN_CLIENTS = TOTAL_CLIENT_NUMBER - NUMBER_OF_MALICIOUS_CLIENTS
NUMBER_OF_ROUNDS = 10

In [None]:
class CommandLineArgumentSimulator:
    
    def __init__(self):
        self.test_dataset_ratio = 0.4
        self.train_dataset_ratio = 0.4
        self.log_dir = 'test_debug'
        self.comm_round = NUMBER_OF_ROUNDS
        self.flow_internal_loop_test=False
        self.optimizer_type = 'SGD'
        
args = CommandLineArgumentSimulator()

In [None]:
download_pretrained_model()

In [None]:
aggregator_object = Aggregator()
aggregator_object.private_attributes = {}
collaborator_names = [f'benign_{i:02d}' for i in range(NUMBER_OF_BENIGN_CLIENTS)] + [f'malicious_{i:02d}' for i in range(NUMBER_OF_MALICIOUS_CLIENTS)]    
collaborators = [Collaborator(name=name) for name in collaborator_names]
if torch.cuda.is_available():
    device = torch.device(
        "cuda:1"
        )  # This will enable Ray library to reserve available GPU(s) for the task
else:
    device = torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD_DEV),])

cifar_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
cifar_train = [x for x in cifar_train]
cifar_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
cifar_test = [x for x in cifar_test]
X = torch.stack([x[0] for x in cifar_train] + [x[0] for x in cifar_test])
Y = torch.LongTensor(np.stack(np.array([x[1] for x in cifar_train] + [x[1] for x in cifar_test])))

In [None]:
seed_random_generators(0)
shuffled_indices = np.arange(X.shape[0])
np.random.shuffle(shuffled_indices)

N_total_samples = len(cifar_test) + len(cifar_train)
train_dataset_size = int(N_total_samples * args.train_dataset_ratio)
test_dataset_size = int(N_total_samples * args.test_dataset_ratio)
X = X[shuffled_indices]
Y = Y[shuffled_indices]

train_dataset_data = X[:train_dataset_size]
train_dataset_targets = Y[:train_dataset_size]

test_dataset_data = X[train_dataset_size:train_dataset_size + test_dataset_size]
test_dataset_targets = Y[train_dataset_size:train_dataset_size + test_dataset_size]
print(f"Dataset info (total {N_total_samples}): train - {test_dataset_targets.shape[0]}, "
          f"test - {test_dataset_targets.shape[0]}, ")


In [None]:
for idx, collab in enumerate(collaborators):
    # construct the training and test and population dataset
    benign_training_X = train_dataset_data[idx::len(collaborators)]
    benign_training_Y = train_dataset_targets[idx::len(collaborators)]
    
    if 'malicious' in collab.name:
        local_train_data, local_train_targets = poison_data(benign_training_X, benign_training_Y)
    else:
        local_train_data, local_train_targets = benign_training_X, benign_training_Y
    

    local_test_data = test_dataset_data[idx::len(collaborators)]
    local_test_targets = test_dataset_targets[idx::len(collaborators)]
    

    poison_test_data, poison_test_targets = poison_data(local_test_data, local_test_targets,
                                                        pdr=1.0)

    collab.private_attributes = {
        "train_loader": torch.utils.data.DataLoader(
            TensorDataset(local_train_data, local_train_targets),
            batch_size=BATCH_SIZE_TRAIN, shuffle=True
            ),
        "test_loader": torch.utils.data.DataLoader(
            TensorDataset(local_test_data, local_test_targets),
            batch_size=BATCH_SIZE_TEST, shuffle=False
            ),
        "backdoor_test_loader": torch.utils.data.DataLoader(
            TensorDataset(poison_test_data, poison_test_targets),
            batch_size=BATCH_SIZE_TEST, shuffle=False
            ),
        }

In [None]:
pretrained_weights = torch.load(PRETRAINED_MODEL_FILE, map_location=device)
test_model = Net().to(device)
test_model.load_state_dict(pretrained_weights)
test(test_model, collab.private_attributes['train_loader'], device, test_train='Train')
test(test_model, collab.private_attributes['test_loader'], device)
test(test_model, collab.private_attributes['backdoor_test_loader'], device, mode='Backdoor')

In [None]:
local_runtime = LocalRuntime(aggregator=aggregator_object, collaborators=collaborators)

print(f"Local runtime collaborators = {local_runtime.collaborators}")

# change to the internal flow loop
model = Net()
model.load_state_dict(pretrained_weights)
top_model_accuracy = 0
optimizers = {
    collaborator.name: default_optimizer(model, optimizer_type=args.optimizer_type)
    for collaborator in collaborators
    }
flflow = FederatedFlow(
    model,
    optimizers,
    device,
    args.comm_round,
    top_model_accuracy,
    NUMBER_OF_MALICIOUS_CLIENTS / TOTAL_CLIENT_NUMBER,
    'CrowdGuard'
    )

flflow.runtime = local_runtime
flflow.run()