1818
1919import numpy as np
2020
21+ from tensorflow .python .eager import def_function
2122from tensorflow .python .framework import dtypes
2223from tensorflow .python .framework import errors
2324from tensorflow .python .framework import ops
2425from tensorflow .python .framework import sparse_tensor
26+ from tensorflow .python .framework import tensor_spec
2527from tensorflow .python .framework import test_util
2628from tensorflow .python .ops import sparse_ops
2729from tensorflow .python .ops .ragged import ragged_array_ops
@@ -358,6 +360,16 @@ def testRaggedCrossLargeBatch(self):
358360 dense_const ([[2 ], [3 ]])],
359361 exception = (ValueError , errors .InvalidArgumentError ),
360362 message = 'inputs must all have the same batch dimension size' ),
363+ dict (
364+ testcase_name = '3DDenseTensor' ,
365+ inputs = [dense_const ([[[1 ]]])],
366+ exception = (ValueError , errors .InvalidArgumentError ),
367+ message = 'tf.ragged.cross only supports inputs with rank=2' ),
368+ dict (
369+ testcase_name = '0DDenseTensor' ,
370+ inputs = [dense_const (1 )],
371+ exception = (ValueError , errors .InvalidArgumentError ),
372+ message = 'tf.ragged.cross only supports inputs with rank=2' ),
361373 ])
362374 def testStaticError (self , inputs , exception = ValueError , message = None ):
363375 with self .assertRaisesRegex (exception , message ):
@@ -368,17 +380,36 @@ def testStaticError(self, inputs, exception=ValueError, message=None):
368380 testcase_name = '3DRaggedTensor' ,
369381 inputs = [ragged_const ([[[1 ]]], ragged_rank = 1 )],
370382 message = 'tf.ragged.cross only supports inputs with rank=2' ),
383+ dict (
384+ testcase_name = '0DDenseTensor' ,
385+ inputs = [dense_const (1 )],
386+ signature = [[tensor_spec .TensorSpec (None , dtypes .int32 )]],
387+ exception = (ValueError , errors .InvalidArgumentError ),
388+ message = 'tf.ragged.cross only supports inputs with rank=2' ),
389+ dict (
390+ testcase_name = '1DDenseTensor' ,
391+ inputs = [dense_const ([1 ])],
392+ signature = [[tensor_spec .TensorSpec (None , dtypes .int32 )]],
393+ exception = (ValueError , errors .InvalidArgumentError ),
394+ message = 'tf.ragged.cross only supports inputs with rank=2' ),
371395 dict (
372396 testcase_name = '3DDenseTensor' ,
373397 inputs = [dense_const ([[[1 ]]])],
398+ signature = [[tensor_spec .TensorSpec (None , dtypes .int32 )]],
399+ exception = (ValueError , errors .InvalidArgumentError ),
374400 message = 'tf.ragged.cross only supports inputs with rank=2' ),
375401 ])
376402 def testRuntimeError (self ,
377403 inputs ,
378404 exception = errors .InvalidArgumentError ,
379- message = None ):
405+ message = None ,
406+ signature = None ):
407+ @def_function .function (input_signature = signature )
408+ def fn (x ):
409+ return ragged_array_ops .cross (x )
410+
380411 with self .assertRaisesRegex (exception , message ):
381- self .evaluate (ragged_array_ops . cross (inputs ))
412+ self .evaluate (fn (inputs ))
382413
383414 def _ragged_to_sparse (self , t ):
384415 if ragged_tensor .is_ragged (t ):
0 commit comments