diff --git a/torchvision/utils.py b/torchvision/utils.py index ea197d7386e..90657acee4f 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -180,6 +180,10 @@ def draw_bounding_boxes( raise ValueError("Pass individual images, not batches") ndarr = image.permute(1, 2, 0).numpy() + # allow single-channel-images + # shape: (1, H, W) with C = 1 + if ndarr.shape[-1] == 1: + ndarr = np.tile(ndarr, (1, 1, 3)) img_to_draw = Image.fromarray(ndarr) img_boxes = boxes.to(torch.int64).tolist()