Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add missing validation to matrix_diag_op.cc
PiperOrigin-RevId: 387923533
Change-Id: Idfffeb328d5f9c6748d992d28a56d6e9e45103a0
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 31, 2021
1 parent ff88940 commit f2a673b
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/linalg/matrix_diag_op.cc
Expand Up @@ -73,6 +73,9 @@ class MatrixDiagPartOp : public OpKernel {
errors::InvalidArgument(
"diag_index must be a scalar or vector, received shape: ",
diag_index.shape().DebugString()));
OP_REQUIRES(context, diag_index.NumElements() > 0,
errors::InvalidArgument(
"Expected diag_index to have at least 1 element"));
lower_diag_index = diag_index.flat<int32>()(0);
upper_diag_index = lower_diag_index;
if (TensorShapeUtils::IsVector(diag_index.shape())) {
Expand Down Expand Up @@ -179,6 +182,9 @@ class MatrixDiagOp : public OpKernel {
errors::InvalidArgument(
"diag_index must be a scalar or vector, received shape: ",
diag_index.shape().DebugString()));
OP_REQUIRES(context, diag_index.NumElements() > 0,
errors::InvalidArgument(
"Expected diag_index to have at least 1 element"));
lower_diag_index = diag_index.flat<int32>()(0);
upper_diag_index = lower_diag_index;
if (TensorShapeUtils::IsVector(diag_index.shape())) {
Expand Down

0 comments on commit f2a673b

Please sign in to comment.