Skip to content

MST op and tf.function(experimental_compile=True) / XLA #283

@tc-wolf

Description

@tc-wolf

I'm trying to use Tensorflow text's max spanning tree op as part of a model that uses tf.function (including experimental_compile=True to enable XLA) to wrap its call method. Unfortunately, this gives an error.

A minimal example is:

import tensorflow as tf
import tensorflow_text as text

print(f"TF Version: {tf.__version__}")

# Using tensorflow_text 2.2.0rc2
# print(f"TF Text Version: {text.__version__}")

def call_mst_op():
    batch_size = 40
    time_steps = 20


    # Fake 'scores' for each possible arc
    scores = tf.random.normal((batch_size, time_steps, time_steps), dtype=tf.float32)
    sequence_lens = tf.random.uniform((batch_size,), minval=1, maxval=time_steps, dtype=tf.int32)

    scores, heads = text.max_spanning_tree(sequence_lens, scores)
    
    return scores, heads

print(call_mst_op())

mst_op_jit = tf.function(call_mst_op, experimental_compile=True)

mst_op_jit()

This gives:

TF Version: 2.2.0-rc3
(<tf.Tensor: shape=(40,), dtype=float32, numpy=
array([12.709742 , 14.647888 , 24.65781  , 11.281626 ,  7.498613 ,
       19.726154 ,  7.788096 , 15.929688 , 27.823446 ,  1.9694456,
        3.5884383, 13.10537  , 23.137075 ,  8.10518  , 19.063808 ,
        4.810276 , 29.56293  , 29.015924 , 18.94247  , 25.819365 ,
        3.631906 , 27.130611 ,  1.5555735,  8.239564 ,  9.661241 ,
       10.398854 , 16.17563  , 13.159201 ,  1.8612909, 26.980581 ,
       20.842346 , 23.399996 , 21.928398 , 11.112592 , 26.93416  ,
        8.321238 , 30.637009 , 28.603863 , 26.578001 , 16.191595 ],
      dtype=float32)>, <tf.Tensor: shape=(40, 20), dtype=int32, numpy=
array([[ 4,  5,  1,  2,  4,  0,  1,  2, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 1,  6,  9,  1,  2,  7,  7,  4,  2,  9, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 8,  0,  8,  2,  1,  2,  5,  2, 13,  3,  7,  0, 14, 13,  7, -1,
        -1, -1, -1, -1],
       [ 6,  2,  6,  7,  1,  4,  6,  2,  7, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 5,  3,  4,  0,  4,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 2,  3, 10,  9,  2, 13, 10, 11,  4, 13, 10,  2,  6, 12, -1, -1,
        -1, -1, -1, -1],
       [ 6,  2,  5,  3,  7,  7,  7,  3, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [10,  4,  1, 11, 11,  1,  6,  4,  7, 10,  3,  6, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 2, 10,  2,  2, 12,  9, 11,  2, 10,  4,  0,  9,  1,  4, 10,  2,
        -1, -1, -1, -1],
       [ 1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 1,  1,  3,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 4,  3,  4,  4,  5,  5,  3,  6,  5, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 8,  6,  1, 13,  8, 10, 11,  7,  1,  5, 11,  7,  5,  8, -1, -1,
        -1, -1, -1, -1],
       [ 0,  4,  4,  2,  5,  6,  0, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [10,  1, 10,  6, 10,  7,  8,  6,  1, 11,  6,  7, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 2,  0,  2,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [14, 11, 11,  5,  4,  4,  9, 15,  1, 14, 14,  3, 14,  6,  1,  0,
        -1, -1, -1, -1],
       [ 1,  1, 11, 13, 14,  3,  9, 14, 13,  0,  2,  7, 10,  6,  8, -1,
        -1, -1, -1, -1],
       [ 9,  2,  5, 10,  8,  5,  5,  2,  2, 10,  8, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [11, 13, 14, 11,  5, 12,  3, 12,  1,  0,  8,  5, 14,  0, 14, -1,
        -1, -1, -1, -1],
       [ 1,  2,  4,  1,  4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [11,  0, 11,  4,  2,  3,  0, 14,  8,  7,  8,  8, 14, 10,  1,  4,
        -1, -1, -1, -1],
       [ 3,  2,  2,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 5,  1,  5,  1,  2,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 0,  4,  6,  1,  0,  1,  4, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 3,  5,  4,  7,  1,  6,  0,  7, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [11,  0,  1,  5,  1, 10, 11,  3,  0,  7,  8, 11, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 1,  5,  3,  4,  0,  5,  2,  4, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 1,  1,  0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 7,  6,  6, 12,  2,  5, 14, 10, 11, 15,  5, 15,  5,  6, 15,  7,
        -1, -1, -1, -1],
       [ 0,  9,  7,  5,  2,  1,  2,  0, 11,  4,  3,  1, 10, -1, -1, -1,
        -1, -1, -1, -1],
       [ 1, 10, 10,  0,  6, 13,  0, 12,  8,  4,  8,  7, 13,  1, -1, -1,
        -1, -1, -1, -1],
       [10,  6,  2,  0, 11,  9,  2,  0,  0,  1,  5,  9, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 2,  1,  1,  2,  6,  0,  2,  3, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [10, 15,  1,  9, 13, 12, 15,  3,  4, 10, 11, 11,  0, 15,  4,  3,
        -1, -1, -1, -1],
       [ 5,  1,  1,  5,  3,  6,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1],
       [ 5,  4, 13,  3,  9,  1,  1,  3, 15, 16,  7,  3,  6,  1,  7,  9,
        10, 16, -1, -1],
       [12, 14, 12, 13, 12, 12,  7,  7, 12,  4,  2, 12, 13,  7,  8, -1,
        -1, -1, -1, -1],
       [ 5,  1,  3, 15, 13,  9, 10, 10,  9,  1,  5,  6,  9,  3,  9,  8,
         0, -1, -1, -1],
       [ 3,  7,  1,  6,  5,  5,  8,  5,  2,  8, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32)>)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
 in 
     24 mst_op_jit = tf.function(call_mst_op, experimental_compile=True)
     25 
---> 26 mst_op_jit()
     27 
     28 

~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    574       try:
    575         xla_context.Enter()
--> 576         result = self._call(*args, **kwds)
    577       finally:
    578         xla_context.Exit()

~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    648               *args, **kwds)
    649       # If we did not create any variables the trace we have is good enough.
--> 650       return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
    651 
    652     def fn_with_cond(*inner_args, **inner_kwds):

~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs)
   1663          if isinstance(t, (ops.Tensor,
   1664                            resource_variable_ops.BaseResourceVariable))),
-> 1665         self.captured_inputs)
   1666 
   1667   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1744       # No tape is watching; skip to running the function.
   1745       return self._build_call_outputs(self._inference_function.call(
-> 1746           ctx, args, cancellation_manager=cancellation_manager))
   1747     forward_backward = self._select_forward_and_backward_functions(
   1748         args,

~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    596               inputs=args,
    597               attrs=attrs,
--> 598               ctx=ctx)
    599         else:
    600           outputs = execute.execute_with_cancellation(

~/foo/.venv/lib/python3.6/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: Function invoked by the following node is not compilable: {{node __inference_call_mst_op_455}} = __inference_call_mst_op_455[_XlaMustCompile=true, config_proto="\n\007\n\003GPU\020\001\n\007\n\003CPU\020\0012\005*\0010J\0008\001", executor_type=""]().
Uncompilable nodes:
MaxSpanningTree: unsupported op: No registered 'MaxSpanningTree' OpKernel for XLA_GPU_JIT devices compatible with node {{node MaxSpanningTree}}
	Stacktrace:
		Node: __inference_call_mst_op_455, function: 
		Node: MaxSpanningTree, function: __inference_call_mst_op_455
 [Op:__inference_call_mst_op_455]

As a workaround, tf.function still works if experimental_compile is passed in as False.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions