From b626b561fc58014d30fa93bb400db4fbc101e526 Mon Sep 17 00:00:00 2001 From: ABD-01 Date: Tue, 19 Oct 2021 17:47:16 +0000 Subject: [PATCH 1/3] Updated utils.py for generating random colors for bounding boxes --- torchvision/utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torchvision/utils.py b/torchvision/utils.py index a71e0f234b4..c03488c612a 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -199,10 +199,15 @@ def draw_bounding_boxes( txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) + if colors is None: + if labels is None: + colors = _generate_color_palette(len(img_boxes)) + else: + label_color_map = dict(zip(labels, _generate_color_palette(len(labels)))) + colors = [label_color_map[label] for label in labels] + for i, bbox in enumerate(img_boxes): - if colors is None: - color = None - elif isinstance(colors, list): + if isinstance(colors, list): color = colors[i] else: color = colors From 9dcd960ff7433be5448ff5123c31260dca0492e4 Mon Sep 17 00:00:00 2001 From: Muhammed Abdullah Date: Tue, 26 Oct 2021 06:29:14 +0000 Subject: [PATCH 2/3] Added colors=white in draw_boxes test --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 1439a974368..b08810895bf 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -122,7 +122,7 @@ def test_draw_boxes_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() boxes_cp = boxes.clone() - result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7) + result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png") if not os.path.exists(path): From 10c8ce86d999393497e6fb5eca3cae64db10c653 Mon Sep 17 00:00:00 2001 From: ABD-01 Date: Wed, 27 Oct 2021 18:01:50 +0000 Subject: [PATCH 3/3] Added Assertion --- torchvision/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/utils.py b/torchvision/utils.py index 8f52a456aac..c0ae4816abf 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -203,6 +203,7 @@ def draw_bounding_boxes( if labels is None: colors = _generate_color_palette(len(img_boxes)) else: + assert len(labels) == len(img_boxes) label_color_map = dict(zip(labels, _generate_color_palette(len(labels)))) colors = [label_color_map[label] for label in labels]