Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autograph should handle "for" loops over "range" in a manner that is compatible with XLA compilation #30182

Closed
shoyer opened this issue Jun 26, 2019 · 4 comments
Assignees
Labels
comp:autograph Autograph related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug WIP

Comments

@shoyer
Copy link
Contributor

shoyer commented Jun 26, 2019

System information

  • TensorFlow version (you are using): 1.14
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.

Consider the following Python code:

import tensorflow as tf
autograph = tf.contrib.autograph
xla = tf.contrib.compiler.xla

tf.enable_eager_execution()

@tf.function
def bad_loop(x, count):
  for _ in range(count):
    x += 1
  return x

@tf.function
def good_loop(x, count):
  i = 0
  while i < count:
    x += 1
    i += 1
  return x

bad_loop is the intuitive way to write this loop. However, it fails to compile with xla:

>>> xla.compile(bad_loop, [1.0, 3])
InvalidArgumentError: Input 1 to node `StatefulPartitionedCall/range` with op Range must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.
	 [[{{node StatefulPartitionedCall/range}}]]
	This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to.
	 [[cluster]] [Op:__inference_xla_compile_wrapper_166]

In contrast, good_loop calculates the correct result:

>>> xla.compile(good_loop, [1.0, 3])
[<tf.Tensor: id=229, shape=(), dtype=float32, numpy=4.0>]

Autograph seems to always convert range() into tf.range(), even in for loops. This means that XLA can't compile the function. However, the equivalent loop written as a naive while loop works.

Ideally, Autograph would detect such uses of range in for loops and convert them into the style of good_loop automatically, rather than requiring users to do this. This would let us write cleaner code.

Will this change the current api? How? No

Who will benefit with this feature? Users who want to write normal Python code with Autograph.

Any Other info.

@ravikyram ravikyram self-assigned this Jun 28, 2019
@ravikyram ravikyram added comp:autograph Autograph related issues type:bug Bug labels Jun 28, 2019
@ravikyram
Copy link
Contributor

I have tried on colab with TF version 1.14 and was able to reproduce the issue.Thanks!

@ravikyram ravikyram assigned ymodak and unassigned ravikyram Jun 28, 2019
@ymodak ymodak assigned mdanatg and unassigned ymodak Jun 28, 2019
@ymodak ymodak added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 28, 2019
@mdanatg
Copy link

mdanatg commented Jun 28, 2019

Looks like a few bugs are being compound here; I'll list them along with recommendations and plans to address -

  1. I think you are correct, in this case there is no way but to detect the use of tf.range; at first, I thought this would be a mere performance optimization, but it seems to be required for XLA. It wasn't already enabled because the detection of tf.range op is not terribly robust, but I think this example justifies it. Will follow up with a fix soon. In the mean time, using tf.range(3) should work (see below).

  2. range is normally not converted to tf.range - this only happens when its argument is a Tensor; xla.compile will auto-cast all arguments to tensors, hence range will receive a Tensor even though you only specify just 3. Even so, using range(tf.constant(3)), is not officially supported and I recommend using tf.range, which is more explicit anyway.

  3. (filed xla.compile + tf.function lose information about compile-time constants #30235) It appears that tf.range only works in XLA if you specify it with an inline constant: tf.range(tf.constant(3)); even though bad_function should be equivalent to that, it looks like xla.compile will not recognize the constant argument and raise an error. For example, the following code will work:

import tensorflow as tf
autograph = tf.contrib.autograph
xla = tf.contrib.compiler.xla

tf.enable_eager_execution()

@tf.function
def good_bad_loop(x):
  for _ in tf.range(3):
    x += 1
  return x

xla.compile(good_bad_loop, (1.0,))

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jun 29, 2019
@mdanatg mdanatg added TF 2.0 Issues relating to TensorFlow 2.0 WIP labels Jul 9, 2019
@tensorflow-bot
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@shoyer
Copy link
Contributor Author

shoyer commented Jul 10, 2019

Thank you @mdanatg!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:autograph Autograph related issues TF 2.0 Issues relating to TensorFlow 2.0 type:bug Bug WIP
Projects
None yet
Development

No branches or pull requests

5 participants