Skip to content

Commit

Permalink
Expose function_captures as a property in FuncGraph
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 522630963
  • Loading branch information
panzhufeng authored and tensorflower-gardener committed Apr 7, 2023
1 parent a38b477 commit b92b675
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tensorflow/python/compiler/tensorrt/trt_convert.py
Expand Up @@ -947,7 +947,7 @@ def _construct_function_from_graph_def(func, graph_def, frozen_func=None):

captures = {
c.internal.name.split(":")[0]: c.external
for c in frozen_func.graph._function_captures.by_val_captures.values() # pylint: disable = protected-access
for c in frozen_func.graph.function_captures.by_val_captures.values()
}
new_func = wrap_function.function_from_graph_def(
graph_def, [tensor.name for tensor in frozen_func.inputs],
Expand Down
Expand Up @@ -358,7 +358,7 @@ def from_func_graph_no_transforms(
num_outputs=len(signature.output_arg),
output_types=[o.type for o in signature.output_arg],
output_shapes=[o.shape for o in outputs],
control_captures=graph._function_captures.control, # pylint: disable=protected-access
control_captures=graph.function_captures.control,
func_graph_outputs=list(outputs),
attrs=attrs,
graph=graph,
Expand Down
Expand Up @@ -402,7 +402,7 @@ def _maybe_define_function(self, args, kwargs):
args, kwargs, func_graph)

# TODO(b/263520817): Remove access to private attribute.
graph_capture_container = concrete_function.graph._function_captures # pylint: disable=protected-access
graph_capture_container = concrete_function.graph.function_captures
# Maintain the list of all captures
self._func_captures.merge_by_ref_with(graph_capture_container)
# Get current active captures snapshot
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/python/framework/func_graph.py
Expand Up @@ -873,6 +873,10 @@ def variable_captures(self):
"""Map of python object ids of variables to variables which are captured."""
return self.variables

@property
def function_captures(self):
return self._function_captures

def mark_as_unsaveable(self, error_message):
"""Marks this FuncGraph as unsaveable.
Expand Down
10 changes: 5 additions & 5 deletions tensorflow/python/ops/cond_v2.py
Expand Up @@ -254,8 +254,8 @@ def _build_cond(pred,

# Create the If op.
with ops.control_dependencies(
list(true_graph._function_captures.control) + list( # pylint: disable=protected-access
false_graph._function_captures.control)): # pylint: disable=protected-access
list(true_graph.function_captures.control) + list(
false_graph.function_captures.control)):
true_stateful_ops = [
op for op in true_graph.get_operations() if op._is_stateful
]
Expand Down Expand Up @@ -333,7 +333,7 @@ def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None):
func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
for external_t, internal_t in zip(inputs, func_graph.inputs):
handle_data_util.copy_handle_data(external_t, internal_t)
func_graph._function_captures.reset_captures(inputs, func_graph.inputs)
func_graph.function_captures.reset_captures(inputs, func_graph.inputs)
# Link the op so that the gradient code can use it.
func_graph._forward_cond = op
return func_graph
Expand Down Expand Up @@ -584,7 +584,7 @@ def _make_inputs_match(branch_graphs, branch_inputs):
branch_graph.inputs = input_list

# Rewrite the FuncGraphs' state to reflect the new inputs.
branch_graph._function_captures.reset_captures(
branch_graph.function_captures.reset_captures(
new_inputs, branch_graph.inputs)

return new_inputs
Expand Down Expand Up @@ -1237,7 +1237,7 @@ def _build_case(branch_index,

# Create the Case op.
with ops.control_dependencies(
sum((list(bg._function_captures.control) for bg in branch_graphs), [])): # pylint: disable=protected-access
sum((list(bg.function_captures.control) for bg in branch_graphs), [])):

def _make_op(inputs):
case_op, tensors = util.get_op_and_outputs(op_fn(
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/python/ops/while_v2.py
Expand Up @@ -279,8 +279,8 @@ def wrapped_body(loop_counter, maximum_iterations_arg, *args):
_check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

with ops.control_dependencies(
list(cond_graph._function_captures.control) + list( # pylint: disable=protected-access
body_graph._function_captures.control)): # pylint: disable=protected-access
list(cond_graph.function_captures.control) + list(
body_graph.function_captures.control)):
output_shapes = [t.shape for t in body_graph.outputs]
orig_loop_vars_range = slice(first_loop_var_index,
first_loop_var_index + num_flattened_outputs)
Expand Down Expand Up @@ -1104,7 +1104,7 @@ def _capture_helper(self, tensor, name):
return captured_tensor

if tensor.graph is not self._forward_graph:
already_captured = id(tensor) in self._function_captures.by_val_captures # pylint: disable=protected-access
already_captured = id(tensor) in self.function_captures.by_val_captures
captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(
tensor, name)
if not already_captured:
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
keys = [id(t) for t in body_graph_captures]
for k, v in zip(keys, tuples):
capture = capture_container.CaptureContainer(v[0], v[1], k, False)
cond_graph._function_captures._by_val[k] = capture # pylint: disable=protected-access
cond_graph.function_captures._by_val[k] = capture # pylint: disable=protected-access
cond_graph.inputs.extend(tensors)


Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/saved_model/function_serialization.py
Expand Up @@ -127,7 +127,7 @@ def wrap_cached_variables(concrete_function):
"""
outer_graph = func_graph_module.FuncGraph(
"{}_no_cache".format(concrete_function.graph.name))
captures = concrete_function.graph._function_captures._by_val # pylint: disable=protected-access
captures = concrete_function.graph.function_captures._by_val # pylint: disable=protected-access
mapped_captures = None
remapped_captures = {}

Expand Down
Expand Up @@ -32,6 +32,10 @@ tf_class {
name: "finalized"
mtype: "<type \'property\'>"
}
member {
name: "function_captures"
mtype: "<type \'property\'>"
}
member {
name: "graph_def_versions"
mtype: "<type \'property\'>"
Expand Down

0 comments on commit b92b675

Please sign in to comment.