Skip to content

Commit

Permalink
Add rank checks to GenerateBoundingBoxProposals.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 479681553
  • Loading branch information
cantonios authored and tensorflow-jenkins committed Oct 21, 2022
1 parent ee897ca commit 26cbb6f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
16 changes: 16 additions & 0 deletions tensorflow/core/kernels/image/generate_box_proposals_op.cu.cc
Expand Up @@ -312,6 +312,22 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel {
const auto bbox_deltas = context->input(1);
const auto image_info = context->input(2);
const auto anchors = context->input(3);

OP_REQUIRES(context, scores.dims() == 4,
errors::InvalidArgument("`scores` must be rank 4 but is rank ",
scores.dims()));
OP_REQUIRES(
context, bbox_deltas.dims() == 4,
errors::InvalidArgument("`bbox_deltas` must be rank 4 but is rank ",
bbox_deltas.dims()));
OP_REQUIRES(
context, image_info.dims() == 2,
errors::InvalidArgument("`image_info` must be rank 2 but is rank ",
image_info.dims()));
OP_REQUIRES(context, anchors.dims() == 3,
errors::InvalidArgument("`anchors` must be rank 3 but is rank ",
anchors.dims()));

const auto num_images = scores.dim_size(0);
const auto num_anchors = scores.dim_size(3);
const auto height = scores.dim_size(1);
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/image_ops/BUILD
Expand Up @@ -102,7 +102,7 @@ tf_py_test(
],
)

tf_py_test(
cuda_py_test(
name = "draw_bounding_box_op_test",
size = "small",
srcs = ["draw_bounding_box_op_test.py"],
Expand Down
Expand Up @@ -16,8 +16,11 @@

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import image_ops_impl
Expand Down Expand Up @@ -131,6 +134,22 @@ def testDrawBoundingBoxHalf(self):
self._testDrawBoundingBoxColorCycling(
image, dtype=dtypes.half, colors=colors)

# generate_bound_box_proposals is only available on GPU.
@test_util.run_gpu_only()
def testGenerateBoundingBoxProposals(self):
# Op only exists on GPU.
with self.cached_session(use_gpu=True):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 4"):
scores = constant_op.constant(
value=[[[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]]])
self.evaluate(
image_ops.generate_bounding_box_proposals(
scores=scores,
bbox_deltas=[],
image_info=[],
anchors=[],
pre_nms_topn=1))

if __name__ == "__main__":
test.main()

0 comments on commit 26cbb6f

Please sign in to comment.