Skip to content
Permalink
Browse files Browse the repository at this point in the history
Check correct input/output scalar types for LinearAlgebraOp.
If calling the `tf.raw_ops` versions, it's possible to provide bad output scalar
types (e.g. for `Eig`), causing a failing check when trying to actually compute
outputs.  Here we add an appropriate input validation check.

NOTE: no test is added because it's impossible to trigger these using the
official public API.
PiperOrigin-RevId: 461637318
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 18, 2022
1 parent 146f252 commit aed3691
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/linalg/linalg_ops_common.cc
Expand Up @@ -15,14 +15,17 @@ limitations under the License.

#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"

#include <initializer_list>
#include <utility>

#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

Expand Down Expand Up @@ -152,6 +155,10 @@ void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
input_matrix_shapes->emplace_back(
std::initializer_list<int64_t>({num_rows, num_cols}));
inputs->emplace_back(&in);
OP_REQUIRES(
context, in.dtype() == DataTypeToEnum<InputScalar>::v(),
errors::InvalidArgument("Invalid input dtype ", in.dtype(), " vs ",
DataTypeToEnum<InputScalar>::v()));
}
// Have the derived class validate that the inputs are as expected.
ValidateInputMatrixShapes(context, *input_matrix_shapes);
Expand Down Expand Up @@ -212,6 +219,11 @@ void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
OP_REQUIRES_OK(context, context->allocate_output(
output_idx, output_tensor_shape, &out));
}
OP_REQUIRES(
context, out->dtype() == DataTypeToEnum<OutputScalar>::v(),
errors::InvalidArgument("Invalid output dtype ", out->dtype(), " vs ",
DataTypeToEnum<OutputScalar>::v()));

outputs->emplace_back(out);
}
}
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/linalg/eig_op_test.py
Expand Up @@ -18,8 +18,10 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
Expand Down Expand Up @@ -88,6 +90,16 @@ def testMatrixThatFailsWhenFlushingDenormsToZero(self):
self.assertAllClose(matrix,
np.matmul(np.matmul(v, np.diag(e)), v.transpose()))

def testMismatchedDtypes(self):
tensor = constant_op.constant([[0, 1], [2, 3]], dtype=dtypes_lib.float32)
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"Invalid output dtype"):
self.evaluate(
gen_linalg_ops.eig(
input=tensor,
Tout=dtypes_lib.complex128, # Expected dtype: complex64.
compute_v=True))


def SortEigenValues(e):
perm = np.argsort(e.real + e.imag, -1)
Expand Down

0 comments on commit aed3691

Please sign in to comment.