diff --git a/test/assets/fakedata/draw_boxes_different_label_background_colors.png b/test/assets/fakedata/draw_boxes_different_label_background_colors.png new file mode 100644 index 00000000000..bf641e9a122 Binary files /dev/null and b/test/assets/fakedata/draw_boxes_different_label_background_colors.png differ diff --git a/test/test_utils.py b/test/test_utils.py index 8b6f357ce6e..afdf6738d2c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -166,6 +166,33 @@ def test_draw_boxes_with_coloured_label_backgrounds(): assert_equal(result, expected) +@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1") +def test_draw_boxes_with_coloured_label_text_boxes(): + img = torch.full((3, 100, 100), 255, dtype=torch.uint8) + labels = ["a", "b", "c", "d"] + colors = ["green", "#FF00FF", (0, 255, 0), "red"] + label_colors = ["green", "red", (0, 255, 0), "#FF00FF"] + label_background_colors = ["white", "black", "yellow", "blue"] + result = utils.draw_bounding_boxes( + img, + boxes, + labels=labels, + colors=colors, + fill=True, + label_colors=label_colors, + label_background_colors=label_background_colors, + fill_labels=True, + ) + path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "assets", + "fakedata", + "draw_boxes_different_label_background_colors.png", + ) + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + assert_equal(result, expected) + + @pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1") def test_draw_rotated_boxes(): img = torch.full((3, 500, 500), 255, dtype=torch.uint8) diff --git a/torchvision/utils.py b/torchvision/utils.py index 0d819ef8330..20534dec2f6 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -10,7 +10,6 @@ import torch from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont - __all__ = [ "_Image_fromarray", "make_grid", @@ -293,6 +292,7 @@ def draw_bounding_boxes( font: Optional[str] = None, font_size: Optional[int] = None, label_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None, + label_background_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None, fill_labels: bool = False, ) -> torch.Tensor: """ @@ -320,7 +320,10 @@ def draw_bounding_boxes( font_size (int): The requested font size in points. label_colors (color or list of colors, optional): Colors for the label text. See the description of the `colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True. - fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False. + label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the + same colors used for the boxes. Ignored when ``fill_labels`` is False. + fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter, + or from the ``colors`` parameter if not specified). Default: False. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. @@ -362,6 +365,11 @@ def draw_bounding_boxes( else: label_colors = colors.copy() # type: ignore[assignment] + if fill_labels and label_background_colors: + label_background_colors = _parse_colors(label_background_colors, num_objects=num_boxes) + else: + label_background_colors = colors.copy() # type: ignore[assignment] + if font is None: if font_size is not None: warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") @@ -385,7 +393,7 @@ def draw_bounding_boxes( else: draw = _ImageDrawTV(img_to_draw) - for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type] + for bbox, color, label, label_color, label_bg_color in zip(img_boxes, colors, labels, label_colors, label_background_colors): # type: ignore[arg-type] draw_method = draw.oriented_rectangle if len(bbox) > 4 else draw.rectangle fill_color = color + (100,) if fill else None draw_method(bbox, width=width, outline=color, fill=fill_color) @@ -396,7 +404,7 @@ def draw_bounding_boxes( if fill_labels: left, top, right, bottom = draw.textbbox((bbox[0] + margin, bbox[1] + margin), label, font=txt_font) draw.rectangle( - (left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=color + (left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=label_bg_color ) draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type] @@ -545,7 +553,7 @@ def draw_keypoints( if visibility.shape != keypoints.shape[:-1]: raise ValueError( "keypoints and visibility must have the same dimensionality for num_instances and K. " - f"Got {visibility.shape = } and {keypoints.shape = }" + f"Got {visibility.shape=} and {keypoints.shape=}" ) original_dtype = image.dtype @@ -746,7 +754,7 @@ def _parse_colors( f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}." ) elif not isinstance(colors, (tuple, str)): - raise ValueError(f"`colors` must be a tuple or a string, or a list thereof, but got {colors}.") + raise ValueError(f"colors must be a tuple or a string, or a list thereof, but got {colors}.") elif isinstance(colors, tuple) and len(colors) != 3: raise ValueError(f"If passed as tuple, colors should be an RGB triplet, but got {colors}.") else: # colors specifies a single color for all objects