diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 9ec18b8f99fdb0..5e747d436489b7 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -313,6 +313,10 @@ bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def, const auto& host_memory_args = kernel_def->host_memory_arg(); const OpDef& op_def = OpRegistry::Global()->LookUp(op.Name())->op_def; const int arg_id = OpPortIdToArgId(*node_def, op_def.input_arg(), port_id); + // Fail if argument ID not found. + if (arg_id < 0) { + return false; + } return std::find(host_memory_args.begin(), host_memory_args.end(), op_def.input_arg(arg_id).name()) != host_memory_args.end(); } diff --git a/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py index 28dc151aec6eb4..05e0cb6be2925a 100644 --- a/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/data_structures/dynamic_stitch_op_test.py @@ -18,6 +18,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops @@ -308,6 +309,29 @@ def testHigherRankGPU(self): for datum, grad in zip(data, self.evaluate(grads[3:])): self.assertAllEqual(7.0 * self.evaluate(datum), grad) + @test_util.run_in_graph_and_eager_modes + def testMismatchedDataAndIndexListSizes(self): + indices = [ + constant_op.constant([2]), + constant_op.constant([1]), + constant_op.constant([0]), + constant_op.constant([3]), + ] + data = [ + constant_op.constant([1.0]), + constant_op.constant([2.0]), + constant_op.constant([3.0]), + constant_op.constant([4.0]) + ] + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + "expected inputs .* do not match|List argument .* must match"): + self.evaluate(data_flow_ops.dynamic_stitch(indices[0:2], data)) + + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + "expected inputs .* do not match|List argument .* must match"): + self.evaluate(data_flow_ops.dynamic_stitch(indices, data[0:2])) if __name__ == "__main__": test.main()