Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix missing sparse matrix crash.
Calling a sparse matrix op with no matrix currently causes a crash.  Here we check and
return a non-ok status.

PiperOrigin-RevId: 476379116
  • Loading branch information
cantonios authored and tensorflower-gardener committed Sep 23, 2022
1 parent a8a5513 commit f856d02
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/sparse/sparse_matrix.h
Expand Up @@ -25,10 +25,12 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {

Expand Down Expand Up @@ -633,6 +635,11 @@ template <typename T>
Status ExtractVariantFromInput(OpKernelContext* ctx, int index,
const T** value) {
const Tensor& input_t = ctx->input(index);
if (!TensorShapeUtils::IsScalar(input_t.shape())) {
return errors::InvalidArgument(
"Invalid input matrix: Shape must be rank 0 but is rank ",
input_t.dims());
}
const Variant& input_variant = input_t.scalar<Variant>()();
*value = input_variant.get<T>();
if (*value == nullptr) {
Expand Down
Expand Up @@ -1313,6 +1313,16 @@ def testOrderingAMD(self):
self.assertLess(cholesky_with_amd_nnz_value,
cholesky_without_ordering_nnz_value)

@test_util.run_in_graph_and_eager_modes
def testNoMatrixNoCrash(self):
# Round-about way of creating an empty variant tensor that works in both
# graph and eager modes.
no_matrix = array_ops.reshape(dense_to_csr_sparse_matrix([[0.0]]), [1])[0:0]
with self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError),
"(Invalid input matrix)|(Shape must be rank 0)"):
sparse_csr_matrix_ops.sparse_matrix_nnz(no_matrix)


class CSRSparseMatrixOpsBenchmark(test.Benchmark):

Expand Down

0 comments on commit f856d02

Please sign in to comment.