Skip to content

Commit

Permalink
[py_function] Don't attach py_function to the global eager graph.
Browse files Browse the repository at this point in the history
Eager mode can incorrectly have a global graph.  Disabling global graph
on eager mode breaks too many assumptions so first introduce a flag indicating it.

Also, avoid attaching py_function to eager mode global graph, which is a leak.

Though this CL doesn't fix the leak yet as there are two more references that leads
to the leak, `tape_cache` and `ag_dnc_wrapper__` .

#35084

PiperOrigin-RevId: 288415011
Change-Id: Ica53e29521320af22c10609857d0a0219a9596ce
  • Loading branch information
tensorflower-gardener committed Jan 7, 2020
1 parent 5b74ee4 commit 3b74a63
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 7 additions & 0 deletions tensorflow/python/framework/ops.py
Expand Up @@ -2785,6 +2785,11 @@ def __init__(self):
# tuples: (input_shape_tuple, reduction_indices_tuple), and the values
# are pairs of tuples: (output_shape_kept_dims, tile_scaling).
self._reduced_shape_cache = {}
# In eager mode, the top level graph can still be created. This is
# incorrect and undesriable but currently so many places are relying on
# this. This is a flag indicating that, and meant to be set manually after
# this graph construction.
self._is_eager_graph = False

# TODO(skyewm): fold as much of the above as possible into the C
# implementation
Expand Down Expand Up @@ -5356,6 +5361,8 @@ def _GetGlobalDefaultGraph(self):
# the global default graph and an explicit graph are combined in the
# same process.
self._global_default_graph = Graph()
if context.executing_eagerly():
self._global_default_graph._is_eager_graph = True # pylint: disable=protected-access
return self._global_default_graph

def reset(self):
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/python/ops/script_ops.py
Expand Up @@ -316,9 +316,11 @@ def _internal_py_func(func,
while True:
current_graph = graph
if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access
graph = graph._outer_graph # pylint: disable=protected-access
if not graph._outer_graph._is_eager_graph: # pylint: disable=protected-access
graph = graph._outer_graph # pylint: disable=protected-access
elif isinstance(graph, func_graph.FuncGraph):
graph = graph.outer_graph
if not graph.outer_graph._is_eager_graph: # pylint: disable=protected-access
graph = graph.outer_graph
if graph is current_graph:
break

Expand Down

0 comments on commit 3b74a63

Please sign in to comment.