Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions official/utils/logging/metric_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Session hook for logging benchmark metric."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from official.utils.logging import logger


class LoggingMetricHook(tf.train.LoggingTensorHook):
"""Hook to log benchmark metric information.

This hook is very similar as tf.train.LoggingTensorHook, which logs given
tensors every N local steps, every N seconds, or at the end. The metric
information will be logged to given log_dir or via metric_logger in JSON
format, which can be consumed by data analysis pipeline later.

Note that if `at_end` is True, `tensors` should not include any tensor
whose evaluation produces a side effect such as consuming additional inputs.
"""

def __init__(self, tensors, log_dir=None, metric_logger=None,
every_n_iter=None, every_n_secs=None, at_end=False):
"""Initializer for LoggingMetricHook.

Args:
tensors: `dict` that maps string-valued tags to tensors/tensor names,
or `iterable` of tensors/tensor names.
log_dir: `string`, directory path that metric hook should write log to.
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. Exactly one of the `log_dir` and
`metric_logger` should be provided.
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
every_n_secs: `int` or `float`, print the values of `tensors` once every N
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided.
at_end: `bool` specifying whether to print the values of `tensors` at the
end of the run.

Raises:
ValueError:
1. `every_n_iter` is non-positive, or
2. Exactly one of every_n_iter and every_n_secs should be provided.
3. Exactly one of log_dir and metric_logger should be provided.
"""
super(LoggingMetricHook, self).__init__(
tensors=tensors,
every_n_iter=every_n_iter,
every_n_secs=every_n_secs,
at_end=at_end)

if (log_dir is None) == (metric_logger is None):
raise ValueError(
"exactly one of log_dir and metric_logger should be provided.")

if log_dir is not None:
self._logger = logger.BenchmarkLogger(log_dir)
else:
self._logger = metric_logger

def begin(self):
super(LoggingMetricHook, self).begin()
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use LoggingMetricHook.")
if self._global_step_tensor.name not in self._current_tensors:
self._current_tensors[self._global_step_tensor.name] = (
self._global_step_tensor)

def after_run(self, unused_run_context, run_values):
# should_trigger is a internal state that populated at before_run, and it is
# using self_timer to determine whether it should trigger.
if self._should_trigger:
self._log_metric(run_values.results)

self._iter_count += 1

def end(self, session):
if self._log_at_end:
values = session.run(self._current_tensors)
self._log_metric(values)

def _log_metric(self, tensor_values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to sync timestamps for different tensors in the same measurement. Right now the times are slightly off which could be annoying later.

{"name": "train_accuracy",  "timestamp": "2018-03-21T12:40:03.460442Z", ... "global_step": 33}
{"name": "learning_rate",   "timestamp": "2018-03-21T12:40:03.460687Z", ... "global_step": 33}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, I think Karmel also had similar comment about bulk logging metrics. Currently we still can align them via global_step. Will address this when in future change.

self._timer.update_last_triggered_step(self._iter_count)
global_step = tensor_values[self._global_step_tensor.name]
# self._tag_order is populated during the init of LoggingTensorHook
for tag in self._tag_order:
self._logger.log_metric(tag, tensor_values[tag], global_step=global_step)
232 changes: 232 additions & 0 deletions official/utils/logging/metric_hook_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metric_hook."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tempfile
import time

import tensorflow as tf
from tensorflow.python.training import monitored_session

from official.utils.logging import metric_hook


class LoggingMetricHookTest(tf.test.TestCase):

def setUp(self):
super(LoggingMetricHookTest, self).setUp()

class MockMetricLogger(object):
def __init__(self):
self.logged_metric = []

def log_metric(self, name, value, unit=None, global_step=None,
extras=None):
self.logged_metric.append({
"name": name,
"value": float(value),
"unit": unit,
"global_step": global_step,
"extras": extras})

self._log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
self._logger = MockMetricLogger()

def tearDown(self):
super(LoggingMetricHookTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())

def test_illegal_args(self):
with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=0)
with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=-10)
with self.assertRaisesRegexp(ValueError, 'xactly one of'):
metric_hook.LoggingMetricHook(
tensors=['t'], every_n_iter=5, every_n_secs=5)
with self.assertRaisesRegexp(ValueError, 'xactly one of'):
metric_hook.LoggingMetricHook(tensors=['t'])
with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'):
metric_hook.LoggingMetricHook(tensors=['t'], every_n_iter=5)
with self.assertRaisesRegexp(ValueError, 'log_dir and metric_logger'):
metric_hook.LoggingMetricHook(
tensors=['t'], every_n_iter=5, log_dir=self._log_dir,
metric_logger=self._logger)

def test_print_at_end_only(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
t = tf.constant(42.0, name='foo')
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())

for _ in range(3):
mon_sess.run(train_op)
self.assertEqual(self._logger.logged_metric, [])

hook.end(sess)
self.assertEqual(len(self._logger.logged_metric), 1)
metric = self._logger.logged_metric[0]
self.assertRegexpMatches(metric["name"], "foo")
self.assertEqual(metric["value"], 42.0)
self.assertEqual(metric["unit"], None)
self.assertEqual(metric["global_step"], 0)

def test_global_step_not_found(self):
with tf.Graph().as_default(), tf.Session() as sess:
t = tf.constant(42.0, name='foo')
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], at_end=True, metric_logger=self._logger)

with self.assertRaisesRegexp(
RuntimeError, 'should be created to use LoggingMetricHook.'):
hook.begin()

def test_log_tensors(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
t1 = tf.constant(42.0, name='foo')
t2 = tf.constant(43.0, name='bar')
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t1, t2], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())

for _ in range(3):
mon_sess.run(train_op)
self.assertEqual(self._logger.logged_metric, [])

hook.end(sess)
self.assertEqual(len(self._logger.logged_metric), 2)
metric1 = self._logger.logged_metric[0]
self.assertRegexpMatches(str(metric1["name"]), "foo")
self.assertEqual(metric1["value"], 42.0)
self.assertEqual(metric1["unit"], None)
self.assertEqual(metric1["global_step"], 0)

metric2 = self._logger.logged_metric[1]
self.assertRegexpMatches(str(metric2["name"]), "bar")
self.assertEqual(metric2["value"], 43.0)
self.assertEqual(metric2["unit"], None)
self.assertEqual(metric2["global_step"], 0)

def _validate_print_every_n_steps(self, sess, at_end):
t = tf.constant(42.0, name='foo')

train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_iter=10, at_end=at_end,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
for _ in range(3):
self._logger.logged_metric = []
for _ in range(9):
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

# Add additional run to verify proper reset when called multiple times.
self._logger.logged_metric = []
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

self._logger.logged_metric = []
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

def test_print_every_n_steps(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=False)
# Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=False)

def test_print_every_n_steps_and_end(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
self._validate_print_every_n_steps(sess, at_end=True)
# Verify proper reset.
self._validate_print_every_n_steps(sess, at_end=True)

def _validate_print_every_n_secs(self, sess, at_end):
t = tf.constant(42.0, name='foo')
train_op = tf.constant(3)

hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_secs=1.0, at_end=at_end,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())

mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

# assertNotRegexpMatches is not supported by python 3.1 and later
self._logger.logged_metric = []
mon_sess.run(train_op)
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
time.sleep(1.0)

self._logger.logged_metric = []
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

self._logger.logged_metric = []
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

def test_print_every_n_secs(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=False)
# Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=False)

def test_print_every_n_secs_and_end(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
self._validate_print_every_n_secs(sess, at_end=True)
# Verify proper reset.
self._validate_print_every_n_secs(sess, at_end=True)


if __name__ == '__main__':
tf.test.main()