-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Fix embedding jvp support by making embedding_renorm ignore forward mode AD #78560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 25309c2 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
torch/nn/functional.py
Outdated
with torch.no_grad(): | ||
torch.embedding_renorm_(weight, input, max_norm, norm_type) | ||
set_fwd_grad_enabled(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has a context manager for enabling / disabling fwd grad been discussed before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @albanD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion, Alban and I talked offline and this is a functorch-only function that I'm guessing worked locally through some dynamic linking or something?
This will change in a moment to use detach instead, though it might still be nice to have an equivalent for forward mode
Sorry again about the confusion, PR should be in final state of using detach and having tests. Locally, it fixes functorch and passes these tests (fwd ad one did fail before change to torch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
@pytorchbot merge |
Hey @samdow. |
@pytorchbot revert his broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9 |
❌ 🤖 pytorchbot command failed:
|
@pytorchmergebot revert -m "broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9" -c ignoredsignal |
@pytorchbot revert -m "broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9" -c ignoredsignal |
…orward mode AD (#78560)" This reverts commit ce7c7bb. Reverted #78560 on behalf of https://github.com/malfet due to broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9
…ode AD (#78560) (#78560) Summary: On functorch, we started seeing [embedding forward mode fail](pytorch/functorch#816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked. What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call Pull Request resolved: #78560 Approved by: https://github.com/soulitzer, https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9 Reviewed By: b0noI Differential Revision: D36865377 fbshipit-source-id: acd33447f1213136019644bab6874ba38ff3836f
…orward mode AD (#78560)" Summary: This reverts commit ce7c7bb. Reverted #78560 on behalf of https://github.com/malfet due to broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d578197747db4efef7771bb2304d6bb54638eb73 Reviewed By: b0noI Differential Revision: D36882233 fbshipit-source-id: 1189a3da13824959f725cdcb784a01f3402faa78
@pytorchbot merge |
…ode AD (#78560) (#78560) Summary: On functorch, we started seeing [embedding forward mode fail](pytorch/functorch#816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked. What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call Pull Request resolved: #78560 Approved by: https://github.com/soulitzer, https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b7cb4eae6b16ce312fb5e285aa82f227be6549a6 Reviewed By: b0noI Differential Revision: D36907413 Pulled By: samdow fbshipit-source-id: 52b6675ccb71287daed41f73a2d9f86b967d99b3
On functorch, we started seeing embedding forward mode fail. From looking at it, we figured out that recently embedding got forward mode support enabled and then doing forward mode with embedding and max_norm doesn't work with gradcheck, so it's not checked.
What was happening is that
embedding_renorm
was settingtorch.no_grad()
which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during theembedding_renorm
call. This makes it so that we don't use forward mode during the embedding_renorm call