New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add vmap support for torch.nn.functional.smooth_l1_loss
#98357
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98357
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 03570dd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
95d48c2
to
6d2f9c4
Compare
6d2f9c4
to
7a13a9b
Compare
ecab6ac
to
b7817ea
Compare
smooth_l1_loss
torch.nn.functional.smooth_l1_loss
@@ -282,6 +303,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { | |||
m.impl("nll_loss_backward", nll_loss_backward_decomposition); | |||
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition); | |||
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule); | |||
VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: Move this below the comment // mse_loss_backwards uses a decomposition for its batch rule
(as it corresponds to mse_loss
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will reorder that in an updated commit.
We need to add batching rule for We can proceed in either of the way
|
I will add it in this PR, but probably need some pointer here. What I think should be done are:
Is there an easier way to do it? Perhaps something similar to |
Apologies for the delayed response. For adding a decomposition, you'll need to add it in the following file. pytorch/torch/_decomp/decompositions.py Lines 356 to 362 in 95621b3
And then register this decomposition with vmap, similar to this PR https://github.com/pytorch/functorch/pull/866/files However, I would recommend doing that in a separate PR. |
@kshitij12345 Thanks! I haven't had the chance to wrap this PR up, but I will do it the soonest I can, and send a separate PR for |
@kshitij12345 done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Just minor changes related to naming of lambda arguments and cleaning up commented code :)
Thanks!
Thanks, appreciate your help throughout! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you @yhl48
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Partially fixes #97246 and #97558. Pull Request resolved: #98357 Approved by: https://github.com/kshitij12345
Follow-up of #98357 Pull Request resolved: #99429 Approved by: https://github.com/kshitij12345, https://github.com/zou3519
Partially fixes #97246 and #97558.