# Imports


Standard Library & Package Imports


In [None]:
from functools import partial
from typing import List

import torch

Project Imports


In [None]:
from attacks.weight_attack import weight_attack
from federated.client import (
    AdversarialClient,
    PrivateClient,
    PrivateAdversarialClient,
    PublicClient,
)
from federated.server import Server
from data_loaders.mnist.data_loader import DataLoader
from models.mnist.mnist_cnn import MnistCNN as Model
from setup import FederatedLearningConfig 
from training import train_model, test_model, TClient
from utilities import save_results

# Setting up the config, dataloader, clients, and server


Get the configuration, torch device, and data loader


In [None]:
config = FederatedLearningConfig(
    n_clients=10,
    n_adv=2,
    noise_multiplier=0.1,
    n_rounds=1,
    L=-1,
    batch_size=64,
    trustworthy_threshold=0,
    should_use_iid_training_data=True,
    should_enable_adv_protection=True,
    should_use_private_clients=False,
    target_epsilon=None,
    target_delta=None,
)

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda") 
else:
    DEVICE = torch.device("cpu")

In [None]:
batch_size = config.batch_size
n_clients = config.n_clients
use_iid = config.should_use_iid_training_data

dataloader = DataLoader(
        batch_size=batch_size,
        device=DEVICE,
        test_split=0.2,
        val_split=0.2,
        n_clients=n_clients,
        use_iid=use_iid,
    )

Set up the attack, adversarial clients, and regular clients


In [None]:
attack = partial(weight_attack,
                 noise_multiplier=config.noise_multiplier)

n_adv = config.n_adv
num_rounds = config.n_rounds
should_use_private = config.should_use_private_clients
target_epsilon = config.target_epsilon
target_delta = config.target_delta

adv_clients: List[TClient] = [
        AdversarialClient(
            id=f"Adversarial Client {i}",
            model=Model(),
            device=DEVICE,
            data_loader=dataloader.train_loaders[i],
            attack=attack,
        )
        if not should_use_private
        else PrivateAdversarialClient(
            id=f"Private Adversarial Client {i}",
            model=Model(),
            device=DEVICE,
            data_loader=dataloader.train_loaders[i],
            target_epsilon=target_epsilon,
            target_delta=target_delta,
            num_epochs=num_rounds,
            max_grad_norm=100.0,
            attack=attack,
        )
        for i in range(n_adv)
    ]

non_adv_clients: List[TClient] = [
        (
            PublicClient(
                id=f"Client {i}",
                model=Model(),
                device=DEVICE,
                data_loader=dataloader.train_loaders[i],
            )
        )
        if not should_use_private
        else (
            PrivateClient(
                id=f"Private Client {i}",
                model=Model(),
                device=DEVICE,
                data_loader=dataloader.train_loaders[i],
                target_epsilon=target_epsilon,
                target_delta=target_delta,
                num_epochs=num_rounds,
                max_grad_norm=100.0,
            )
        )
        for i in range(n_adv, n_clients)
    ]

clients = adv_clients + non_adv_clients

Set up the server


In [None]:
enable_adv_protection = config.should_enable_adv_protection
trust_threshold = config.trustworthy_threshold

server = Server(
        model=Model(),
        device=DEVICE,
        validation_data=dataloader.val_loader,
        enable_adversary_protection=enable_adv_protection,
        weight_threshold=trust_threshold,
    )

# Training


In [None]:
train_model(
    server=server,
    num_rounds=num_rounds,
    clients=clients,
    L=config.L,
    is_verbose=False
)

# Evaluation and Saving


In [None]:
test_model(
    model=server.model,
    test_loader=dataloader.test_loader,
)

In [None]:
save_results(
    server=server,
    clients=clients,
    config=config,
)