Skip to content

Commit fa6b778

Browse files
edlopertensorflower-gardener
authored andcommitted
Fix null pointer exception in shape inference function when tf.ragged.cross() is called with invalid inputs.
PiperOrigin-RevId: 400045848 Change-Id: Ia65501583b85cf1ec14a252d83fbdd716817a516
1 parent e1cfe3c commit fa6b778

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

Diff for: tensorflow/core/ops/ragged_array_ops.cc

+7
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ REGISTER_OP("RaggedCross")
9999
int dense_start = num_ragged * 2 + num_sparse * 3;
100100
for (int i = 0; i < dense_types.size(); ++i) {
101101
ShapeHandle dense_input = c->input(i + dense_start);
102+
int32 rank = c->Rank(dense_input);
103+
if (rank == InferenceContext::kUnknownRank) {
104+
continue;
105+
} else if (rank != 2) {
106+
return errors::InvalidArgument(
107+
"tf.ragged.cross only supports inputs with rank=2");
108+
}
102109
int64_t batch_size = c->Value(c->Dim(dense_input, 0));
103110
if (batch_size != InferenceContext::kUnknownDim) {
104111
ShapeHandle row_splits = c->Vector(batch_size + 1);

Diff for: tensorflow/python/ops/ragged/ragged_cross_op_test.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919
import numpy as np
2020

21+
from tensorflow.python.eager import def_function
2122
from tensorflow.python.framework import dtypes
2223
from tensorflow.python.framework import errors
2324
from tensorflow.python.framework import ops
2425
from tensorflow.python.framework import sparse_tensor
26+
from tensorflow.python.framework import tensor_spec
2527
from tensorflow.python.framework import test_util
2628
from tensorflow.python.ops import sparse_ops
2729
from tensorflow.python.ops.ragged import ragged_array_ops
@@ -358,6 +360,16 @@ def testRaggedCrossLargeBatch(self):
358360
dense_const([[2], [3]])],
359361
exception=(ValueError, errors.InvalidArgumentError),
360362
message='inputs must all have the same batch dimension size'),
363+
dict(
364+
testcase_name='3DDenseTensor',
365+
inputs=[dense_const([[[1]]])],
366+
exception=(ValueError, errors.InvalidArgumentError),
367+
message='tf.ragged.cross only supports inputs with rank=2'),
368+
dict(
369+
testcase_name='0DDenseTensor',
370+
inputs=[dense_const(1)],
371+
exception=(ValueError, errors.InvalidArgumentError),
372+
message='tf.ragged.cross only supports inputs with rank=2'),
361373
])
362374
def testStaticError(self, inputs, exception=ValueError, message=None):
363375
with self.assertRaisesRegex(exception, message):
@@ -368,17 +380,36 @@ def testStaticError(self, inputs, exception=ValueError, message=None):
368380
testcase_name='3DRaggedTensor',
369381
inputs=[ragged_const([[[1]]], ragged_rank=1)],
370382
message='tf.ragged.cross only supports inputs with rank=2'),
383+
dict(
384+
testcase_name='0DDenseTensor',
385+
inputs=[dense_const(1)],
386+
signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
387+
exception=(ValueError, errors.InvalidArgumentError),
388+
message='tf.ragged.cross only supports inputs with rank=2'),
389+
dict(
390+
testcase_name='1DDenseTensor',
391+
inputs=[dense_const([1])],
392+
signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
393+
exception=(ValueError, errors.InvalidArgumentError),
394+
message='tf.ragged.cross only supports inputs with rank=2'),
371395
dict(
372396
testcase_name='3DDenseTensor',
373397
inputs=[dense_const([[[1]]])],
398+
signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
399+
exception=(ValueError, errors.InvalidArgumentError),
374400
message='tf.ragged.cross only supports inputs with rank=2'),
375401
])
376402
def testRuntimeError(self,
377403
inputs,
378404
exception=errors.InvalidArgumentError,
379-
message=None):
405+
message=None,
406+
signature=None):
407+
@def_function.function(input_signature=signature)
408+
def fn(x):
409+
return ragged_array_ops.cross(x)
410+
380411
with self.assertRaisesRegex(exception, message):
381-
self.evaluate(ragged_array_ops.cross(inputs))
412+
self.evaluate(fn(inputs))
382413

383414
def _ragged_to_sparse(self, t):
384415
if ragged_tensor.is_ragged(t):

0 commit comments

Comments
 (0)