diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f149a61dfc95b4..1ed379929c53f6 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2785,6 +2785,11 @@ def __init__(self): # tuples: (input_shape_tuple, reduction_indices_tuple), and the values # are pairs of tuples: (output_shape_kept_dims, tile_scaling). self._reduced_shape_cache = {} + # In eager mode, the top level graph can still be created. This is + # incorrect and undesriable but currently so many places are relying on + # this. This is a flag indicating that, and meant to be set manually after + # this graph construction. + self._is_eager_graph = False # TODO(skyewm): fold as much of the above as possible into the C # implementation @@ -5356,6 +5361,8 @@ def _GetGlobalDefaultGraph(self): # the global default graph and an explicit graph are combined in the # same process. self._global_default_graph = Graph() + if context.executing_eagerly(): + self._global_default_graph._is_eager_graph = True # pylint: disable=protected-access return self._global_default_graph def reset(self): diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 8463ffb8ae01b8..16711e600fbd3c 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -316,9 +316,11 @@ def _internal_py_func(func, while True: current_graph = graph if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access - graph = graph._outer_graph # pylint: disable=protected-access + if not graph._outer_graph._is_eager_graph: # pylint: disable=protected-access + graph = graph._outer_graph # pylint: disable=protected-access elif isinstance(graph, func_graph.FuncGraph): - graph = graph.outer_graph + if not graph.outer_graph._is_eager_graph: # pylint: disable=protected-access + graph = graph.outer_graph if graph is current_graph: break