diff --git a/test/test_utils.py b/test/test_utils.py index 90d4fdf8552..727208ec16c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,5 @@ import os +import re import sys import tempfile from io import BytesIO @@ -168,6 +169,13 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong) +def test_draw_boxes_warning(): + img = torch.full((3, 100, 100), 255, dtype=torch.uint8) + + with pytest.warns(UserWarning, match=re.escape("Argument 'font_size' will be ignored since 'font' is not set.")): + utils.draw_bounding_boxes(img, boxes, font_size=11) + + @pytest.mark.parametrize( "colors", [ diff --git a/torchvision/utils.py b/torchvision/utils.py index 4737a047327..e82752ab28b 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -164,7 +164,7 @@ def draw_bounding_boxes( fill: Optional[bool] = False, width: int = 1, font: Optional[str] = None, - font_size: int = 10, + font_size: Optional[int] = None, ) -> torch.Tensor: """ @@ -223,6 +223,13 @@ def draw_bounding_boxes( colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] + if font is None: + if font_size is not None: + warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") + txt_font = ImageFont.load_default() + else: + txt_font = ImageFont.truetype(font=font, size=font_size or 10) + # Handle Grayscale images if image.size(0) == 1: image = torch.tile(image, (3, 1, 1)) @@ -236,8 +243,6 @@ def draw_bounding_boxes( else: draw = ImageDraw.Draw(img_to_draw) - txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) - for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] if fill: fill_color = color + (100,)