Skip to content

Commit

Permalink
Validate (and ensure validation sticks) inputs for `MatrixTriangularS…
Browse files Browse the repository at this point in the history
…olve`.

PiperOrigin-RevId: 370282444
Change-Id: Iaed61a0b0727cc42c830658b72eb69f785f48dc5
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Apr 24, 2021
1 parent 8cae746 commit 480641e
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ class BaseMatrixTriangularSolveOp : public OpKernel {
const Tensor& in1 = ctx->input(1);

ValidateInputTensors(ctx, in0, in1);
if (!ctx->status().ok()) {
return;
}

MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
Expand Down Expand Up @@ -230,13 +233,22 @@ class MatrixTriangularSolveOp
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
const auto in0_num_dims = in0.dims();
OP_REQUIRES(
ctx, in0.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
ctx, in0_num_dims >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims));

const auto in1_num_dims = in1.dims();
OP_REQUIRES(
ctx, in1.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
ctx, in1_num_dims >= 2,
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims));

const auto in0_last_dim = in0.dim_size(in0_num_dims - 1);
const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2);
OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim,
errors::InvalidArgument(
"In[0] matrices in the last dimensions must be square (",
in0_last_dim, " =/= ", in0_prev_dim, ")"));
}
};

Expand Down

0 comments on commit 480641e

Please sign in to comment.