Skip to content

Commit

Permalink
Fix OOB error when op input sizes do not match.
Browse files Browse the repository at this point in the history
In cases where op input sizes are specified as in
```
REGISTER_OP("DynamicStitch")
    .Input("indices: N * int32")
    .Input("data: N * T")
    .Output("merged: T")
    .Attr("N : int >= 1")
    .Attr("T : type")
    .SetShapeFn(DynamicStitchShapeFunction);
```
if differing number of inputs are provided (e.g. 3 for `indices` and 4 for `data`)
we can get a crash in the executor when parsing the inputs, even before the kernel
called.  Here we avoid this by checking the return code for the argument id and
exit early.

PiperOrigin-RevId: 478068540
  • Loading branch information
cantonios authored and vinila21 committed Oct 11, 2022
1 parent 9e52217 commit fff989e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
34 changes: 32 additions & 2 deletions tensorflow/core/common_runtime/eager/execute.cc
Expand Up @@ -289,8 +289,38 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
return {x, tensorflow::FingerprintCat64(a.high64, x)};
}

Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
Device** result) {
const KernelDef* GetKernelDef(const EagerOperation& op, const NodeDef* node_def,
const Device* op_device) {
if (node_def == nullptr || op_device == nullptr) return nullptr;
const KernelDef* kernel_def = nullptr;
Status s = FindKernelDef(DeviceType(op_device->device_type()), *node_def,
&kernel_def,
/*kernel_class_name=*/nullptr);
if (!s.ok()) return nullptr;
return kernel_def;
}

bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def,
const Device* op_device, const KernelDef* kernel_def,
const int port_id) {
if (op.is_function()) return false;
if (node_def == nullptr) return false;
if (kernel_def == nullptr || op_device == nullptr) return false;
const auto& host_memory_args = kernel_def->host_memory_arg();
const OpDef& op_def = OpRegistry::Global()->LookUp(op.Name())->op_def;
const int arg_id = OpPortIdToArgId(*node_def, op_def.input_arg(), port_id);
// Fail if argument ID not found.
if (arg_id < 0) {
return false;
}
return std::find(host_memory_args.begin(), host_memory_args.end(),
op_def.input_arg(arg_id).name()) != host_memory_args.end();
}

Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx,
const bool is_host_memory_arg,
TensorHandle* tensor_handle, Device** result) {
>>>>>>> f5381e0e10b (Fix OOB error when op input sizes do not match.)
Device* cpu_device = ctx.HostCPU();
string device_name;
if (tensor_handle->Type() != TensorHandle::LOCAL) {
Expand Down
Expand Up @@ -18,6 +18,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
Expand Down Expand Up @@ -308,6 +309,29 @@ def testHigherRankGPU(self):
for datum, grad in zip(data, self.evaluate(grads[3:])):
self.assertAllEqual(7.0 * self.evaluate(datum), grad)

@test_util.run_in_graph_and_eager_modes
def testMismatchedDataAndIndexListSizes(self):
indices = [
constant_op.constant([2]),
constant_op.constant([1]),
constant_op.constant([0]),
constant_op.constant([3]),
]
data = [
constant_op.constant([1.0]),
constant_op.constant([2.0]),
constant_op.constant([3.0]),
constant_op.constant([4.0])
]
with self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError),
"expected inputs .* do not match|List argument .* must match"):
self.evaluate(data_flow_ops.dynamic_stitch(indices[0:2], data))

with self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError),
"expected inputs .* do not match|List argument .* must match"):
self.evaluate(data_flow_ops.dynamic_stitch(indices, data[0:2]))

if __name__ == "__main__":
test.main()

0 comments on commit fff989e

Please sign in to comment.