diff --git a/torchvision/utils.py b/torchvision/utils.py index 3ecc3aa3873..91ae37bb9e1 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -346,7 +346,7 @@ def draw_keypoints( if keypoints.ndim != 3: raise ValueError("keypoints must be of shape (num_instances, K, 2)") - ndarr = image.permute(1, 2, 0).numpy() + ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) img_kpts = keypoints.to(torch.int64).tolist()