##### Copyright 2021 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://www.tensorflow.org/federated/tutorials/custom_federated_algorithm_with_tff_optimizers"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">TensorFlow.org에서 보기</a></td>
  <td><a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ko/federated/tutorials/custom_federated_algorithm_with_tff_optimizers.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Google Colab에서 실행</a></td>
  <td><a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/ko/federated/tutorials/custom_federated_algorithm_with_tff_optimizers.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">GitHub에서 소스 보기</a></td>
  <td><a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ko/federated/tutorials/custom_federated_algorithm_with_tff_optimizers.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">노트북 다운로드</a></td>
</table>

# 맞춤형 반복 프로세스에서 TFF 옵티마이저 사용

이것은 [고유한 페더레이션 학습 알고리즘 빌드](building_your_own_federated_learning_algorithm.ipynb) 튜토리얼과 [페더레이션 평균화](https://arxiv.org/abs/1602.05629) 알고리즘에 대한 맞춤형 반복 프로세스를 빌드하기 위한 [simple_fedavg](https://github.com/tensorflow/federated/tree/main/tensorflow_federated/examples/simple_fedavg) 예제를 대체합니다. 이 튜토리얼에서는 Keras 옵티마이저 대신 [TFF 옵티마이저](https://github.com/tensorflow/federated/tree/main/tensorflow_federated/python/learning/optimizers)를 사용합니다. TFF 옵티마이저 추상화는 TFF 반복 프로세스에 더 쉽게 도입하도록 state-in-state-out으로 설계되었습니다. `tff.learning` API는 또한 TFF 옵티마이저를 입력 인수로 허용합니다. 

## 시작하기 전에

시작하기 전에 다음을 실행하여 환경이 올바르게 설정되었는지 확인하세요. 인사말이 표시되지 않으면 [설치](../install.md) 가이드에서 지침을 참조하세요. 

In [None]:
#@test {"skip": true}
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()

In [None]:
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

## 데이터 및 모델 준비하기

EMNIST 데이터 처리 및 모델은 [simple_fedavg](https://github.com/tensorflow/federated/tree/main/tensorflow_federated/examples/simple_fedavg) 예제와 매우 유사합니다.

In [None]:
only_digits=True

# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)

# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):

  def batch_format_fn(element):
    return (tf.expand_dims(element['pixels'], -1), element['label'])

  return dataset.batch(batch_size).map(batch_format_fn)

# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
    emnist_train.create_tf_dataset_for_client(train_client_ids[0]))

# Define model.
def create_keras_model():
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'
  input_shape = [28, 28, 1]

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      conv2d(filters=32, input_shape=input_shape),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
  ])

  return model

# Wrap as `tff.learning.Model`.
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=central_test_data.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

## 맞춤형 반복 프로세스


많은 경우 페더레이션 알고리즘에는 4가지 주요 구성 요소가 있습니다.

1. 서버-클라이언트 브로드 캐스트 단계.
2. 로컬 클라이언트 업데이트 단계.
3. 클라이언트-서버 업로드 단계.
4. 서버 업데이트 단계.

