Skip to content

Latest commit

 

History

History

simple_fedavg

Minimal Stand-Alone Implementation of Federated Averaging

This is intended to be a flexible and minimal implementation of Federated Averaging, and the code is designed to be modular and reusable. This implementation of the federated averaging algorithm only uses key TFF functions and does not depend on advanced features in tff.learning. See fed_avg_.py for a more full-featured implementation.

Instructions

A minimal implementation of the Federated Averaging algorithm is provided here, along with an example federated EMNIST experiment. The implementation demonstrates the three main types of logic of a typical federated learning simulation.

  • An outer python script that drives the simulation by selecting simulated clients from a dataset and then executing federated learning algorithms.

  • Individual pieces of TensorFlow code that run in a single location (e.g., on clients or on a server). The code pieces are typically tf.functions that can be used and tested outside TFF.

  • The orchestration logic binds together the local computations by wrapping them as tff.tensorflow.computations and using key TFF functions like tff.federated_broadcast and tff.federated_map inside a tff.federated_computation.

This EMNIST example can easily be adapted for experimental changes:

We expect researchers working on applications and models only need to change the driver file emnist_fedavg_main, researchers working on optimization may implement most of the ideas by writing pure TF code in the TF file simple_fedavg_tf, while researchers who need more control over the orchestration strategy may get familiar with TFF code in simple_fedavg_tff. We encourage readers to consider the following exercises for using this set of code for your research:

  1. Try a more complicated server optimizer such as ADAM. You only need to change emnist_fedavg_main.

  2. Implement a model that uses L2 regularization. You will need to change the model definition in emnist_fedavg_main and add Keras regularization losses in the KerasModelWrapper class in simple_fedavg_tf.

  3. Implement a decaying learning rate schedule on the clients based on the global round, using the round_num broadcasted to the clients in simple_fedavg_tf.

  4. Implement a more complicated aggregation procedure that drops the client updates with the largest and smallest l2 norms. You will need to change simple_fedavg_tff.

Citation

@inproceedings{mcmahan2017communication,
  title={Communication-Efficient Learning of Deep Networks from
  Decentralized Data},
  author={McMahan, Brendan and Moore, Eider and Ramage, Daniel and Hampson,
  Seth and y Arcas, Blaise Aguera},
  booktitle={Artificial Intelligence and Statistics},
  pages={1273--1282},
  year={2017}
  }