Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")

path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
Expand Down
12 changes: 9 additions & 3 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,16 @@ def draw_bounding_boxes(

txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)

if colors is None:
if labels is None:
colors = _generate_color_palette(len(img_boxes))
else:
assert len(labels) == len(img_boxes)
label_color_map = dict(zip(labels, _generate_color_palette(len(labels))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we agree that len(labels) == len(img_boxes) Line 207 and 204 can be simplified.

colors = [label_color_map[label] for label in labels]

for i, bbox in enumerate(img_boxes):
if colors is None:
color = None
elif isinstance(colors, list):
if isinstance(colors, list):
color = colors[i]
else:
color = colors
Expand Down