Skip to content

tmlr-group/SFAT

 
 

Repository files navigation

SFAT: Slack Federated Adversarial Training

Paper Paper Github License Slides

This repo contains the sample code of our proposed framework Slack Federated Adversarial Training (SFAT) in our paper: Combating Exacerbated Heterogeneity for Robust Models in Federated Learning (ICLR 2023).

Figure. Framework overview of SFAT.

TODO:

  • Update the Project Page of SFAT.
  • Update the Presentation Slides and Video.
  • Released the arXiv version of SFAT.
  • Released the early version of sample code.

TL;DR

Our SFAT assigns the client-wise slack during aggregation to combat the intensified heterogeneity, which is induced by the inner-maximization of adversarial training on the heterogeneous data in federated learning.

Introduction

The emerging privacy and security issues in real-world applications motivate us to pursue the adversarially robust federated models. However, the straightforward combination between adversarial training and federated learning in one framework can induce the undesired robustness deterioration.

Figure 1. Robust Deterioration in federated adversarial training.

We dive into the issue of robustness deterioration and discover that it may attribute to the intensified heterogeneity induced by adversarial training in local clients. Considering federated learning, one of the primary difficulties is the biased optimization caused by the local training with heterogeneous data. As for adversarial training, the key distinction from standard training is the use of inner-maximization to generate adversarial data, which pursues the better adversarial robustness. When combining the two learning paradigms, we conjecture that the following issue may arise especially under the Non-IID case,

the inner-maximization for pursuing adversarial robustness would exacerbate the data heterogeneity among local clients in federated learning.

Figure 2. Illustration of $\alpha$-slacked mechanism.

Quick preview of our SFAT

Environment

Python (3.8)
Pytorch (1.7.0 or above)
torchvision
CUDA
Numpy

File Structure

./SFAT-main
├─ Centralized_AT.py        # Training and evaluation
├─ SFAT.py
├─ attack_generator.py      # Attack generation
├─ eval_pgd.py
├─ logger.py                # Log support
├─ models.py
├─ options.py               # Options and hyperparameters
├─ readme.md
├─ sampling.py              # Data split
├─ update.py
└─ utils.py                 # Aggregation and other utils

Running example

To train federated robust model, we provide examples below to use our code:

CUDA_VISIBLE_DEVICES='0' python SFAT.py --dataset=cifar-10 --local_ep=10 --local_bs=32 --iid=0 --epochs=100 --num_users=5 --agg-opt='FedAvg' --agg-center='FedAvg' --out-dir='../output_results_FAT_FedAvg'

CUDA_VISIBLE_DEVICES='1' python SFAT.py --dataset=cifar-10 --local_ep=10 --local_bs=32 --iid=0 --epochs=100 --num_users=5 --agg-opt='FedAvg' --agg-center='SFAT' --pri=1.2 --out-dir='../output_results_SFAT_FedAvg'

Figure 3. Comparison of FAT and SFAT using approximated client drift.

Compared with FAT, our proposed SFAT selectively upweights/downweights the client with small/large adversarial training loss to alleviate it during aggregation, which follows our $\alpha$-slack mechanism to relax the original objective into a lower bound. SFAT can a smaller drift compared to FAT, i.e., a less heterogeneous aggregation, by adapting $\alpha$-slack mechanism.

Realization details

Following the conventional federated learning realization, we realizes the overall framework of SFAT in SFAT.py which coordinate the local optimization part in update.py and the aggregation functions in utils.py.

In SFAT.py, we get the local model in each client and aggregate the global model.

# local updates
for idx in idxs_users:
    local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, alg=args.agg_opt, anchor=global_model, anchor_mu=args.mu, local_rank=ipx, method=args.train_method)
            ''' ''' 
# aggregation method
if args.agg_center == 'FedAvg':
    global_weights = average_weights(local_weights)
if args.agg_center == 'SFAT':
    ''' '''
    global_weights = average_weights_alpha(local_weights, idt, idtxnum, args.pri)

In updates.py, we realize the local training on each client for adversarial training and defined the LocalUpdate().

In utils.py, we realize the aggregation methods and define the FAT, i.e., average_weights() and SFAT average_weights_alpha() as well as their unequal versions. For the our SFAT, the critical part of code is as follows, where the lw and idx is to help choose the corresponding clients and the p is our $\alpha$-slack parameter for reweighting.

Data split

We realize the operation of data split in sampling.py and utilized in utils.py for generate local data loader for each client. We can use our pre-defined split function as following to get the local data.

def get_dataset(args):
    ''' ''' 
    user_groups = cifar_noniid_skew(train_dataset, args.num_users)
    ''' '''
    return train_dataset, test_dataset, user_groups

Choosing different optimization and aggregation methods

To choose different federated optimization methods (e.g., FedAvg, FedProx, Scaffold) and the aggregations (e.g., FAT and SFAT) for training robust federated model. We can used defined parameter in our options.py:

parser.add_argument('--agg-opt',type=str,default='FedAvg',help='option of on-device learning: FedAvg, FedProx, Scaffold')
parser.add_argument('--agg-center',type=str,default='FedAvg',help='option of aggregation: FedAvg, SFAT')

Evaluation

To evaluate our trained model using various attack methods, we provide the eval_pgd.py contains different evaluation metrics for natural and robust performance. You can run the following script with your model path to conduct evaluation:

CUDA_VISIBLE_DEVICES='0' python eval_pgd.py --net [NETWORK STRUCTURE] --dataset [DATASET] --model_path [MODLE PATH]

Sample results:

CIFAR-10 (Non-IID) Method Natural FGSM PGD-20 CW AutoAttack
FedAvg FAT 58.13 (0.68) 40.06 (0.62) 32.56 (0.01) 30.88 (0.37) 29.17 (0.03)
FedAvg SFAT 63.36 (0.07) 44.82 (0.32) 37.14 (0.03) 33.39 (0.61) 31.66 (0.70)

Actually, during the training, we also provide the accuracy track via logger.py to save the model performance in each epoch.

To extend and design new method in our framework

Either the local optimization or aggregation method can be re-designed based on our framework in the corresponding updates.py and utils.py part.

Reference Code


If you find our paper and repo useful, please cite our paper:

@inproceedings{zhu2023combating,
title       ={Combating Exacerbated Heterogeneity for Robust Models in Federated Learning},
author      ={Jianing Zhu and Jiangchao Yao and Tongliang Liu and quanming yao and Jianliang Xu and Bo Han},
booktitle   ={The Eleventh International Conference on Learning Representations },
year        ={2023},
url         ={https://openreview.net/forum?id=eKllxpLOOm}
}

About

[ICLR 2023] "Combating Exacerbated Heterogeneity for Robust Models in Federated Learning"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%