-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
draw_keypoints() float support #8276
draw_keypoints() float support #8276
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8276
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Unrelated FailuresAs of commit c9fd6ca with merge base c8c3839 (): NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @GsnMithra !
I made a few comments below but the PR looks great overall, so I took the liberty to address these comments myself.
test/test_utils.py
Outdated
keypoints_cp = keypoints.clone() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be used anywhere
keypoints_cp = keypoints.clone() |
test/test_utils.py
Outdated
@@ -247,6 +247,24 @@ def test_draw_segmentation_masks(colors, alpha, device): | |||
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0) | |||
|
|||
|
|||
def test_draw_keypoints_dtypes(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move that test down below, so that it is located next to the rest of the draw_keypoints
tests. Right now it's in the middle of the draw_segmentation_mask
tests which is a bit confusing.
test/test_utils.py
Outdated
@@ -247,6 +247,24 @@ def test_draw_segmentation_masks(colors, alpha, device): | |||
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0) | |||
|
|||
|
|||
def test_draw_keypoints_dtypes(): | |||
image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This image should not be just zeros, otherwise it will be easy to miss subtle bugs. This should be the same as for the other test i.e.:
image_uint8 = torch.full((3, 100, 100), 0, dtype=torch.uint8) | |
torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8) |
and in fact you'll see that there's a bug because the test will fail
torchvision/utils.py
Outdated
@@ -428,7 +432,7 @@ def draw_keypoints( | |||
width=width, | |||
) | |||
|
|||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | |||
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=original_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just calling .to
here won't scale the float images back down to [0, 1]
so we would end up with a flaot image in [0, 255]
(that's why the test would fail). It's best to just call to_dtype()
for both the uint8 <-> float
conversions.
Hey @NicolasHug, Thank you for pointing out my mistakes. While it might be second nature for you to find these bugs, I am still in the learning process. I apologize for making you go through my code once again. Thanks again. |
No problem at all @GsnMithra thank you for the PR |
Reviewed By: vmoens Differential Revision: D55062794 fbshipit-source-id: 1a9484e4959fef604153857cc7d4a6d7262cbea9 Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
Follow-up PR: #8150
Issue: #8138
Hey there!
I've added functionality to the draw_keypoints() method, allowing it to handle both uint8 and float32 image types.
I welcome any feedback you may have on these changes.
Thank you!