Skip to content

Commit

Permalink
Add shape checks to GPU TridiagonalMatMul.
Browse files Browse the repository at this point in the history
When given invalid shapes, the GPU TridiagonalMatMul op could read invalid or uninitialized GPU memory.

PiperOrigin-RevId: 401775483
Change-Id: Ib5500aeb8225e50d4ce790b06d2c34751f544ad8
  • Loading branch information
reedwm authored and tensorflower-gardener committed Oct 8, 2021
1 parent c36d5a1 commit 68422b2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tensorflow/core/kernels/linalg/tridiagonal_matmul_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class TridiagonalMatMulOpGpu : public OpKernel {
const Tensor& rhs = context->input(3);

const int ndims = rhs.dims();
OP_REQUIRES(
context, ndims >= 2,
errors::InvalidArgument("Input must have rank >= 2, but got ", ndims));
OP_REQUIRES_OK(context, ValidateInputTensor(superdiag, "superdiag", rhs));
OP_REQUIRES_OK(context, ValidateInputTensor(maindiag, "maindiag", rhs));
OP_REQUIRES_OK(context, ValidateInputTensor(subdiag, "subdiag", rhs));
int64 batch_size = 1;
for (int i = 0; i < ndims - 2; i++) {
batch_size *= rhs.dim_size(i);
Expand All @@ -85,6 +91,39 @@ class TridiagonalMatMulOpGpu : public OpKernel {
maindiag.flat<Scalar>().data(), subdiag.flat<Scalar>().data(),
rhs.flat<Scalar>().data(), output->flat<Scalar>().data()));
}

private:
Status ValidateInputTensor(const Tensor& tensor,
const std::string& tensor_name,
const Tensor& rhs) {
const int ndims = rhs.dims();
if (tensor.dims() != ndims) {
return errors::InvalidArgument(tensor_name,
" must have same rank as rhs, but got ",
tensor.dims(), " and ", ndims);
}
for (int i = 0; i < ndims - 2; i++) {
if (tensor.dim_size(i) != rhs.dim_size(i)) {
return errors::InvalidArgument(
tensor_name,
" must have same outer dimensions as rhs, but for index ", i,
", got ", tensor.dim_size(i), " and ", rhs.dim_size(i));
}
}
if (tensor.dim_size(ndims - 2) != 1) {
return errors::InvalidArgument(
tensor_name, "'s second-to-last dimension must be 1, but got ",
tensor.dim_size(ndims - 2));
}
if (tensor.dim_size(ndims - 1) != rhs.dim_size(ndims - 2)) {
return errors::InvalidArgument(tensor_name,
"'s last dimension size must be rhs's "
"second-to-last dimension size, but got ",
tensor.dim_size(ndims - 1), " and ",
rhs.dim_size(ndims - 2));
}
return Status::OK();
}
};

REGISTER_LINALG_OP_GPU("TridiagonalMatMul", (TridiagonalMatMulOpGpu<float>),
Expand Down
34 changes: 34 additions & 0 deletions tensorflow/python/kernel_tests/tridiagonal_matmul_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import numpy as np

from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
Expand Down Expand Up @@ -175,6 +178,37 @@ def testGradientComplexWithBatches(self):
rhs = self._randomComplexArray((b, m, n))
self._gradientTest(diags, rhs, dtype=dtypes.complex128)

def _testErrorWithShapesEager(self, exception_regex, superdiag_shape,
maindiag_shape, subdiag_shape, rhs_shape):
with context.eager_mode():
superdiag = array_ops.ones(superdiag_shape)
maindiag = array_ops.ones(maindiag_shape)
subdiag = array_ops.ones(subdiag_shape)
rhs = array_ops.ones(rhs_shape)
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
exception_regex):
linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs)

def testInvalidShapesEagerGpu(self):
if not test.is_gpu_available():
self.skipTest('Test requires GPU')
self._testErrorWithShapesEager('Input must have rank >= 2, but got ',
[2], [2], [2], [2])
self._testErrorWithShapesEager(
'superdiag must have same rank as rhs, but got 3 and 2',
[2, 1, 2], [2, 1], [2, 1], [2, 2])
self._testErrorWithShapesEager(
'maindiag must have same outer dimensions as rhs, but for index 0, got '
'3 and 2',
[2, 1, 2], [3, 1, 2], [2, 1, 2], [2, 2, 2])
self._testErrorWithShapesEager(
"subdiag's second-to-last dimension must be 1, but got 3",
[2, 1, 2], [2, 1, 2], [2, 3, 2], [2, 2, 2])
self._testErrorWithShapesEager(
"subdiag's last dimension size must be rhs's second-to-last dimension "
"size, but got 3 and 2",
[2, 1, 2], [2, 1, 2], [2, 1, 3], [2, 2, 2])

# Benchmark
class TridiagonalMatMulBenchmark(test.Benchmark):
sizes = [(100000, 1, 1), (1000000, 1, 1), (10000000, 1, 1), (100000, 10, 1),
Expand Down

0 comments on commit 68422b2

Please sign in to comment.