Skip to content
Permalink
Browse files Browse the repository at this point in the history
Ensure non-empty padding_value input to tf.raw_ops.MatrixDiagPartV2, …
…if a padding_value is input

PiperOrigin-RevId: 388314614
Change-Id: If0b51ad58d5d8543a6be6ce8f42ae4755c80d55f
  • Loading branch information
pak-laura authored and tensorflower-gardener committed Aug 2, 2021
1 parent 3b4351c commit 482da92
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tensorflow/core/kernels/linalg/matrix_diag_op.cc
Expand Up @@ -89,7 +89,10 @@ class MatrixDiagPartOp : public OpKernel {
upper_diag_index = diag_index.flat<int32>()(1);
}
}
padding_value = context->input(2).flat<T>()(0);
const Tensor& padding_in = context->input(2);
OP_REQUIRES(context, padding_in.NumElements() == 1,
errors::InvalidArgument("Padding must be scalar."));
padding_value = padding_in.flat<T>()(0);
}
const TensorShape& input_shape = input.shape();

Expand Down

0 comments on commit 482da92

Please sign in to comment.