You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
bad_loop is the intuitive way to write this loop. However, it fails to compile with xla:
>>> xla.compile(bad_loop, [1.0, 3])
InvalidArgumentError: Input 1 to node `StatefulPartitionedCall/range` with op Range must be a compile-time constant.
XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.
[[{{node StatefulPartitionedCall/range}}]]
This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to.
[[cluster]] [Op:__inference_xla_compile_wrapper_166]
In contrast, good_loop calculates the correct result:
Autograph seems to always convert range() into tf.range(), even in for loops. This means that XLA can't compile the function. However, the equivalent loop written as a naive while loop works.
Ideally, Autograph would detect such uses of range in for loops and convert them into the style of good_loop automatically, rather than requiring users to do this. This would let us write cleaner code.
Will this change the current api? How? No
Who will benefit with this feature? Users who want to write normal Python code with Autograph.
Any Other info.
The text was updated successfully, but these errors were encountered:
Looks like a few bugs are being compound here; I'll list them along with recommendations and plans to address -
I think you are correct, in this case there is no way but to detect the use of tf.range; at first, I thought this would be a mere performance optimization, but it seems to be required for XLA. It wasn't already enabled because the detection of tf.range op is not terribly robust, but I think this example justifies it. Will follow up with a fix soon. In the mean time, using tf.range(3) should work (see below).
range is normally not converted to tf.range - this only happens when its argument is a Tensor; xla.compile will auto-cast all arguments to tensors, hence range will receive a Tensor even though you only specify just 3. Even so, using range(tf.constant(3)), is not officially supported and I recommend using tf.range, which is more explicit anyway.
(filed xla.compile + tf.function lose information about compile-time constants #30235) It appears that tf.range only works in XLA if you specify it with an inline constant: tf.range(tf.constant(3)); even though bad_function should be equivalent to that, it looks like xla.compile will not recognize the constant argument and raise an error. For example, the following code will work:
import tensorflow as tf
autograph = tf.contrib.autograph
xla = tf.contrib.compiler.xla
tf.enable_eager_execution()
@tf.function
def good_bad_loop(x):
for _ in tf.range(3):
x += 1
return x
xla.compile(good_bad_loop, (1.0,))
System information
Describe the feature and the current behavior/state.
Consider the following Python code:
bad_loop
is the intuitive way to write this loop. However, it fails to compile with xla:In contrast,
good_loop
calculates the correct result:Autograph seems to always convert
range()
intotf.range()
, even in for loops. This means that XLA can't compile the function. However, the equivalent loop written as a naivewhile
loop works.Ideally, Autograph would detect such uses of
range
in for loops and convert them into the style ofgood_loop
automatically, rather than requiring users to do this. This would let us write cleaner code.Will this change the current api? How? No
Who will benefit with this feature? Users who want to write normal Python code with Autograph.
Any Other info.
The text was updated successfully, but these errors were encountered: