-
Notifications
You must be signed in to change notification settings - Fork 75.1k
Description
System information
- Have I written custom code: yes
- OS Platform and Distribution: Linux Mint 19.1
- TensorFlow installed from: binary (using pip)
- TensorFlow version: 2.0.0-beta1 & 2.0.0-dev10290731 (tried on both)
- Python version: 3.6.8
- CUDA/cuDNN version: 10.0 / 7.5
- GPU model and memory: Nvidia Quadro P1000 - 4 GB GDDR5
Describe the current behavior
When Eager execution is enabled, tf.while_loop uses a backend implementation suited for Eager tensors only, which practically disallows the use of while loops with Keras symbolic tensors. Specifically, the line while cond(*loop_vars): (here as of this date) is only valid when loop_vars is a list of EagerTensor instances, thus enabling the while check on the numpy attribute of the results.
Now, my issue is that I actually need to implement a function that works on Keras symbolic tensors and uses a tf.while_loop, which seemingly proves impossible (apart from disabling Eager execution, which is a workaround I am comfortable to use in the long term, but i does not feel like an actual solution).
Intuitively, I would think a way to fix the issue would be to follow an alternative route within tf.while_loop's source code when using symbolic tensors (e.g. that used when Eager is disabled), and keeping the Eager-suited one otherwise. I tried a nasty fix which consisted in adding and all(isinstance(tensor, ops.EagerTensor) for tensor in loop_vars) to the executing_eagerly = context.executing_eagerly() line at the beginning of while_loop's body in the source code, but this results in raising AttributeError: Tensor.name is meaningless when eager execution is enabled. within the loop constructor (only with symbolic tensors - when using Eager ones, everything goes fine). I would be happy to dig deeper and contribute a fix if I can find one, but first it would be nice to know whether a solution already exist (I might just be missing something, perhaps from the keras backend submodule).
Describe the expected behavior
I would like to be able to run symbolic tensors through a while loop without disabling Eager execution (basically, I would like Eager execution not to take away practical functionalities which are very useful in designing models without having to put up small hacks of the framework, which are bound to decrease readability and stability).
Code to reproduce the issue
Base code defining the function I want to implement and two minimalist tests (in practice my symbolic tensors are not mere inputs, but the issue is strictly similar):
import tensorflow as tf
def pred_in_top_k(y_true, y_pred, k=5):
"""Check whether targets are in top K predictions, for batched samples.
Extension of `tf.keras.metrics.sparse_top_k_categorical_accuracy`
to batched sequences of targets and predictions.
y_true : true labels; tf.int32 or tf.int64 Tensor of shape
(batch_len, max_seq_len)
y_pred : predicted probabilities; tf.float32 Tensor of shape
(batch_len, max_seq_len, n_labels)
"""
# Define the loop's body.
def body(i, matches):
"""Compute matches for a given sample in the batch."""
matching = tf.nn.in_top_k(y_true[i], y_pred[i], k=k)
matching = tf.expand_dims(tf.cast(matching, tf.float32), 0)
updated = tf.concat([matches[:i], matching, matches[i + 1:]], axis=0)
updated.set_shape(matches.shape)
return i + 1, updated
# Define the loop's stopping condition.
def cond(i, _):
"""Stop when the entire batch has been processed."""
return tf.less(i, tf.shape(y_true)[0])
# Run the loop and return the results.
loop_vars = [tf.constant(0), tf.zeros_like(y_true, dtype=tf.float32)]
_, matches = tf.while_loop(cond, body, loop_vars)
return matches
def test_random_tensors():
""""Run pred_in_top_k on random tensors."""
y_true = tf.random.uniform(shape=(4, 10), maxval=20, dtype=tf.int64)
y_pred = tf.nn.softmax(tf.random.normal(shape=(4, 10, 20)))
return pred_in_top_k(y_true, y_pred)
def test_symbolic_tensors():
"""Run pred_in_top_k on symbolic tensors."""
y_true = tf.keras.Input((None,), dtype=tf.int64)
y_pred = tf.keras.Input((None, 20), dtype=tf.float32)
return pred_in_top_k(y_true, y_pred)test_random_tensors() works both with and without having executed tf.compat.v1.disable_eager_execution() first.
test_symbolic_tensors() fails when Eager is left enabled, with the following error messages depending on the tensorflow 2.0 installation used:
2.0b1:
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.2.0 nightly:
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.Note that decoration with @tf.function does not solve the issue, as functions wrapped this way do not accept Keras symbolic tensors as inputs.