##### Copyright 2021 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Migration examples: estimator.LoggingTensorHook

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/guide/migrate/logging_tensor_hook">
    <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />
    View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/logging_tensor_hook.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_tensor_hook.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />
    View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/logging_tensor_hook.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This notebook demonstrates how you can migrate `tf.estimator.LoggingTensorHook` usage to use custom `tf.keras.callbacks.Callback` instead.

## Setup

First, you need to define a couple of necessary imports.

In [None]:
import tensorflow as tf
import tensorflow.compat.v1 as tf1

Prepare some simple data for demonstration.

In [None]:
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

### TF1: Estimator.train/evaluate

To monitor tensors, for example model weights or losses, you can use `tf.estimator.LoggingTensorHook` (`tf1.train.LoggingTensorHook` is its alias), and then pass the hook to `tf.estimator.EstimatorSpec`.

In [None]:
def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  dense = tf1.layers.Dense(1)
  logits = dense(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  kernel_name = tf.identity(dense.weights[0])
  bias_name = tf.identity(dense.weights[1])
  # access tensors to be logged by names
  logging_weight = tf1.train.LoggingTensorHook(tensors=[kernel_name, bias_name],
                                             every_n_iter=1)
  # log training loss by the tensor object
  logging_loss = tf1.train.LoggingTensorHook(
      {'loss from LoggingTensorHook': loss},
      every_n_secs=3)
  
  return tf1.estimator.EstimatorSpec(mode,
                                     loss=loss,
                                     train_op=train_op,
                                     training_hooks=[logging_weight,
                                                     logging_loss])

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)

### TF2: Keras training API

In TF2, accessing to tensors by names is not supported. You need to record and output the logged tensors manually.

When migrating to TF2 Keras training API, you can define when to log the tensors by overriding different methods of defining custom `tf.keras.callbacks.Callback`.  You can also implement the logging frequency in the custom callback. The example below will print the weights every two steps. Other strategies like logging every n seconds are also possible.

Check the API [docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) and [Writing your own callbacks](https://www.tensorflow.org/guide/keras/custom_callback) for more details.

In [None]:
class LoggingTensorCallback(tf.keras.callbacks.Callback):
  def __init__(self, every_n_iter):
      super().__init__()
      self._every_n_iter = every_n_iter
      self._log_count = every_n_iter

  def on_batch_end(self, batch, logs=None):
    if self._log_count > 0:
      self._log_count -= 1
      print("Logging Tensor Callback: dense/kernel:",
            model.layers[0].weights[0])
      print("Logging Tensor Callback: dense/bias:",
            model.layers[0].weights[1])
      print("Logging Tensor Callback loss:", logs["loss"])
    else:
      self._log_count -= self._every_n_iter

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")
model.fit(dataset, callbacks=[LoggingTensorCallback(2)])