##### Copyright 2019 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# tf.distribute.Strategy を使用したカスタムトレーニング

<table class="tfo-notebook-buttons" align="left">
  <td> <img src="https://www.tensorflow.org/images/tf_logo_32px.png"><a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/custom_training.ipynb">TensorFlow.org で表示</a> </td>
  <td> <img src="https://www.tensorflow.org/images/colab_logo_32px.png"><a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/custom_training.ipynb">Google Colab で実行</a> </td>
  <td> <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png"><a target="_blank" href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/distribute/custom_training.ipynb">GitHub でソースを表示</a> </td>
  <td> <img src="https://www.tensorflow.org/images/download_logo_32px.png"><a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/custom_training.ipynb">ノートブックをダウンロード</a> </td>
</table>

This tutorial demonstrates how to use `tf.distribute.Strategy`—a TensorFlow API that provides an abstraction for [distributing your training](../../guide/distributed_training.ipynb) across multiple processing units (GPUs, multiple machines, or TPUs)—with custom training loops. In this example, you will train a simple convolutional neural network on the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) containing 70,000 images of size 28 x 28.

[Custom training loops](../customization/custom_training_walkthrough.ipynb) provide flexibility and a greater control on training. They also make it easier to debug the model and the training loop.

In [None]:
# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

## Download the Fashion MNIST dataset

In [None]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

## 変数とグラフを分散させるストラテジーを作成する

`tf.distribute.MirroredStrategy`ストラテジーはどのように機能するのでしょう？

- All the variables and the model graph are replicated across the replicas.
- 入力はレプリカ全体に均等に分散されます。
- 各レプリカは受け取った入力の損失と勾配を計算します。
- 勾配は加算して全てのレプリカ間で同期されます。
- 同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。

Note: You can put all the code below inside a single scope. This example divides it into several code cells for illustration purposes.


In [None]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

In [None]:
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

## 入力パイプラインをセットアップする

In [None]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

データセットを作成して、それを配布します。

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

## モデルを作成する

Create a model using `tf.keras.Sequential`. You can also use the [Model Subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) or the [functional API](https://www.tensorflow.org/guide/keras/functional) to do this.

In [None]:
def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model

In [None]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## 損失関数を定義する

Normally, on a single machine with a single GPU/CPU, the loss function is divided by the number of examples in the batch of input.

*では、`tf.distribute.Strategy` を使用する場合、どのように損失を計算すればよいのでしょうか。*

- For an example, let's say you have 4 GPUs and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), and each replica gets an input of size 16.

- 各レプリカのモデルは、それぞれの入力でフォワードパスを実行し、損失を計算します。ここでは、損失をそれぞれの入力の例の数（BATCH_SIZE_PER_REPLICA = 16）で除算するのではなく、損失を GLOBAL_BATCH_SIZE (64) で除算する必要があります。

*なぜそうするのでしょう？*

- 勾配を各レプリカで計算した後にそれらを**加算**してレプリカ間で同期するためです。

*TensorFlow では次のようにします。*

- このチュートリアルにもあるように、カスタムトレーニングループを書く場合は、サンプルごとの損失を加算し、その合計を GLOBAL_BATCH_SIZE: `scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)` で除算する必要があります。または、`tf.nn.compute_average_loss` を使用することも可能です。これはサンプルごとの損失、オプションのサンプルの重み、そしてGLOBAL_BATCH_SIZE を引数として取り、スケーリングされた損失を返します。

- If you are using regularization losses in your model then you need to scale the loss value by the number of replicas. You can do this by using the `tf.nn.scale_regularization_loss` function.

- `tf.reduce_mean` の使用は推奨されません。これを使用すると、損失がレプリカごとの実際のバッチサイズで除算され、ステップごとに変化する場合があります。

- This reduction and scaling is done automatically in Keras `Model.compile` and `Model.fit`

- If using `tf.keras.losses` classes (as in the example below), the loss reduction needs to be explicitly specified to be one of `NONE` or `SUM`. `AUTO` and `SUM_OVER_BATCH_SIZE`  are disallowed when used with `tf.distribute.Strategy`. `AUTO` is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. `SUM_OVER_BATCH_SIZE` is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So, instead, you need to do the reduction yourself explicitly.

- もし `labels` が多次元である場合は、各サンプルの要素数全体で `per_example_loss` を平均化します。例えば、`predictions` の形状が `(batch_size, H, W, n_classes)` で、`labels` が `(batch_size, H, W)` の場合、`per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)` のように `per_example_loss` を更新する必要があります。

    注意：**損失の形状を確認してください**。 `tf.losses`/`tf.keras.losses`の損失関数は、通常、入力の最後の次元の平均を返します。損失クラスはこれらの関数をラップします。 損失クラスのインスタンスを作成するときに`reduction=Reduction.NONE`を渡すことは、「**追加の**縮小がない」ことを意味します。`[batch, W, H, n_classes]`の入力形状の例を使用したカテゴリ損失の場合、`n_classes`次元が縮小されます。`losses.mean_squared_error`または`losses.binary_crossentropy`のような点ごとの損失の場合、ダミー軸を用いて、`[batch, W, H, 1]`を`[batch, W, H]`に縮小します。ダミー軸がないと、`[batch, W, H]`は誤って`[batch, W]`に縮小されます。


In [None]:
with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

## 損失と精度を追跡するメトリクスを定義する

これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。`.result()`を使用して、いつでも累積統計を取得できます。

In [None]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

## トレーニングループ

In [None]:
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [None]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)