TFF에서는 일반적으로 페더레이션 알고리즘을 [`tff.templates.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/templates/IterativeProcess)(나머지 부분에서 `IterativeProcess`라고 함)로 나타냅니다. 이것은 `initialize` 및 `next` 함수를 포함하는 클래스입니다. 여기서 `initialize`는 서버를 초기화하는 데 사용되며 `next`는 페더레이션 알고리즘의 한 통신 라운드를 수행합니다.

클라이언트 업데이트 단계에서 옵티마이저를 사용하고 서버 업데이트 단계에서 또 다른 옵티마이저를 사용하는 페더레이션 평균화(FedAvg) 알고리즘을 빌드하기 위해 다양한 구성 요소를 도입합니다. 클라이언트 및 서버 업데이트의 핵심 로직은 순수 TF 블록으로 표현될 수 있습니다.

### TF 블록: 클라이언트 및 서버 업데이트

각 클라이언트에서 로컬 `client_optimizer`가 초기화되고 클라이언트 모델 가중치를 업데이트하는 데 사용됩니다. 서버에서 `server_optimizer`는 *이전* 라운드의 상태를 사용하고 다음 라운드의 상태를 업데이트합니다. 

In [None]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs local training on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)
  # Initialize the client optimizer.
  trainable_tensor_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
  optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
  # Use the client_optimizer to update the local model.
  for batch in iter(dataset):
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data.
      outputs = model.forward_pass(batch)
    # Compute the corresponding gradient.
    grads = tape.gradient(outputs.loss, client_weights)
    # Apply the gradient using a client optimizer.
    optimizer_state, updated_weights = client_optimizer.next(
        optimizer_state, client_weights, grads)
    tf.nest.map_structure(lambda a, b: a.assign(b), 
                          client_weights, updated_weights)
  # Return model deltas.
  return tf.nest.map_structure(tf.subtract, client_weights, server_weights)

In [None]:
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
  trainable_weights = attr.ib()
  optimizer_state = attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
  """Updates the server model weights."""
  # Use aggregated negative model delta as pseudo gradient. 
  negative_weights_delta = tf.nest.map_structure(
      lambda w: -1.0 * w, mean_model_delta)
  new_optimizer_state, updated_weights = server_optimizer.next(
      server_state.optimizer_state, server_state.trainable_weights, 
      negative_weights_delta)
  return tff.structure.update_struct(
      server_state,
      trainable_weights=updated_weights,
      optimizer_state=new_optimizer_state)

### TFF 블록: `tff.tf_computation` 및 `tff.federated_computation`

이제 오케스트레이션에 TFF를 사용하고 FedAvg에 대한 반복 프로세스를 빌드합니다. 위에서 정의한 TF 블록을 `tff.tf_computation`으로 래핑하고 `tff.federated_computation` 함수에서 TFF 메서드 `tff.federated_broadcast`, `tff.federated_map`, `tff.federated_mean`을 사용해야 합니다. 맞춤형 반복 프로세스를 정의할 때 `initialize` 및 `next` 함수와 함께 `tff.learning.optimizers.Optimizer` API를 쉽게 사용할 수 있습니다.

In [None]:
# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.01)

# 2. Functions return initial state on server. 
@tff.tf_computation
def server_init():
  model = model_fn()
  trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
  optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
  return ServerState(
      trainable_weights=model.trainable_variables,
      optimizer_state=optimizer_state)

@tff.federated_computation
def server_init_tff():
  return tff.federated_value(server_init(), tff.SERVER)

# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n', 
      server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n', 
      trainable_weights_type.formatted_representation())

# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
  return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n', 
      tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
  model = model_fn()
  return client_update(model, dataset, server_weights, client_optimizer)

# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
  # Server-to-client broadcast.
  server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
  # Local client update.
  model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  # Client-to-server upload and aggregation.
  mean_model_delta = tff.federated_mean(model_deltas)
  # Server update.
  server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
  return server_state

# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n', 
      fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n', 
      fedavg_process.next.type_signature.formatted_representation())

server_state_type:
 <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>
trainable_weights_type:
 <
  float32[5,5,1,32],
  float32[32],
  float32[5,5,32,64],
  float32[64],
  float32[3136,512],
  float32[512],
  float32[512,10],
  float32[10]
>
tf_dataset_type:
 <
  float32[?,28,28,1],
  int32[?]
>*
type signature of `initialize`:
 ( -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32

## 알고리즘 평가

중앙 집중식 평가 데이터세트에서 성능을 평가합니다.

In [None]:
def evaluate(server_state):
  keras_model = create_keras_model()
  tf.nest.map_structure(
      lambda var, t: var.assign(t),
      keras_model.trainable_weights, server_state.trainable_weights)
  metric = tf.keras.metrics.SparseCategoricalAccuracy()
  for batch in iter(central_test_data):
    preds = keras_model(batch[0], training=False)
    metric.update_state(y_true=batch[1], y_pred=preds)
  return metric.result().numpy()

In [None]:
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)

# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
    train_data.create_tf_dataset_for_client(client)
    for client in sampled_clients]
for round in range(20):
  server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)

Initial test accuracy 0.09677419
Test accuracy 0.13978495
