# Imports


Standard Library & Package Imports


In [1]:
import itertools
from functools import partial
from typing import List

import torch

Project Imports


In [2]:
from attacks.weight_attack import weight_attack
from federated.client import (
    AdversarialClient,
    PrivateClient,
    PrivateAdversarialClient,
    PublicClient,
    TClient
)
from federated.server import Server
from data_loaders.mnist.data_loader import DataLoader
from models.mnist.mnist_cnn import MnistCNN as Model
from setup import FederatedLearningConfig 
from training import train_model, test_model 
from utilities import save_results

# Setting up the config, dataloader, clients, and server


In [3]:
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda") 
else:
    DEVICE = torch.device("cpu")

In [4]:
def calculate_trustworthy_thresholds(n_clients):
    return [
        0,
        1 / (1.1 * n_clients),
        1 / (1.3 * n_clients),
        1 / (1.5 * n_clients),
        1 / (2.0 * n_clients),
    ]

Get the configuration, torch device, and data loader


In [11]:
CONFIG_OPTIONS = {
    "n_clients": [1, 2, 5, 10],
    "n_adv": [0, 1, 2, 5],
    "noise_multiplier": [0.001, 0.01, 0.1, 1, 10],
    "n_rounds": [40],
    "L": [-1],
    "batch_size": [64],
    "should_use_iid_training_data": [False, True],
    "should_enable_adv_protection": [False, True],
    "should_use_private_clients": [False],
    "target_epsilon": [None],
    "target_delta": [None],
}


# Function to create configurations with dynamic trustworthy thresholds
def create_configs(configs):
    # Get the keys and values, excluding trustworthy_threshold
    keys = list(configs.keys())
    values = list(configs.values())
    
    # Generate initial configurations without trustworthy thresholds
    initial_configs = [dict(zip(keys, v)) for v in itertools.product(*values)]
    
    # Extend configurations with dynamic trustworthy thresholds
    extended_configs = []
    for config in initial_configs:
        n_clients = config['n_clients']
        thresholds = calculate_trustworthy_thresholds(n_clients)
        # always include 0 threshold
        config['trustworthy_threshold'] = 0
        extended_configs.append(config.copy())

        if config["should_enable_adv_protection"]:
            for threshold in thresholds[1:]:
                extended_config = config.copy()
                extended_config['trustworthy_threshold'] = threshold
                extended_configs.append(extended_config)

    # filter out configs where n_adv > n_clients
    extended_configs = [config for config in extended_configs if config['n_adv'] <= config['n_clients']]
    # filter out configs where noise_multiplier != 0.001 and n_adv == 0
    extended_configs = [config for config in extended_configs if config['noise_multiplier'] == 0.001 or config['n_adv'] != 0]
    # filter out configs where should_enable_adv_protection == True and n_adv == 0
    extended_configs = [config for config in extended_configs if not config['should_enable_adv_protection'] or config['n_adv'] != 0]
    # filter out all n_adv > 0 configs when n_clients == 1
    extended_configs = [config for config in extended_configs if config['n_clients'] > 1 or config['n_adv'] == 0]
    # filter out should_enable_adv_protection == True configs when n_clients == 1
    extended_configs = [config for config in extended_configs if config['n_clients'] > 1 or not config['should_enable_adv_protection']]
    # filter out n_clients = 1 and should_use_iid_training_data = False
    extended_configs = [config for config in extended_configs if config['n_clients'] > 1 or config['should_use_iid_training_data']]
    
    return extended_configs

configs = create_configs(CONFIG_OPTIONS)

In [12]:
len(configs)

487

In [13]:
configs

