Skip to content
Merged
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,33 @@ def test_draw_boxes_with_coloured_label_backgrounds():
assert_equal(result, expected)


@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
def test_draw_boxes_with_coloured_label_text_boxes():
img = torch.full((3, 100, 100), 255, dtype=torch.uint8)
labels = ["a", "b", "c", "d"]
colors = ["green", "#FF00FF", (0, 255, 0), "red"]
label_colors = ["green", "red", (0, 255, 0), "#FF00FF"]
label_background_colors = ["white", "black", "yellow", "blue"]
result = utils.draw_bounding_boxes(
img,
boxes,
labels=labels,
colors=colors,
fill=True,
label_colors=label_colors,
label_background_colors=label_background_colors,
fill_labels=True,
)
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"assets",
"fakedata",
"draw_boxes_different_label_background_colors.png",
)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)


@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
def test_draw_rotated_boxes():
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
Expand Down
20 changes: 14 additions & 6 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
from PIL import __version__ as PILLOW_VERSION_STRING, Image, ImageColor, ImageDraw, ImageFont


__all__ = [
"_Image_fromarray",
"make_grid",
Expand Down Expand Up @@ -293,6 +292,7 @@ def draw_bounding_boxes(
font: Optional[str] = None,
font_size: Optional[int] = None,
label_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
label_background_colors: Optional[Union[list[Union[str, tuple[int, int, int]]], str, tuple[int, int, int]]] = None,
fill_labels: bool = False,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -320,7 +320,10 @@ def draw_bounding_boxes(
font_size (int): The requested font size in points.
label_colors (color or list of colors, optional): Colors for the label text. See the description of the
`colors` argument for details. Defaults to the same colors used for the boxes, or to black if ``fill_labels`` is True.
fill_labels (bool): If `True` fills the label background with specified box color (from the ``colors`` parameter). Default: False.
label_background_colors (color or list of colors, optional): Colors for the label text box fill. Defaults to the
same colors used for the boxes. Ignored when ``fill_labels`` is False.
fill_labels (bool): If `True` fills the label background with specified color (from the ``label_background_colors`` parameter,
or from the ``colors`` parameter if not specified). Default: False.

Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted.
Expand Down Expand Up @@ -362,6 +365,11 @@ def draw_bounding_boxes(
else:
label_colors = colors.copy() # type: ignore[assignment]

if fill_labels and label_background_colors:
label_background_colors = _parse_colors(label_background_colors, num_objects=num_boxes)
else:
label_background_colors = colors.copy() # type: ignore[assignment]

if font is None:
if font_size is not None:
warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.")
Expand All @@ -385,7 +393,7 @@ def draw_bounding_boxes(
else:
draw = _ImageDrawTV(img_to_draw)

for bbox, color, label, label_color in zip(img_boxes, colors, labels, label_colors): # type: ignore[arg-type]
for bbox, color, label, label_color, label_bg_color in zip(img_boxes, colors, labels, label_colors, label_background_colors): # type: ignore[arg-type]
draw_method = draw.oriented_rectangle if len(bbox) > 4 else draw.rectangle
fill_color = color + (100,) if fill else None
draw_method(bbox, width=width, outline=color, fill=fill_color)
Expand All @@ -396,7 +404,7 @@ def draw_bounding_boxes(
if fill_labels:
left, top, right, bottom = draw.textbbox((bbox[0] + margin, bbox[1] + margin), label, font=txt_font)
draw.rectangle(
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=color
(left - box_margin, top - box_margin, right + box_margin, bottom + box_margin), fill=label_bg_color
)
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=label_color, font=txt_font) # type: ignore[arg-type]

Expand Down Expand Up @@ -545,7 +553,7 @@ def draw_keypoints(
if visibility.shape != keypoints.shape[:-1]:
raise ValueError(
"keypoints and visibility must have the same dimensionality for num_instances and K. "
f"Got {visibility.shape = } and {keypoints.shape = }"
f"Got {visibility.shape=} and {keypoints.shape=}"
)

original_dtype = image.dtype
Expand Down Expand Up @@ -746,7 +754,7 @@ def _parse_colors(
f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
)
elif not isinstance(colors, (tuple, str)):
raise ValueError(f"`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
raise ValueError(f"colors must be a tuple or a string, or a list thereof, but got {colors}.")
elif isinstance(colors, tuple) and len(colors) != 3:
raise ValueError(f"If passed as tuple, colors should be an RGB triplet, but got {colors}.")
else: # colors specifies a single color for all objects
Expand Down