Skip to content
38 changes: 38 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from PIL import Image


def path_to_tensor(filepath):
from numpy import array as to_numpy_array
return torch.from_numpy(to_numpy_array(Image.open(filepath)))


class Tester(unittest.TestCase):

def test_make_grid_not_inplace(self):
Expand Down Expand Up @@ -41,6 +46,39 @@ def test_normalize_in_make_grid(self):
self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')

# def test_bboxes_not_inplace(self):
# t = torch.rand(5, 3, 10, 10) * 255
# t_clone = t.clone()
#
# TODO: this doesn't work; we need to pass in bboxes
# utils.draw_bounding_bboxes(t, draw_labels=False)
# self.assertTrue(torch.equal(t, t_clone), 'draw_bounding_bboxes modified tensor in-place')
#
# utils.draw_bounding_bboxes(t, draw_labels=True)
# self.assertTrue(torch.equal(t, t_clone), 'draw_bounding_bboxes modified tensor in-place')

def test_bboxes(self):
from numpy import array as to_numpy_array

IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
inp_img_path = os.path.join(IMAGE_DIR, 'a4.png')
out_img_path = os.path.join(IMAGE_DIR, 'b5.png')

inp_img_pil = path_to_tensor(inp_img_path)
bboxes = ((1, 2, 10, 18), (4, 8, 9, 11))
# TODO: maybe write the rectangle programatically in this test instead of
# statically loading output?
out_img_pil = path_to_tensor(out_img_path)

self.assertTrue(
torch.equal(
utils.draw_bounding_bboxes(inp_img_pil, bboxes, draw_labels=False),
out_img_pil,
),
'draw_bounding_bboxes returned an incorrect result',
)

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
Expand Down
70 changes: 69 additions & 1 deletion torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, List, Tuple, Text, BinaryIO
from typing import Union, Optional, List, Tuple, Text, BinaryIO, Sequence, Dict
import io
import pathlib
import torch
Expand Down Expand Up @@ -128,3 +128,71 @@ def save_image(
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)


BBox = Tuple[int, int, int, int]
BBoxes = Sequence[BBox]
Color = Tuple[int, int, int]
DEFAULT_COLORS: Sequence[Color]


def draw_bounding_boxes(
image: torch.Tensor,
bboxes: Union[BBoxes, Dict[str, Sequence[BBox]]],
colors: Optional[Dict[str, Color]] = None,
draw_labels: bool = None,
width: int = 1,
) -> torch.Tensor:
# TODO: docstring

bboxes_is_seq = BBoxes.__instancecheck__(bboxes)
# bboxes_is_dict is Dict[str, Sequence[BBox]].__instancecheck__(bboxes)
bboxes_is_dict = not bboxes_is_seq

if bboxes_is_seq:
# TODO: raise better Errors
if colors is not None:
# can't pass custom colors if bboxes is a sequence
raise Error
if draw_labels is True:
# can't draw labels if bboxes is a sequence
raise Error

if draw_labels is None:
if bboxes_is_seq:
draw_labels = False
else: # BBoxes.__instancecheck__(Dict[str, Sequence[BBox]])
draw_labels = True

# colors: Union[Sequence[Color], Dict[str, Color]]
if colors is None:
# TODO: default to one of @pmeir's suggestions as a seq
colors_: Sequence[Color] = colors
else:
colors_: Dict[str, Color] = colors

from PIL import Image, ImageDraw
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(
1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
draw = ImageDraw.Draw(im)

if bboxes_is_dict:
if Sequence[Color].__instancecheck__(colors_):
# align the colors seq with the bbox classes
colors = dict(zip(sorted(bboxes.keys()), colors_))

for bbox_class, bbox in enumerate(bboxes.items()):
draw.rectangle(bbox, outline=colors_[bbox_class], width=width)
if draw_labels:
# TODO: this will probably overlap with the bbox
# hard-code in a margin for the label?
label_tl_x, label_tl_y, _, _ = bbox
draw.text((label_tl_x, label_tl_y), bbox_class)
else: # bboxes_is_seq
for i, bbox in enumerate(bboxes):
draw.rectangle(bbox, outline=colors_[i], width=width)

from numpy import array as to_numpy_array
return torch.from_numpy(to_numpy_array(im))