Skip to content
Permalink
Browse files Browse the repository at this point in the history
Validate MatrixDiagV{2,3} arguments to prevent breakage.
PiperOrigin-RevId: 369056033
Change-Id: Ic2018c297d3dd6f252dc1dd3667f1ed5cb1eaa42
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Apr 18, 2021
1 parent 4c4f420 commit a7116dd
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions tensorflow/core/kernels/linalg/matrix_diag_op.cc
Expand Up @@ -192,9 +192,22 @@ class MatrixDiagOp : public OpKernel {
upper_diag_index = diag_index.flat<int32>()(1);
}
}
num_rows = context->input(2).flat<int32>()(0);
num_cols = context->input(3).flat<int32>()(0);
padding_value = context->input(4).flat<T>()(0);

auto& num_rows_tensor = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_rows_tensor.shape()),
errors::InvalidArgument("num_rows must be a scalar"));
num_rows = num_rows_tensor.flat<int32>()(0);

auto& num_cols_tensor = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_cols_tensor.shape()),
errors::InvalidArgument("num_cols must be a scalar"));
num_cols = num_cols_tensor.flat<int32>()(0);

auto& padding_value_tensor = context->input(4);
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(padding_value_tensor.shape()),
errors::InvalidArgument("padding_value must be a scalar"));
padding_value = padding_value_tensor.flat<T>()(0);
}

// Size validations.
Expand Down

0 comments on commit a7116dd

Please sign in to comment.