##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>     <a target="_blank" href="https://www.tensorflow.org/federated/tutorials/building_your_own_federated_learning_algorithm"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">在 TensorFlow.org 上查看</a>
</td>
  <td>     <a target="_blank" href="https://colab.research.google.com/github/tensorflow/federated/blob/v0.34.0/docs/tutorials/building_your_own_federated_learning_algorithm.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">	在 Google Colab 中运行</a>
</td>
  <td>     <a target="_blank" href="https://github.com/tensorflow/federated/blob/v0.34.0/docs/tutorials/building_your_own_federated_learning_algorithm.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">在 GitHub 上查看源代码</a>
</td>
  <td>     <a href="https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/building_your_own_federated_learning_algorithm.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">下载笔记本</a>   </td>
</table>

## Before you start

Before you start, please run the following to make sure that your environment is
correctly setup. If you don't see a greeting, please refer to the
[Installation](../install.md) guide for instructions. 

In [None]:
#@test {"skip": true}
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()

In [None]:
import tensorflow as tf
import tensorflow_federated as tff

**注**：本 Colab 已通过验证，可与[最新发布版本](https://github.com/tensorflow/federated#compatibility)的 `tensorflow_federated` pip 软件包一起使用，但 Tensorflow Federated 项目仍处于预发布开发阶段，可能无法在 `main` 上运行。

# 构建您自己的联合学习算法

In the [image classification](federated_learning_for_image_classification.ipynb) and [text generation](federated_learning_for_text_generation.ipynb) tutorials, you learned how to set up model and data pipelines for Federated Learning (FL), and performed federated training via the `tff.learning` API layer of TFF.

This is only the tip of the iceberg when it comes to FL research. This tutorial discusses how to implement federated learning algorithms *without* deferring to the `tff.learning` API. In this tutorial, you will accomplish the following:

**目标：**

- 了解联合学习算法的一般结构。
- 探索 TFF 的 *Federated Core*。
- 使用 Federated Core 直接实现联合平均。

While this tutorial is self-contained, it may be useful to first check out the [image classification](federated_learning_for_image_classification.ipynb) and [text generation](federated_learning_for_text_generation.ipynb) tutorials.


## 准备输入数据

First load and preprocess the EMNIST dataset included in TFF. For more details, see the [image classification](federated_learning_for_image_classification.ipynb) tutorial.

In [None]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In order to feed the dataset into our model, the data is flattened, and each example is converted into a tuple of the form `(flattened_image_vector, label)`.

In [None]:
NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

Now, select a small number of clients, and apply the preprocessing above to their datasets.

In [None]:
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

## 准备模型

This uses the same model as in the [image classification](federated_learning_for_image_classification.ipynb) tutorial. This model (implemented via `tf.keras`) has a single hidden layer, followed by a softmax layer.

In [None]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

In order to use this model in TFF, wrap the Keras model as a [`tff.learning.Model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model). This allows one to perform the model's [forward pass](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model#forward_pass) within TFF, and [extract model outputs](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model#report_local_unfinalized_metrics). For more details, also see the [image classification](federated_learning_for_image_classification.ipynb) tutorial.

In [None]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

While the above used `tf.keras` to create a `tff.learning.Model`, TFF supports much more general models. These models have the following relevant attributes capturing the model weights:

- `trainable_variables`：与可训练层对应的张量的可迭代对象。
- `non_trainable_variables`：与不可训练层对应的张量的可迭代对象。

In this tutorial, only the `trainable_variables` will be used. (as the model only has those!).

# 构建您自己的联合学习算法

虽然 `tff.learning` API 支持创建联合平均的许多变体，但也有些其他联合算法不适合此框架。例如，您可能想要添加正则化、裁剪或更复杂的算法，例如[联合 GAN 训练](https://github.com/tensorflow/federated/tree/main/tensorflow_federated/python/research/gans)。另外，您可能还对[联合分析](https://ai.googleblog.com/2020/05/federated-analytics-collaborative-data.html)感兴趣。

For these more advanced algorithms, you'll have to write our own custom algorithm using TFF. In many cases, federated algorithms have 4 main components:

1. 服务器到客户端的广播步骤。
2. 本地客户端更新步骤。
3. 客户端到服务器的上传步骤。
4. 服务器更新步骤。

In TFF, a federated algorithm is typically represented as a [`tff.templates.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess) (which will be referred to as just an `IterativeProcess` throughout). This is a class that contains `initialize` and `next` functions. Here, `initialize` is used to initialize the server, and `next` will perform one communication round of the federated algorithm. Let's write a skeleton of what our iterative process for FedAvg should look like.

First, there is an initialize function that simply creates a `tff.learning.Model`, and returns its trainable weights.

In [None]:
def initialize_fn():
  model = model_fn()
  return model.trainable_variables

This function looks good, but as you will see later, you will need to make a small modification to make it a "TFF computation".

Next, let's write a sketch of the `next_fn`.

In [None]:
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Let's focus on implementing these four components separately. First, let's focus on the parts that can be implemented in pure TensorFlow, namely the client and server update steps.


## TensorFlow 块 

### 客户端更新

The `tff.learning.Model` can be used to do client training in essentially the same way you would train a TensorFlow model. In particular, one can use `tf.GradientTape` to compute the gradient on batches of data, then apply these gradient using a `client_optimizer`. This will only involve the trainable weights.


In [None]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

### 服务器更新

The server update for FedAvg is simpler than the client update. This tutorial will implement "vanilla" federated averaging, in which the server model weights are replaced by the average of the client model weights. Again, this only uses the trainable weights.

In [None]:
@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

该代码片段可简化为简单地返回 `mean_client_weights`。但是，联合平均的更高级实现可以使用 `mean_client_weights` 和更复杂的技术，例如动量或自适应。

**挑战**：实现用于将服务器权重更新为 model_weights 和 mean_client_weights 的中点的 `server_update` 版本。（注：这种“中点”方式类似于最近在 [Lookahead 优化器](https://arxiv.org/abs/1907.08610)上的工作！）。

So far, this has only involved TensorFlow code. This is by design, as TFF allows you to use much of the TensorFlow code you're already familiar with. Next you will have to specify the **orchestration logic**, that is, the logic that dictates what the server broadcasts to the client, and what the client uploads to the server.

这将需要 TFF 的 *Federated Core*。

# Federated Core 简介

Federated Core (FC) 是一组用作 `tff.learning` API 基础的低级接口。不过，这些接口不仅限于学习。事实上，它们可用于对分布式数据进行分析和许多其他计算。

概括来讲，Federated Core 是一个开发环境，可让简洁表达的程序逻辑能够将 TensorFlow 代码与分布式通信算子（例如分布式和与广播）相结合。目标是让研究员和从业者明确控制他们系统中的分布式通信，而不需要系统实现细节（例如指定点对点网络消息交换）。

一个关键点在于，TFF 是专为隐私保护而设计。因此，它允许显式控制数据驻留的位置，以防止在集中式服务器位置不必要地积累数据。

## 联合数据

A key concept in TFF is "federated data", which refers to a collection of data items hosted across a group of devices in a distributed system (eg. client datasets, or the server model weights). The entire collection of values across all devices is represented as a single *federated value*.

For example, suppose there are client devices that each have a float representing the temperature of a sensor. These floats can be represented as a *federated float* by

In [None]:
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

Federated types are specified by a type `T` of its member constituents (eg. `tf.float32`) and a group `G` of devices. Typically, `G` is either `tff.CLIENTS` or `tff.SERVER`. Such a federated type is represented as `{T}@G`, as shown below.

In [None]:
str(federated_float_on_clients)

'{float32}@CLIENTS'

Why does TFF care so much about placements? A key goal of TFF is to enable writing code that could be deployed on a real distributed system. This means that it is vital to reason about which subsets of devices execute which code, and where different pieces of data reside.

TFF 关注三个信息：*数据*、数据*放置*的位置以及数据如何*转换*。前两个封装在联合类型中，而最后一个封装在*联合计算*中。

## 联合计算

TFF 是一种强类型函数式编程环境，其基本单元是*联合计算*。这些单元是接受联合值作为输入并返回联合值作为输出的逻辑片段。

For example, suppose you wanted to average the temperatures on our client sensors. You could define the following (using our federated float):

In [None]:
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

您可能会问，这和 TensorFlow 中的 `tf.function` 装饰器有什么不同？关键的答案是 `tff.federated_computation` 生成的代码既不是 TensorFlow 也不是 Python 代码；它是以独立于内部平台的*胶水语言*编写的分布式系统规范。

虽然这听起来很复杂，但您可以将 TFF 计算视为具有明确定义的类型签名的函数。可以直接查询这些类型签名。

In [None]:
str(get_average_temperature.type_signature)

'({float32}@CLIENTS -> float32@SERVER)'

此 `tff.federated_computation` 接受联合类型 `<float>@CLIENTS` 的参数，并返回联合类型 `<float>@SERVER` 的值。联合计算可以从服务器到客户端、从客户端到客户端或者从服务器到服务器。另外，联合计算的构成也可以像普通函数一样，只要它们的类型签名匹配即可。

To support development, TFF allows you to invoke a `tff.federated_computation` as a Python function. For example, you can call

In [None]:
get_average_temperature([68.5, 70.3, 69.8])

69.53334

## 非 Eager 计算和 TensorFlow

有两个关键限制需要注意。首先，当 Python 解释器遇到 `tff.federated_computation` 装饰器时，该函数会被跟踪一次并序列化以备将来使用。由于联合学习的去中心化性质，这种未来用法可能会在别处得到应用，例如远程执行环境。因此，TFF 计算从根本上来说是*非 Eager* 计算。这种行为有点类似于 TensorFlow 中的 [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) 装饰器。

其次，联合计算只能由联合算子（例如 `tff.federated_mean`）组成，不能包含 TensorFlow 运算。TensorFlow 代码必须限制在使用 `tff.tf_computation` 装饰的块中。大多数普通 TensorFlow 代码都可以直接进行装饰，例如下面的函数，它会取一个数字并加 `0.5`。

In [None]:
@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

These also have type signatures, but *without placements*. For example, you can call

In [None]:
str(add_half.type_signature)

'(float32 -> float32)'

This showcases an important difference between `tff.federated_computation` and `tff.tf_computation`. The former has explicit placements, while the latter does not.

You can use `tff.tf_computation` blocks in federated computations by specifying placements. Let's create a function that adds half, but only to federated floats at the clients. You can do this by using `tff.federated_map`, which applies a given `tff.tf_computation`, while preserving the placement.

In [None]:
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

This function is almost identical to `add_half`, except that it only accepts values with placement at `tff.CLIENTS`, and returns values with the same placement. This can be seen in its type signature:

In [None]:
str(add_half_on_clients.type_signature)

'({float32}@CLIENTS -> {float32}@CLIENTS)'

总结：

- TFF 对联合值进行运算。
- 每个联合值都有一个*联合类型*，而联合类型包含*类型*（例如 `tf.float32`）和*布局*（例如 `tff.CLIENTS`）。
- 联合值可以使用*联合计算*进行转换，联合计算必须使用 `tff.federated_computation` 和联合类型签名进行装饰。
- TensorFlow 代码必须包含在带有 `tff.tf_computation` 装饰器的块中。
- 随后可以将这些块合并到联合计算中。


# 再次讨论构建您自己的联合学习算法

Now that you've gotten a glimpse of the Federated Core, you can build our own federated learning algorithm. Remember that above, you defined an `initialize_fn` and `next_fn` for our algorithm. The `next_fn` will make use of the `client_update` and `server_update` you defined using pure TensorFlow code.

However, in order to make our algorithm a federated computation, you will need both the `next_fn` and `initialize_fn` to each be a `tff.federated_computation`.

## TensorFlow 联合块 

### 创建初始化计算

The initialize function will be quite simple: You will create a model using `model_fn`. However, remember that you must separate out our TensorFlow code using `tff.tf_computation`.

In [None]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

You can then pass this directly into a federated computation using `tff.federated_value`.

In [None]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

### 创建 `next_fn`

The client and server update code can now be used to write the actual algorithm. First, you will turn the `client_update` into a `tff.tf_computation` that accepts a client datasets and server weights, and outputs an updated client weights tensor.

You will need the corresponding types to properly decorate our function. Luckily, the type of the server weights can be extracted directly from our model.

In [None]:
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

Let's look at the dataset type signature. Remember that you took 28 by 28 images (with integer labels) and flattened them.

In [None]:
str(tf_dataset_type)

'<float32[?,784],int32[?,1]>*'

You can also extract the model weights type by using our `server_init` function above.

In [None]:
model_weights_type = server_init.type_signature.result

Examining the type signature, you'll be able to see the architecture of our model!

In [None]:
str(model_weights_type)

'<float32[784,10],float32[10]>'

You can now create our `tff.tf_computation` for the client update.

In [None]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

The `tff.tf_computation` version of the server update can be defined in a similar way, using types you've already extracted.

In [None]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

Last, but not least, you need to create the `tff.federated_computation` that brings this all together. This function will accept two *federated values*, one corresponding to the server weights (with placement `tff.SERVER`), and the other corresponding to the client datasets (with placement `tff.CLIENTS`).

Note that both these types were defined above! You simply need to give them the proper placement using `tff.FederatedType`.

In [None]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

还记得 FL 算法的 4 个元素吗？

1. 服务器到客户端的广播步骤。
2. 本地客户端更新步骤。
3. 客户端到服务器的上传步骤。
4. 服务器更新步骤。

Now that you've built up the above, each part can be compactly represented as a single line of TFF code. This simplicity is why you had to take extra care to specify things such as federated types!

In [None]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  
  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

You now have a `tff.federated_computation` for both the algorithm initialization, and for running one step of the algorithm. To finish our algorithm, you pass these into `tff.templates.IterativeProcess`.

In [None]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

我们看看迭代过程的 <code>initialize</code> 和 `next` 函数的<em>类型签名</em>。

In [None]:
str(federated_algorithm.initialize.type_signature)

'( -> <float32[784,10],float32[10]>@SERVER)'

这反映了 `federated_algorithm.initialize` 是一个返回单层模型（具有 784×10 权重矩阵和 10 个偏置单元）的无参数函数的事实。

In [None]:
str(federated_algorithm.next.type_signature)

'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

Here, one can see that `federated_algorithm.next` accepts a server model and client data, and returns an updated server model.

## 评估算法

Let's run a few rounds, and see how the loss changes. First, you will define an evaluation function using the *centralized* approach discussed in the second tutorial.

You will first create a centralized evaluation dataset, and then apply the same preprocessing you used for the training data.

In [None]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

Next, you will write a function that accepts a server state, and uses Keras to evaluate on the test dataset. If you're familiar with `tf.Keras`, this will all look familiar, though note the use of `set_weights`!

In [None]:
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

现在，我们初始化我们的算法并对测试集进行评估。

In [None]:
server_state = federated_algorithm.initialize()
evaluate(server_state)



我们训练几个轮次，看看有什么变化。

In [None]:
for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)

In [None]:
evaluate(server_state)



There is a slight decrease in the loss function. While the jump is small, you've only performed 15 training rounds, and on a small subset of clients. To see better results, you may have to do hundreds if not thousands of rounds.

## 修改我们的算法

At this point, let's stop and think about what you've accomplished. You've implemented Federated Averaging directly by combining pure TensorFlow code (for the client and server updates) with federated computations from the Federated Core of TFF.

To perform more sophisticted learning, you can simply alter what you have above. In particular, by editing the pure TF code above, you can change how the client performs training, or how the server updates its model.

**挑战：**将[梯度裁剪](https://towardsdatascience.com/what-is-gradient-clipping-b8e815cdfb48)添加到 `client_update` 函数。


If you wanted to make larger changes, you could also have the server store and broadcast more data. For example, the server could also store the client learning rate, and make it decay over time! Note that this will require changes to the type signatures used in the `tff.tf_computation` calls above.

**更高难度的挑战**：在客户端上实现采用学习率衰减的联合平均。

At this point, you may begin to realize how much flexibility there is in what you can implement in this framework. For ideas (including the answer to the harder challenge above) you can see the source-code for [`tff.learning.algorithms.build_weighted_fed_avg`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/algorithms/build_weighted_fed_avg), or check out various [research projects](https://github.com/google-research/federated) using TFF.