Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix null pointer exception in shape inference function when tf.ragged…
….cross() is called with invalid inputs.

PiperOrigin-RevId: 400045848
Change-Id: Ia65501583b85cf1ec14a252d83fbdd716817a516
  • Loading branch information
edloper authored and tensorflower-gardener committed Sep 30, 2021
1 parent e1cfe3c commit fa6b778
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/ops/ragged_array_ops.cc
Expand Up @@ -99,6 +99,13 @@ REGISTER_OP("RaggedCross")
int dense_start = num_ragged * 2 + num_sparse * 3;
for (int i = 0; i < dense_types.size(); ++i) {
ShapeHandle dense_input = c->input(i + dense_start);
int32 rank = c->Rank(dense_input);
if (rank == InferenceContext::kUnknownRank) {
continue;
} else if (rank != 2) {
return errors::InvalidArgument(
"tf.ragged.cross only supports inputs with rank=2");
}
int64_t batch_size = c->Value(c->Dim(dense_input, 0));
if (batch_size != InferenceContext::kUnknownDim) {
ShapeHandle row_splits = c->Vector(batch_size + 1);
Expand Down
35 changes: 33 additions & 2 deletions tensorflow/python/ops/ragged/ragged_cross_op_test.py
Expand Up @@ -18,10 +18,12 @@

import numpy as np

from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_array_ops
Expand Down Expand Up @@ -358,6 +360,16 @@ def testRaggedCrossLargeBatch(self):
dense_const([[2], [3]])],
exception=(ValueError, errors.InvalidArgumentError),
message='inputs must all have the same batch dimension size'),
dict(
testcase_name='3DDenseTensor',
inputs=[dense_const([[[1]]])],
exception=(ValueError, errors.InvalidArgumentError),
message='tf.ragged.cross only supports inputs with rank=2'),
dict(
testcase_name='0DDenseTensor',
inputs=[dense_const(1)],
exception=(ValueError, errors.InvalidArgumentError),
message='tf.ragged.cross only supports inputs with rank=2'),
])
def testStaticError(self, inputs, exception=ValueError, message=None):
with self.assertRaisesRegex(exception, message):
Expand All @@ -368,17 +380,36 @@ def testStaticError(self, inputs, exception=ValueError, message=None):
testcase_name='3DRaggedTensor',
inputs=[ragged_const([[[1]]], ragged_rank=1)],
message='tf.ragged.cross only supports inputs with rank=2'),
dict(
testcase_name='0DDenseTensor',
inputs=[dense_const(1)],
signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
exception=(ValueError, errors.InvalidArgumentError),
message='tf.ragged.cross only supports inputs with rank=2'),
dict(
testcase_name='1DDenseTensor',
inputs=[dense_const([1])],
signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
exception=(ValueError, errors.InvalidArgumentError),
message='tf.ragged.cross only supports inputs with rank=2'),
dict(
testcase_name='3DDenseTensor',
inputs=[dense_const([[[1]]])],
signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
exception=(ValueError, errors.InvalidArgumentError),
message='tf.ragged.cross only supports inputs with rank=2'),
])
def testRuntimeError(self,
inputs,
exception=errors.InvalidArgumentError,
message=None):
message=None,
signature=None):
@def_function.function(input_signature=signature)
def fn(x):
return ragged_array_ops.cross(x)

with self.assertRaisesRegex(exception, message):
self.evaluate(ragged_array_ops.cross(inputs))
self.evaluate(fn(inputs))

def _ragged_to_sparse(self, t):
if ragged_tensor.is_ragged(t):
Expand Down

0 comments on commit fa6b778

Please sign in to comment.