Skip to content

Federated Adversrial Learning/ Training Framework. A testing ground for conducting relevant research.

Notifications You must be signed in to change notification settings


Repository files navigation

Federate Adversarial Learning


Federated Adversarial Training serves as a way to improve model adversarial robustness while preserving data privacy under federated learning settings. In general, data is biased and stored at remote clients, and clients will perform adversarial training on their skewed local data before uploading trained weights to one central server. Then, the server would aggregate weights from different clients and send back to clients. Clear as it seems, many problems need to addressed such as communication constraint, computational budget, robustness fairness and non-IID. Many methods have been proposed to overcome these issues. FAT is the first work to study this setting, FedDynAT reduces model shift through dynamic adjustment of local epochs in each communication round, FedRBN proposes an effective approach to propogate adversarial robustness among non-iid users, GEAR and CalFAT focus on the margin calibration of local class heterogeneity, Li et al. gives a comprehensive convergence analysis.

Fig.1 - FAL illustration


To this end, this repository is a testing field where I've implemented FAT, FedPGD, FedTRADES, FedMART, FedGAIRAT, CalFAT, Ensemble Distillation and aggregation algorithms such as FedAvg, FedProx, SCAFFOLD, FedNova, FedWAvg. I haven't bothered to run a full test yet, so be careful and get a little improvised. My code is built on the top of and You could also treat this as unofficial implementation of:

General Code Structure: structure

🚀 Getting Started

First clone this repo, then run

pip install -r requiremenets.txt

I use YACS for configuration file. Detailed explanation about parameters is in

config file, details about yacs:

from yacs.config import CfgNode as CN

_C = CN()
# set up random seed.
_C.SEED = 3047
# ["mnist", "fmnist", "femnist", "cifar10", "cifar100", "svhn"]
_C.DATASET = "mnist"
# ["simple-cnn", "moderate-cnn", "vggs", "resnet18","resnet50", "alexnet", "wideresnet", "small-cnn", "nin"]
_C.MODEL = "vgg16"
# data partition strategy, more details:
_C.PARTITION = "noniid-labeldir"
# concentration number for dirichlet distribution
_C.BETA = 0.5
# the frequency to save models
_C.CKPT = 100
# whether to assign helpers to clients during the training process. If assigned, you might want to tweak the loss both in Loss floder and local updates.
_C.HELP = False

# Data Path and Log Path
_C.PATH = CN()
_C.PATH.DATADIR = "../data/"
_C.PATH.DISTILLDATA = "../data/distillation_data"
_C.PATH.LOGDIR = "./logs/"

# Number of remote clients
# Communication rounds
# Aggregation Algorithms: ["FedAvg", "FedProx", "FedNova", "SCAFFOLD", "FedWAvg"]
# Fraction of clients used for updating each communication round(float, 1.0)
# Scale factor in FedWAvg algorithm
# the frequency to assign new helpers for each client
# learning rate for distillation process, we keep distillation optimizer the same as local trianing optimizer, neglect weight decay
_C.SERVER.LR = 0.0001

_C.USER = CN()
# Number of local training epochs
# Number of Training batch size
# Number of learning rate
_C.USER.LR = 0.001
# optimizer, ["SGD", "Adam", "Adagrad"]
# momentum in SGD
# Whethter to perform adversarial training and test.
_C.USER.ADV = True
# weight decay in optimizer
# fraction of local data to do adversarial training
# The proximal term parameter for FedProx
_C.USER.MU = 1.0
# the number of helpers (<= number of clients)

def get_cfg_defaults():
    """Get a yacs CfgNode object with default values for my_project."""
    # Return a clone so that the defaults will not be altered
    # This is for the "local variable" use pattern
    return _C.clone()

You can write new experiments configuration yaml in ./experiments, and different experiments details:

To run CalFAT

python3 ./experiments/align_with_calfat.yaml

If you want ensemble distillation for model fusion

Ensemble Distillation This algorithm requires unlabeled dataset in the server. You can generate fake data with BigGAN as:

python3 -s 33 -b 64 -nipc 128 -o ../data/distillation_data

or use existing datasets, such as stl10 for cifar10

To run with different local training loss

Many losses are provided in ./Loss, you need to change loss term in local trianing such as replacing calibrated_loss with trades_loss in local_train_fedavg

By default, best robust accuracy model and final model will be saved. To attack trained models, the torchattacks framework is used:

python3 -b 64 -d cifar10 -m simple-cnn -s 2022 --model-path /path/to/saved/model/

🔥 Reommended Resources


[1] C. Chen, Y. Liu, X. Ma, and L. Lyu, “CalFAT: Calibrated Federated Adversarial Training with Label Skewness.” arXiv, May 30, 2022. doi: 10.48550/arXiv.2205.14926.
[2] G. Zizzo, A. Rawat, M. Sinn, and B. Buesser, “FAT: Federated Adversarial Training.” arXiv, Dec. 03, 2020. doi: 10.48550/arXiv.2012.01791.
[3] Q. Li, Y. Diao, Q. Chen, and B. He, “Federated Learning on Non-IID Data Silos: An Experimental Study.” arXiv, Oct. 28, 2021. Accessed: Jul. 09, 2022. [Online]. Available:
[4] T. Lin, L. Kong, S. U. Stich, and M. Jaggi, “Ensemble Distillation for Robust Model Fusion in Federated Learning.” arXiv, Mar. 27, 2021. doi: 10.48550/arXiv.2006.07242.
[5] H. Zhang, Y. Yu, J. Jiao, E. P. Xing, L. E. Ghaoui, and M. I. Jordan, “Theoretically Principled Trade-off between Robustness and Accuracy.” arXiv, Jun. 24, 2019. Accessed: Jun. 20, 2022. [Online]. Available:

⬆ Return to top