diff --git a/torchvision/utils.py b/torchvision/utils.py index 3a17a46e26e..3ecc3aa3873 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -189,7 +189,7 @@ def draw_bounding_boxes( if image.size(0) == 1: image = torch.tile(image, (3, 1, 1)) - ndarr = image.permute(1, 2, 0).numpy() + ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) img_boxes = boxes.to(torch.int64).tolist()