Skip to content
Permalink
Browse files Browse the repository at this point in the history
Remove OP_REQUIRES call from helper function.
Since `OP_REQUIRES` macro expands to a `return;` (among other), calling it in a helper function only ends the helper function's execution earlier, but the kernel will still run from start to end. Thus, all the expected validations are actually broken/useless as the code ploughs through the next crash anyway.

PiperOrigin-RevId: 369524386
Change-Id: I54f6cf9328445675ccc392e661b04336b229c9da
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Apr 20, 2021
1 parent 080f1d9 commit e6a7c7c
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
#include <numeric>
#include <vector>

#include "tensorflow/core/framework/op_requires.h"

#define EIGEN_USE_THREADS

#include "third_party/eigen3/Eigen/Core"
Expand Down Expand Up @@ -82,8 +84,8 @@ class CSRSparseCholeskyCPUOp : public OpKernel {

int64 num_rows;
int batch_size;
ValidateInputs(ctx, *input_matrix, input_permutation_indices, &batch_size,
&num_rows);
OP_REQUIRES_OK(ctx, ValidateInputs(*input_matrix, input_permutation_indices,
&batch_size, &num_rows));

// Allocate batch pointers.
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
Expand Down Expand Up @@ -226,49 +228,48 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
}

private:
void ValidateInputs(OpKernelContext* ctx,
const CSRSparseMatrix& sparse_matrix,
const Tensor& permutation_indices, int* batch_size,
int64* num_rows) {
OP_REQUIRES(ctx, sparse_matrix.dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
"Asked for a CSRSparseMatrix of type ",
DataTypeString(DataTypeToEnum<T>::value),
" but saw dtype: ", DataTypeString(sparse_matrix.dtype())));
Status ValidateInputs(const CSRSparseMatrix& sparse_matrix,
const Tensor& permutation_indices, int* batch_size,
int64* num_rows) {
if (sparse_matrix.dtype() != DataTypeToEnum<T>::value)
return errors::InvalidArgument(
"Asked for a CSRSparseMatrix of type ",
DataTypeString(DataTypeToEnum<T>::value),
" but saw dtype: ", DataTypeString(sparse_matrix.dtype()));

const Tensor& dense_shape = sparse_matrix.dense_shape();
const int rank = dense_shape.dim_size(0);
OP_REQUIRES(ctx, rank == 2 || rank == 3,
errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
"but dense_shape has size ", rank));
if (rank < 2 || rank > 3)
return errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
"but dense_shape has size ", rank);
const int row_dim = (rank == 2) ? 0 : 1;
auto dense_shape_vec = dense_shape.vec<int64>();
*num_rows = dense_shape_vec(row_dim);
const int64 num_cols = dense_shape_vec(row_dim + 1);
OP_REQUIRES(ctx, *num_rows == num_cols,
errors::InvalidArgument("sparse matrix must be square; got: ",
*num_rows, " != ", num_cols));
if (*num_rows != num_cols)
return errors::InvalidArgument(
"sparse matrix must be square; got: ", *num_rows, " != ", num_cols);
const TensorShape& perm_shape = permutation_indices.shape();
OP_REQUIRES(
ctx, perm_shape.dims() + 1 == rank,
errors::InvalidArgument(
"sparse matrix must have the same rank as permutation; got: ", rank,
" != ", perm_shape.dims(), " + 1."));
OP_REQUIRES(
ctx, perm_shape.dim_size(rank - 2) == *num_rows,
errors::InvalidArgument(
"permutation must have the same number of elements in each batch "
"as the number of rows in sparse matrix; got: ",
perm_shape.dim_size(rank - 2), " != ", *num_rows));
if (perm_shape.dims() + 1 != rank)
return errors::InvalidArgument(
"sparse matrix must have the same rank as permutation; got: ", rank,
" != ", perm_shape.dims(), " + 1.");
if (perm_shape.dim_size(rank - 2) != *num_rows)
return errors::InvalidArgument(
"permutation must have the same number of elements in each batch "
"as the number of rows in sparse matrix; got: ",
perm_shape.dim_size(rank - 2), " != ", *num_rows);

*batch_size = sparse_matrix.batch_size();
if (*batch_size > 1) {
OP_REQUIRES(
ctx, perm_shape.dim_size(0) == *batch_size,
errors::InvalidArgument("permutation must have the same batch size "
"as sparse matrix; got: ",
perm_shape.dim_size(0), " != ", *batch_size));
if (perm_shape.dim_size(0) != *batch_size)
return errors::InvalidArgument(
"permutation must have the same batch size "
"as sparse matrix; got: ",
perm_shape.dim_size(0), " != ", *batch_size);
}

return Status::OK();
}
};

Expand Down

0 comments on commit e6a7c7c

Please sign in to comment.