Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vmap support for torch.nn.functional.smooth_l1_loss #98357

Closed
wants to merge 7 commits into from

Conversation

yhl48
Copy link
Contributor

@yhl48 yhl48 commented Apr 4, 2023

Partially fixes #97246 and #97558.

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 4, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98357

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 03570dd:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yhl48 yhl48 changed the title Batch smooth l1 loss Add vmap support for smooth_l1_loss Apr 4, 2023
@yhl48 yhl48 marked this pull request as draft April 4, 2023 23:07
@yhl48 yhl48 changed the title Add vmap support for smooth_l1_loss Add vmap support for torch.nn.functional.smooth_l1_loss Apr 4, 2023
@@ -282,6 +303,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("nll_loss_backward", nll_loss_backward_decomposition);
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: Move this below the comment // mse_loss_backwards uses a decomposition for its batch rule (as it corresponds to mse_loss).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will reorder that in an updated commit.

@kshitij12345
Copy link
Collaborator

runtimeerror: aten::smooth_l1_loss_backward hit the vmap fallback which is currently disabled

We need to add batching rule for smooth_l1_loss_backward.

We can proceed in either of the way

  • You can either add it in this PR.
  • Keep the relevant failing test as xfail and pursue it in a separate PR.

@yhl48
Copy link
Contributor Author

yhl48 commented Apr 6, 2023

runtimeerror: aten::smooth_l1_loss_backward hit the vmap fallback which is currently disabled

We need to add batching rule for smooth_l1_loss_backward.

We can proceed in either of the way

  • You can either add it in this PR.
  • Keep the relevant failing test as xfail and pursue it in a separate PR.

I will add it in this PR, but probably need some pointer here. What I think should be done are:

  1. Implement a function smooth_l1_loss_backward
  2. Add that function to TORCH_LIBRARY_IMPL

Is there an easier way to do it? Perhaps something similar to mse_loss_backwards as commented in the code // mse_loss_backwards uses a decomposition for its batch rule ?

@kshitij12345

@kshitij12345
Copy link
Collaborator

Apologies for the delayed response.

For adding a decomposition, you'll need to add it in the following file.

@register_decomposition(aten.mse_loss_backward)
@pw_cast_for_opmath
def mse_loss_backward(
grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
):
norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
return norm * (input - target) * grad_output

And then register this decomposition with vmap, similar to this PR https://github.com/pytorch/functorch/pull/866/files

However, I would recommend doing that in a separate PR.

@yhl48
Copy link
Contributor Author

yhl48 commented Apr 11, 2023

@kshitij12345 Thanks! I haven't had the chance to wrap this PR up, but I will do it the soonest I can, and send a separate PR for smooth_l1_loss_backward

@yhl48 yhl48 marked this pull request as ready for review April 13, 2023 14:12
@yhl48
Copy link
Contributor Author

yhl48 commented Apr 13, 2023

@kshitij12345 done!

@dagitses dagitses added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 13, 2023
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Just minor changes related to naming of lambda arguments and cleaning up commented code :)

Thanks!

aten/src/ATen/functorch/BatchRulesLoss.cpp Outdated Show resolved Hide resolved
aten/src/ATen/functorch/BatchRulesLoss.cpp Outdated Show resolved Hide resolved
@yhl48
Copy link
Contributor Author

yhl48 commented Apr 14, 2023

Thanks, appreciate your help throughout!

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you @yhl48

@kshitij12345 kshitij12345 added the release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch label Apr 14, 2023
@kshitij12345
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 14, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZainRizvi pushed a commit that referenced this pull request Apr 19, 2023
pytorchmergebot pushed a commit that referenced this pull request Apr 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged merging open source release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

functorch roll-up issue for 2.1
5 participants