Skip to content

Commit

Permalink
Fix dtype bug in draw bounding boxes.
Browse files Browse the repository at this point in the history
Boxes always needs to be type `float`.

PiperOrigin-RevId: 461800676
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 19, 2022
1 parent 785d67a commit da0d65c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/image/draw_bounding_box_op.cc
Expand Up @@ -119,7 +119,7 @@ class DrawBoundingBoxesOp : public OpKernel {

for (int64_t b = 0; b < batch_size; ++b) {
const int64_t num_boxes = boxes.dim_size(1);
const auto tboxes = boxes.tensor<T, 3>();
const auto tboxes = boxes.tensor<float, 3>();
for (int64_t bb = 0; bb < num_boxes; ++bb) {
int64_t color_index = bb % color_table.size();
const int64_t min_box_row =
Expand Down
Expand Up @@ -50,11 +50,16 @@ def _fillBorder(self, image, color):
image[height - 1, 0:width, 0:depth] = color
return image

def _testDrawBoundingBoxColorCycling(self, img, colors=None):
def _testDrawBoundingBoxColorCycling(self,
img,
dtype=dtypes.float32,
colors=None):
"""Tests if cycling works appropriately.
Args:
img: 3-D numpy image on which to draw.
dtype: image dtype (float, half).
colors: color table.
"""
color_table = colors
if colors is None:
Expand Down Expand Up @@ -82,7 +87,7 @@ def _testDrawBoundingBoxColorCycling(self, img, colors=None):
bboxes = math_ops.cast(bboxes, dtypes.float32)
bboxes = array_ops.expand_dims(bboxes, 0)
image = ops.convert_to_tensor(image)
image = image_ops_impl.convert_image_dtype(image, dtypes.float32)
image = image_ops_impl.convert_image_dtype(image, dtype)
image = array_ops.expand_dims(image, 0)
image = image_ops.draw_bounding_boxes(image, bboxes, colors=colors)
with self.cached_session(use_gpu=False) as sess:
Expand Down Expand Up @@ -118,6 +123,14 @@ def testDrawBoundingBoxRGBAColorCyclingWithColors(self):
[0, 0, 0.5, 1]])
self._testDrawBoundingBoxColorCycling(image, colors=colors)

def testDrawBoundingBoxHalf(self):
"""Test if RGBA color cycling works correctly with provided colors."""
image = np.zeros([10, 10, 4], "float32")
colors = np.asarray([[0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], [0.5, 0, 0, 1],
[0, 0, 0.5, 1]])
self._testDrawBoundingBoxColorCycling(
image, dtype=dtypes.half, colors=colors)


if __name__ == "__main__":
test.main()

0 comments on commit da0d65c

Please sign in to comment.