From fff989eab57a66970fff8e05956f320c9310917d Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 30 Sep 2022 13:33:07 -0700 Subject: [PATCH] Fix OOB error when op input sizes do not match. In cases where op input sizes are specified as in ``` REGISTER_OP("DynamicStitch") .Input("indices: N * int32") .Input("data: N * T") .Output("merged: T") .Attr("N : int >= 1") .Attr("T : type") .SetShapeFn(DynamicStitchShapeFunction); ``` if differing number of inputs are provided (e.g. 3 for `indices` and 4 for `data`) we can get a crash in the executor when parsing the inputs, even before the kernel called. Here we avoid this by checking the return code for the argument id and exit early. PiperOrigin-RevId: 478068540 --- .../core/common_runtime/eager/execute.cc | 34 +++++++++++++++++-- .../data_structures/dynamic_stitch_op_test.py | 24 +++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 4ad2e2d60aa2af..c4edc9fe6af06a 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -289,8 +289,38 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, return {x, tensorflow::FingerprintCat64(a.high64, x)}; } -Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle, - Device** result) { +const KernelDef* GetKernelDef(const EagerOperation& op, const NodeDef* node_def, + const Device* op_device) { + if (node_def == nullptr || op_device == nullptr) return nullptr; + const KernelDef* kernel_def = nullptr; + Status s = FindKernelDef(DeviceType(op_device->device_type()), *node_def, + &kernel_def, + /*kernel_class_name=*/nullptr); + if (!s.ok()) return nullptr; + return kernel_def; +} + +bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def, + const Device* op_device, const KernelDef* kernel_def, + const int port_id) { + if (op.is_function()) return false; + if (node_def == nullptr) return false; + if (kernel_def == nullptr || op_device == nullptr) return false; + const auto& host_memory_args = kernel_def->host_memory_arg(); + const OpDef& op_def = OpRegistry::Global()->LookUp(op.Name())->op_def; + const int arg_id = OpPortIdToArgId(*node_def, op_def.input_arg(), port_id); + // Fail if argument ID not found. + if (arg_id < 0) { + return false; + } + return std::find(host_memory_args.begin(), host_memory_args.end(), + op_def.input_arg(arg_id).name()) != host_memory_args.end(); +} + +Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx, + const bool is_host_memory_arg, + TensorHandle* tensor_handle, Device** result) { +>>>>>>> f5381e0e10b (Fix OOB error when op input sizes do not match.) Device* cpu_device = ctx.HostCPU(); string device_name; if (tensor_handle->Type() != TensorHandle::LOCAL) { diff --git a/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py index 28dc151aec6eb4..05e0cb6be2925a 100644 --- a/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -308,6 +309,29 @@ def testHigherRankGPU(self): for datum, grad in zip(data, self.evaluate(grads[3:])): self.assertAllEqual(7.0 * self.evaluate(datum), grad) + @test_util.run_in_graph_and_eager_modes + def testMismatchedDataAndIndexListSizes(self): + indices = [ + constant_op.constant([2]), + constant_op.constant([1]), + constant_op.constant([0]), + constant_op.constant([3]), + ] + data = [ + constant_op.constant([1.0]), + constant_op.constant([2.0]), + constant_op.constant([3.0]), + constant_op.constant([4.0]) + ] + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + "expected inputs .* do not match|List argument .* must match"): + self.evaluate(data_flow_ops.dynamic_stitch(indices[0:2], data)) + + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + "expected inputs .* do not match|List argument .* must match"): + self.evaluate(data_flow_ops.dynamic_stitch(indices, data[0:2])) if __name__ == "__main__": test.main()