In [None]:
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

上記の例における注意点:

- Iterate over the `train_dist_dataset` and `test_dist_dataset` using  a `for x in ...` construct.
- スケーリングされた損失は `distributed_train_step` の戻り値です。この値は `tf.distribute.Strategy.reduce` 呼び出しを使用してレプリカ間で集約され、次に `tf.distribute.Strategy.reduce` 呼び出しの戻り値を加算してバッチ間で集約されます。
- `tf.keras.Metrics` should be updated inside `train_step` and `test_step` that gets executed by `tf.distribute.Strategy.run`.
- `tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do `tf.distribute.Strategy.reduce` to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.


## 最新のチェックポイントを復元してテストする

`tf.distribute.Strategy`でチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。

In [None]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [None]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [None]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))

## データセットのイテレーションの代替方法

### イテレータを使用する

If you want to iterate over a given number of steps and not through the entire dataset, you can create an iterator using the `iter` call and explicitly call `next` on the iterator. You can choose to iterate over the dataset both inside and outside the `tf.function`. Here is a small snippet demonstrating iteration of the dataset outside the `tf.function` using an iterator.


In [None]:
for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()

### tf.function 内でイテレーションする

You can also iterate over the entire input `train_dist_dataset` inside a `tf.function` using the `for x in ...` construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a `@tf.function` decorator and iterating over `train_dist_dataset` inside the function.

In [None]:
@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()

### レプリカ間でトレーニング損失を追跡する

注意: 一般的なルールとして、サンプルごとの値の追跡には`tf.keras.Metrics`を使用し、レプリカ内で集約された値を避ける必要があります。

Because of the loss scaling computation that is carried out, it's not recommended to use `tf.keras.metrics.Mean` to track the training loss across different replicas.

例えば、次のような特徴を持つトレーニングジョブを実行するとします。

- レプリカ 2 つ
- 各レプリカで 2 つのサンプルを処理
- 結果の損失値 : 各レプリカで [2,  3] および [4,  5]
- グローバルバッチサイズ = 4

損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、`(2 + 3) / 4 = 1.25`および`(4 + 5) / 4 = 2.25`となります。

If you use `tf.keras.metrics.Mean` to track loss across the two replicas, the result is different. In this example, you end up with a `total` of 3.50 and `count` of 2, which results in `total`/`count` = 1.75  when `result()` is called on the metric. Loss calculated with `tf.keras.Metrics` is scaled by an additional factor that is equal to the number of replicas in sync.

### ガイドと例

カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。

1. [Distributed training guide](../../guide/distributed_training)
2. `MirroredStrategy`を使用した [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) の例。
3. [BERT](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_classifier.py) example trained using `MirroredStrategy` and `TPUStrategy`. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.
4. `MirroredStrategy` を使用してトレーニングされ、`keras_use_ctl` フラグを使用した有効化が可能な、[NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) の例。
5. `MirroredStrategy`を使用してトレーニングされた、[NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) の例。

You can find more examples listed under *Examples and tutorials* in the [Distribution strategy guide](../../guide/distributed_training.ipynb).

## 次のステップ

- 新しい`tf.distribute.Strategy` API を独自のモデルで試してみましょう。
- Visit the [Better performance with `tf.function`](../../guide/function.ipynb) and [TensorFlow Profiler](../../guide/profiler.md) guides to learn more about tools to optimize the performance of your TensorFlow models.
- Check out the [Distributed training in TensorFlow](../../guide/distributed_training.ipynb) guide, which provides an overview of the available distribution strategies.