From c4649e28882185ab6e5d4b70597288fa13286d37 Mon Sep 17 00:00:00 2001 From: TensorFlower Gardener Date: Mon, 10 Jan 2022 11:35:05 -0800 Subject: [PATCH] Merge pull request #53695 from yongtang:53660-tf.sparse.split-crash PiperOrigin-RevId: 420811652 Change-Id: I83742482770ba0bf7c3ccd57508c40fb9cdbe2f7 --- tensorflow/core/kernels/sparse_split_op.cc | 10 ++++++++-- tensorflow/python/kernel_tests/sparse_split_op_test.py | 9 +++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc index 3b88a8ca2bf6ee..18f787b71b9546 100644 --- a/tensorflow/core/kernels/sparse_split_op.cc +++ b/tensorflow/core/kernels/sparse_split_op.cc @@ -30,11 +30,16 @@ class SparseSplitOp : public OpKernel { } void Compute(OpKernelContext* context) override { - const int64 axis_input = context->input(0).scalar()(); + const Tensor& input_axis = context->input(0); const Tensor& input_indices = context->input(1); const Tensor& input_values = context->input(2); const Tensor& input_shape = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_axis.shape()), + errors::InvalidArgument( + "Input axis should be a scalar but received shape ", + input_axis.shape().DebugString()), + done); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), errors::InvalidArgument( "Input indices should be a matrix but received shape ", @@ -48,7 +53,8 @@ class SparseSplitOp : public OpKernel { "Input shape should be a vector but received shape ", input_shape.shape().DebugString())); - const int64 input_rank = input_shape.vec().size(); + const int64 axis_input = input_axis.scalar()(); + const int64 input_rank = input_shape.vec().size(); const int64 axis = (axis_input < 0) ? input_rank + axis_input : axis_input; OP_REQUIRES( diff --git a/tensorflow/python/kernel_tests/sparse_split_op_test.py b/tensorflow/python/kernel_tests/sparse_split_op_test.py index 31ef1129f1319c..b5cc3f02d9d4c7 100644 --- a/tensorflow/python/kernel_tests/sparse_split_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_split_op_test.py @@ -257,6 +257,15 @@ def testArgumentErrors(self): with self.assertRaisesRegex(ValueError, 'axis is required'): sparse_ops.sparse_split(num_split=2, sp_input=1) + def testInvalidArgumentError(self): + # Test case for GitHub issue 53660. + axis = [1, 2] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r'axis should be a scalar'): + self.evaluate( + sparse_ops.sparse_split( + sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis)) + if __name__ == '__main__': test.main()