-
Notifications
You must be signed in to change notification settings - Fork 372
Open
Labels
enhancementNew feature or requestNew feature or request
Description
I'm trying to use Tensorflow text's max spanning tree op as part of a model that uses tf.function (including experimental_compile=True to enable XLA) to wrap its call method. Unfortunately, this gives an error.
A minimal example is:
import tensorflow as tf
import tensorflow_text as text
print(f"TF Version: {tf.__version__}")
# Using tensorflow_text 2.2.0rc2
# print(f"TF Text Version: {text.__version__}")
def call_mst_op():
batch_size = 40
time_steps = 20
# Fake 'scores' for each possible arc
scores = tf.random.normal((batch_size, time_steps, time_steps), dtype=tf.float32)
sequence_lens = tf.random.uniform((batch_size,), minval=1, maxval=time_steps, dtype=tf.int32)
scores, heads = text.max_spanning_tree(sequence_lens, scores)
return scores, heads
print(call_mst_op())
mst_op_jit = tf.function(call_mst_op, experimental_compile=True)
mst_op_jit()This gives:
TF Version: 2.2.0-rc3
(<tf.Tensor: shape=(40,), dtype=float32, numpy=
array([12.709742 , 14.647888 , 24.65781 , 11.281626 , 7.498613 ,
19.726154 , 7.788096 , 15.929688 , 27.823446 , 1.9694456,
3.5884383, 13.10537 , 23.137075 , 8.10518 , 19.063808 ,
4.810276 , 29.56293 , 29.015924 , 18.94247 , 25.819365 ,
3.631906 , 27.130611 , 1.5555735, 8.239564 , 9.661241 ,
10.398854 , 16.17563 , 13.159201 , 1.8612909, 26.980581 ,
20.842346 , 23.399996 , 21.928398 , 11.112592 , 26.93416 ,
8.321238 , 30.637009 , 28.603863 , 26.578001 , 16.191595 ],
dtype=float32)>, <tf.Tensor: shape=(40, 20), dtype=int32, numpy=
array([[ 4, 5, 1, 2, 4, 0, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 1, 6, 9, 1, 2, 7, 7, 4, 2, 9, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 8, 0, 8, 2, 1, 2, 5, 2, 13, 3, 7, 0, 14, 13, 7, -1,
-1, -1, -1, -1],
[ 6, 2, 6, 7, 1, 4, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 5, 3, 4, 0, 4, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 2, 3, 10, 9, 2, 13, 10, 11, 4, 13, 10, 2, 6, 12, -1, -1,
-1, -1, -1, -1],
[ 6, 2, 5, 3, 7, 7, 7, 3, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[10, 4, 1, 11, 11, 1, 6, 4, 7, 10, 3, 6, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 2, 10, 2, 2, 12, 9, 11, 2, 10, 4, 0, 9, 1, 4, 10, 2,
-1, -1, -1, -1],
[ 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 1, 1, 3, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 4, 3, 4, 4, 5, 5, 3, 6, 5, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 8, 6, 1, 13, 8, 10, 11, 7, 1, 5, 11, 7, 5, 8, -1, -1,
-1, -1, -1, -1],
[ 0, 4, 4, 2, 5, 6, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[10, 1, 10, 6, 10, 7, 8, 6, 1, 11, 6, 7, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 2, 0, 2, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[14, 11, 11, 5, 4, 4, 9, 15, 1, 14, 14, 3, 14, 6, 1, 0,
-1, -1, -1, -1],
[ 1, 1, 11, 13, 14, 3, 9, 14, 13, 0, 2, 7, 10, 6, 8, -1,
-1, -1, -1, -1],
[ 9, 2, 5, 10, 8, 5, 5, 2, 2, 10, 8, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[11, 13, 14, 11, 5, 12, 3, 12, 1, 0, 8, 5, 14, 0, 14, -1,
-1, -1, -1, -1],
[ 1, 2, 4, 1, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[11, 0, 11, 4, 2, 3, 0, 14, 8, 7, 8, 8, 14, 10, 1, 4,
-1, -1, -1, -1],
[ 3, 2, 2, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 5, 1, 5, 1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 0, 4, 6, 1, 0, 1, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 3, 5, 4, 7, 1, 6, 0, 7, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[11, 0, 1, 5, 1, 10, 11, 3, 0, 7, 8, 11, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 1, 5, 3, 4, 0, 5, 2, 4, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 1, 1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 7, 6, 6, 12, 2, 5, 14, 10, 11, 15, 5, 15, 5, 6, 15, 7,
-1, -1, -1, -1],
[ 0, 9, 7, 5, 2, 1, 2, 0, 11, 4, 3, 1, 10, -1, -1, -1,
-1, -1, -1, -1],
[ 1, 10, 10, 0, 6, 13, 0, 12, 8, 4, 8, 7, 13, 1, -1, -1,
-1, -1, -1, -1],
[10, 6, 2, 0, 11, 9, 2, 0, 0, 1, 5, 9, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 2, 1, 1, 2, 6, 0, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[10, 15, 1, 9, 13, 12, 15, 3, 4, 10, 11, 11, 0, 15, 4, 3,
-1, -1, -1, -1],
[ 5, 1, 1, 5, 3, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1],
[ 5, 4, 13, 3, 9, 1, 1, 3, 15, 16, 7, 3, 6, 1, 7, 9,
10, 16, -1, -1],
[12, 14, 12, 13, 12, 12, 7, 7, 12, 4, 2, 12, 13, 7, 8, -1,
-1, -1, -1, -1],
[ 5, 1, 3, 15, 13, 9, 10, 10, 9, 1, 5, 6, 9, 3, 9, 8,
0, -1, -1, -1],
[ 3, 7, 1, 6, 5, 5, 8, 5, 2, 8, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1]], dtype=int32)>)
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
in
24 mst_op_jit = tf.function(call_mst_op, experimental_compile=True)
25
---> 26 mst_op_jit()
27
28
~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
574 try:
575 xla_context.Enter()
--> 576 result = self._call(*args, **kwds)
577 finally:
578 xla_context.Exit()
~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
648 *args, **kwds)
649 # If we did not create any variables the trace we have is good enough.
--> 650 return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
651
652 def fn_with_cond(*inner_args, **inner_kwds):
~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs)
1663 if isinstance(t, (ops.Tensor,
1664 resource_variable_ops.BaseResourceVariable))),
-> 1665 self.captured_inputs)
1666
1667 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1744 # No tape is watching; skip to running the function.
1745 return self._build_call_outputs(self._inference_function.call(
-> 1746 ctx, args, cancellation_manager=cancellation_manager))
1747 forward_backward = self._select_forward_and_backward_functions(
1748 args,
~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
596 inputs=args,
597 attrs=attrs,
--> 598 ctx=ctx)
599 else:
600 outputs = execute.execute_with_cancellation(
~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Function invoked by the following node is not compilable: {{node __inference_call_mst_op_455}} = __inference_call_mst_op_455[_XlaMustCompile=true, config_proto="\n\007\n\003GPU\020\001\n\007\n\003CPU\020\0012\005*\0010J\0008\001", executor_type=""]().
Uncompilable nodes:
MaxSpanningTree: unsupported op: No registered 'MaxSpanningTree' OpKernel for XLA_GPU_JIT devices compatible with node {{node MaxSpanningTree}}
Stacktrace:
Node: __inference_call_mst_op_455, function:
Node: MaxSpanningTree, function: __inference_call_mst_op_455
[Op:__inference_call_mst_op_455]
As a workaround, tf.function still works if experimental_compile is passed in as False.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request