[{'n_clients': 1,
  'n_adv': 0,
  'noise_multiplier': 0.001,
  'n_rounds': 40,
  'L': -1,
  'batch_size': 64,
  'should_use_iid_training_data': True,
  'should_enable_adv_protection': False,
  'should_use_private_clients': False,
  'target_epsilon': None,
  'target_delta': None,
  'trustworthy_threshold': 0},
 {'n_clients': 2,
  'n_adv': 0,
  'noise_multiplier': 0.001,
  'n_rounds': 40,
  'L': -1,
  'batch_size': 64,
  'should_use_iid_training_data': False,
  'should_enable_adv_protection': False,
  'should_use_private_clients': False,
  'target_epsilon': None,
  'target_delta': None,
  'trustworthy_threshold': 0},
 {'n_clients': 2,
  'n_adv': 0,
  'noise_multiplier': 0.001,
  'n_rounds': 40,
  'L': -1,
  'batch_size': 64,
  'should_use_iid_training_data': True,
  'should_enable_adv_protection': False,
  'should_use_private_clients': False,
  'target_epsilon': None,
  'target_delta': None,
  'trustworthy_threshold': 0},
 {'n_clients': 2,
  'n_adv': 1,
  'noise_multiplier': 0.001,
  'n_

# Training Loop


In [8]:
for c_i, cfg in enumerate(configs):
    print(f"Running configuration {c_i + 1}/{len(configs)}")
    fl_config = FederatedLearningConfig(**cfg)
    print(fl_config)

    batch_size = fl_config.batch_size
    enable_adv_protection = fl_config.should_enable_adv_protection
    n_adv = fl_config.n_adv
    n_clients = fl_config.n_clients
    noise_multiplier = fl_config.noise_multiplier
    num_rounds = fl_config.n_rounds
    should_use_private = fl_config.should_use_private_clients
    target_delta = fl_config.target_delta
    target_epsilon = fl_config.target_epsilon
    trust_threshold = fl_config.trustworthy_threshold
    use_iid = fl_config.should_use_iid_training_data

    attack = partial(weight_attack,
                     noise_multiplier=noise_multiplier)

    dataloader = DataLoader(
            batch_size=batch_size,
            device=DEVICE,
            test_split=0.2,
            val_split=0.2,
            n_clients=n_clients,
            use_iid=use_iid,
        )

    adv_clients: List[TClient] = [
            AdversarialClient(
                id=f"Adversarial Client {i}",
                model=Model(),
                device=DEVICE,
                data_loader=dataloader.train_loaders[i],
                attack=attack,
            )
            if not should_use_private
            else PrivateAdversarialClient(
                id=f"Private Adversarial Client {i}",
                model=Model(),
                device=DEVICE,
                data_loader=dataloader.train_loaders[i],
                target_epsilon=target_epsilon,
                target_delta=target_delta,
                num_epochs=num_rounds,
                max_grad_norm=100.0,
                attack=attack,
            )
            for i in range(n_adv)
        ]
    non_adv_clients: List[TClient] = [
            (
                PublicClient(
                    id=f"Client {i}",
                    model=Model(),
                    device=DEVICE,
                    data_loader=dataloader.train_loaders[i],
                )
            )
            if not should_use_private
            else (
                PrivateClient(
                    id=f"Private Client {i}",
                    model=Model(),
                    device=DEVICE,
                    data_loader=dataloader.train_loaders[i],
                    target_epsilon=target_epsilon,
                    target_delta=target_delta,
                    num_epochs=num_rounds,
                    max_grad_norm=100.0,
                )
            )
            for i in range(n_adv, n_clients)
        ]
    clients = adv_clients + non_adv_clients

    server = Server(
        model=Model(),
        device=DEVICE,
        validation_data=dataloader.val_loader,
        enable_adversary_protection=enable_adv_protection,
        weight_threshold=trust_threshold,
    )

    train_model(
        server=server,
        num_rounds=num_rounds,
        clients=clients,
        L=fl_config.L,
        is_verbose=False
    )

    test_model(
        model=server.model,
        test_loader=dataloader.test_loader,
    )

    save_results(
        server=server,
        clients=clients,
        config=fl_config,
    )

Running configuration 1/548
FederatedLearningConfig(n_clients=1, n_adv=0, noise_multiplier=0.001, n_rounds=40, L=-1, batch_size=64, trustworthy_threshold=0, should_use_iid_training_data=False, should_enable_adv_protection=False, should_use_private_clients=False, target_epsilon=None, target_delta=None)
Round 1 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 178.14it/s]


Accuracy of the server on the validation images: 95.47%
Round 3 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 187.67it/s]


Accuracy of the server on the validation images: 98.26%
Round 5 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 183.51it/s]


Accuracy of the server on the validation images: 98.74%
Round 7 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 186.55it/s]


Accuracy of the server on the validation images: 98.78%
Round 9 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 179.17it/s]


Accuracy of the server on the validation images: 98.94%
Round 11 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 172.70it/s]


Accuracy of the server on the validation images: 98.99%
Round 13 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 178.14it/s]


Accuracy of the server on the validation images: 99.05%
Round 15 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 173.75it/s]


Accuracy of the server on the validation images: 99.06%
Round 17 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 157.76it/s]


Accuracy of the server on the validation images: 98.99%
Round 19 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 165.37it/s]


Accuracy of the server on the validation images: 99.03%
Round 21 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 152.26it/s]


Accuracy of the server on the validation images: 99.15%
Round 23 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 179.12it/s]


Accuracy of the server on the validation images: 99.03%
Round 25 of 40


Validating model: 100%|██████████| 188/188 [00:01<00:00, 149.63it/s]


Accuracy of the server on the validation images: 99.07%


KeyboardInterrupt: 