Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tensorflow/python/keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,7 +2762,8 @@ class Function(object):
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: A name to help users identify what this function does.
session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`.
session_kwargs: Arguments to `tf.Session.run()`:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in getting back. This PR LGTM except for a minor comment.

session_kwarg is not precise. It should be called session_run_kwarg, otherwise it may be confused with kwargs that you can use when creating a session (e.g., https://www.tensorflow.org/api_docs/python/tf/Session)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind. This is an existing name and changing it would be breaking.

`fetches`, `feed_dict`, `options`, `run_metadata`.
"""

def __init__(self, inputs, outputs, updates=None, name=None,
Expand Down Expand Up @@ -2796,6 +2797,8 @@ def __init__(self, inputs, outputs, updates=None, name=None,
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
self.run_options = session_kwargs.pop('options', None)
self.run_metadata = session_kwargs.pop('run_metadata', None)
# The main use case of `fetches` being passed to a model is the ability
# to run custom updates
# This requires us to wrap fetches in `identity` ops.
Expand Down Expand Up @@ -2853,6 +2856,9 @@ def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
callable_opts.fetch.append(x.name)
# Handle updates.
callable_opts.target.append(self.updates_op.name)
# Handle run_options.
if self.run_options:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, handle run_metadata as well.

callable_opts.run_options.CopyFrom(self.run_options)
# Create callable.
callable_fn = session._make_callable_from_options(callable_opts)
# Cache parameters corresponding to the generated callable, so that
Expand Down Expand Up @@ -2911,7 +2917,8 @@ def __call__(self, inputs):
session != self._session):
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)

fetched = self._callable_fn(*array_vals)
fetched = self._callable_fn(*array_vals,
run_metadata=self.run_metadata)
self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]

Expand Down
24 changes: 24 additions & 0 deletions tensorflow/python/keras/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import scipy.sparse

from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
Expand Down Expand Up @@ -276,6 +277,29 @@ def test_function_tf_feed_dict(self):
self.assertEqual(
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])

def test_function_tf_run_options_with_run_metadata(self):
with self.test_session():
x_placeholder = keras.backend.placeholder(shape=())
y_placeholder = keras.backend.placeholder(shape=())

run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
# enable run_options.
f = keras.backend.function(inputs=[x_placeholder, y_placeholder],
outputs=[x_placeholder + y_placeholder],
options=run_options,
run_metadata=run_metadata)
output = f([10., 20.])
self.assertEqual(output, [30.])
self.assertGreater(len(run_metadata.partition_graphs), 0)
# disable run_options.
f1 = keras.backend.function(inputs=[x_placeholder, y_placeholder],
outputs=[x_placeholder + y_placeholder],
run_metadata=run_metadata)
output1 = f1([10., 20.])
self.assertEqual(output1, [30.])
self.assertEqual(len(run_metadata.partition_graphs), 0)

def test_function_fetch_callbacks(self):

class CallbackStub(object):
Expand Down