Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix crash in MatrixSolve when inputs have different batch dimensions.
Before, the process would crash or certain elements would be silently ignored. Now an InvalidArgument is raised.

PiperOrigin-RevId: 384844020
Change-Id: Iba44417e383bdd0e1abc4012bfca83b2377dd335
  • Loading branch information
reedwm authored and tensorflower-gardener committed Jul 15, 2021
1 parent cd5a9ad commit 579261d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tensorflow/core/kernels/linalg/matrix_solve_op.cc
Expand Up @@ -143,15 +143,22 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
done);
OP_REQUIRES_ASYNC(
context, input.dim_size(ndims - 2) == n,
errors::InvalidArgument("Input matrices must be squares, got",
errors::InvalidArgument("Input matrices must be squares, got ",
input.dim_size(ndims - 2), " != ", n),
done);
OP_REQUIRES_ASYNC(context, rhs.dim_size(ndims - 2) == n,
errors::InvalidArgument(
"Input matrix and right-hand side must have the "
"same number of rows, got",
"same number of rows, got ",
n, " != ", rhs.dim_size(ndims - 2)),
done);
for (int dim = 0; dim < ndims - 2; dim++) {
OP_REQUIRES_ASYNC(
context, input.dim_size(dim) == rhs.dim_size(dim),
errors::InvalidArgument(
"All input tensors must have the same outer dimensions."),
done);
}

// Allocate output.
Tensor* output;
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/python/kernel_tests/matrix_solve_op_test.py
Expand Up @@ -112,6 +112,12 @@ def testWrongDimensions(self):
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
self.evaluate(linalg_ops.matrix_solve(matrix, rhs))

# The matrix and right-hand side should have the same batch dimensions
matrix = np.random.normal(size=(2, 6, 2, 2))
rhs = np.random.normal(size=(2, 3, 2, 2))
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
self.evaluate(linalg_ops.matrix_solve(matrix, rhs))

def testNotInvertible(self):
# The input should be invertible.
with self.assertRaisesOpError("Input matrix is not invertible."):
Expand Down

0 comments on commit 579261d

Please sign in to comment.