Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[r2.0 CherryPick]: @tf.function: Show a warning message when tracing happens too frequently #32258

Merged
merged 1 commit into from Sep 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 67 additions & 1 deletion tensorflow/python/eager/def_function.py
Expand Up @@ -38,6 +38,41 @@
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export

FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
FREQUENT_TRACING_WARNING_THRESHOLD = 5


class _CallCounter(object):
"""Class keeping track of how many recent calls triggered tracing."""

def __init__(self, max_call_history):
self._max_call_history = max_call_history
self._calls_per_tracings = []
self.call_count = 0

def called_with_tracing(self):
self.call_count += 1
self._calls_per_tracings.append(1)

while self._calls_per_tracings:
if self.call_count - self._calls_per_tracings[0] > self._max_call_history:
self.call_count -= self._calls_per_tracings.pop(0)
else:
break

def called_without_tracing(self):
# TODO(kkimlabs): This is an unnecessary defensive check. Since this is last
# minute CL before 2.0 release, I've decided to be very defensive here to
# avoid a potential crash. Remove once we release 2.0.
if not self._calls_per_tracings:
return

self._calls_per_tracings[-1] += 1
self.call_count += 1

def get_tracing_count(self):
return len(self._calls_per_tracings)


class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
"""Variable which does not lift its initializer out of function context.
Expand Down Expand Up @@ -297,6 +332,7 @@ def __init__(self,
self._stateless_fn = None # GUARDED_BY(self._lock)
self._descriptor_cache = weakref.WeakKeyDictionary()
self._name = name
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)

def _defun_with_scope(self, scope):
"""Creates a defun wrapped inside a variable creator scope."""
Expand Down Expand Up @@ -406,11 +442,41 @@ def _decorate(self, decorator):
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
self._python_function, self.input_signature)

def _get_tracing_count(self):
result = self._stateless_fn.tracing_count if self._stateless_fn else 0
result += self._stateful_fn.tracing_count if self._stateful_fn else 0
return result

def __call__(self, *args, **kwds):
"""Calls the graph function."""
"""Calls the graph function and warn too frequent tracings."""
context.ensure_initialized()
if RUN_FUNCTIONS_EAGERLY:
return self._python_function(*args, **kwds)

tracing_count = self._get_tracing_count()
result = self._call(*args, **kwds)
if tracing_count == self._get_tracing_count():
self._call_counter.called_without_tracing()
return result

self._call_counter.called_with_tracing()
recent_tracing_count = self._call_counter.get_tracing_count()
if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD:
logging.warning(
"{} out of the last {} calls to {} triggered tf.function retracing. "
"Tracing is expensive and the excessive number of tracings is likely "
"due to passing python objects instead of tensors. Also, tf.function "
"has experimental_relax_shapes=True option that relaxes argument "
"shapes that can avoid unnecessary retracing. Please refer to "
"https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args"
" and https://www.tensorflow.org/api_docs/python/tf/function for more "
"details.".format(recent_tracing_count, self._call_counter.call_count,
self._python_function))

return result

def _call(self, *args, **kwds):
"""Calls the graph function."""
self._lock.acquire()
if self._created_variables:
# Release the lock early so that multiple threads can perform the call
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/python/eager/function.py
Expand Up @@ -1809,6 +1809,7 @@ def __init__(self,
self._function_cache = FunctionCache()
self._function_attributes = attributes or {}
self._capture_by_value = capture_by_value
self.tracing_count = 0

self._lock = threading.Lock()
# _descriptor_cache is a of instance of a class to an instance-specific
Expand Down Expand Up @@ -2011,6 +2012,8 @@ def _cache_key(self, args, kwargs, include_tensor_ranks_only=False):

def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
self.tracing_count += 1

if self.input_signature is None:
arglen = len(args)
else:
Expand Down