-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add session hook for benchmark metric logging. #3672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Current hook is very similar as the LoggingTensorHook. Some of the function are directly copied since the original one was not exposed for import. We should seek to eventually move this code to core when it is mature enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks sufficiently similar to LoggingTensorHook that I think we would be better off subclassing that, calling super for the methods with changes, and doing the post-work. We can overwrite an entire method if necessary. Thoughts?
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from official.utils.logging import logger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: local (official) imports should go below the third party.
The existing hook is similar enough to LoggingTensorHook, and we should eliminate duplicate as much as possible.
Good point. Update to inherit LoggingTensorHook as parent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exciting. Two notes:
- Our lint tests are working! Please lint.
- After this gets merged, we will want to add the module and its test to the build file.
|
||
This hook is very similar as tf.train.LoggingTensorHook, which logs given | ||
tensors every N local steps, every N seconds, or at the end. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some details on how it is different/what it is used for?
tensors: `dict` that maps string-valued tags to tensors/tensor names, | ||
or `iterable` of tensors/tensor names. | ||
log_dir: `string`, directory path that metric hook should write log to. | ||
metric_logger: `BenchmarkLogger`, the benchmark logger that hook should |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An instance of BL or a class?
|
||
def begin(self): | ||
super(LoggingMetricHook, self).begin() | ||
if tf.train.get_global_step() is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the graph can optimize this out, but, for clarity-- maybe call get_global_step once, assigned to a var, and then check and use that var below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self._current_tensors[GLOBAL_STEP_TENSOR_NAME] = tf.train.get_global_step() | ||
|
||
def after_run(self, unused_run_context, run_values): | ||
if self._should_trigger: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments for all these methods would be helpful, since most people won't know what the parent class does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def _log_metric(self, tensor_values): | ||
self._timer.update_last_triggered_step(self._iter_count) | ||
global_step = tensor_values[GLOBAL_STEP_TENSOR_NAME] | ||
for tag in self._tag_order: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably worth noting that this comes from LoggingTensorHook, which captures the keys of tensors
during init.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
import time | ||
|
||
from official.utils.logging import metric_hook | ||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import order
from tensorflow.python.framework import constant_op | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.ops import variables as variables_lib | ||
from tensorflow.python.training import monitored_session |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't validated this for all, but certainly many of the classes/functions used below that I see is available from the top-level tf
import. Let's stick with that unless absolutely necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the suggestion.
mon_sess = monitored_session._HookedSession(sess, [hook]) | ||
sess.run(variables_lib.global_variables_initializer()) | ||
|
||
# metric_log = os.path.join(self.log_dir, "metric.log") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove debugging lines
if tf.train.get_global_step() is None: | ||
raise RuntimeError( | ||
"Global step should be created to use LoggingMetricHook.") | ||
self._current_tensors[GLOBAL_STEP_TENSOR_NAME] = tf.train.get_global_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this create a problem if someone happens to pass in the global step tensor, or a tensor by the same name? TF is often more cautious about placeholder vars that it adds in-- see, for example, https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/python/estimator/inputs/numpy_io.py#L51 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, I get rid of the self defined name, and use the default ops name as the key. I will trust the if user put a tensor with that name, it will also be the global step tensor, otherwise they are kind of shooting on their own foot. On the other hand, if user put global step tensor as a input with different name, they will just get a extra metric logged, which does not provide much value for them.
train_op = constant_op.constant(3) | ||
|
||
hook = metric_hook.LoggingMetricHook( | ||
tensors=[t.name], every_n_secs=1.0, at_end=at_end, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make sure to test the case of actually having multiple tensors passed in, and also passed in as a dict? Also, we use tensor names in all of these tests; let's make sure to test also with the tensors themselves, which are allowed according to the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Added test with multi tensors.
1. Update global step tensor handle. 2. Update tests. 3. Update document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last comment, but looks good.
raise RuntimeError( | ||
"Global step should be created to use LoggingMetricHook.") | ||
if not self._current_tensors.has_key(ops.GraphKeys.GLOBAL_STEP): | ||
self._current_tensors[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should be able to just say global_step_tensor.name here, which also allows us to remove the dependency on ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Ah, looks like some py3 errors as well; should be easy to resolve. |
if global_step_tensor is None: | ||
raise RuntimeError( | ||
"Global step should be created to use LoggingMetricHook.") | ||
if not self._current_tensors.has_key(ops.GraphKeys.GLOBAL_STEP): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.has_key() is deprecated in favor of "in".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
values = session.run(self._current_tensors) | ||
self._log_metric(values) | ||
|
||
def _log_metric(self, tensor_values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to sync timestamps for different tensors in the same measurement. Right now the times are slightly off which could be annoying later.
{"name": "train_accuracy", "timestamp": "2018-03-21T12:40:03.460442Z", ... "global_step": 33}
{"name": "learning_rate", "timestamp": "2018-03-21T12:40:03.460687Z", ... "global_step": 33}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, I think Karmel also had similar comment about bulk logging metrics. Currently we still can align them via global_step. Will address this when in future change.
Current hook is very similar as the LoggingTensorHook. Some of the
function are directly copied since the original one was not
exposed for import. We should seek to eventually move this code to
core when it is mature enough.