Skip to content

Fix embedding jvp support by making embedding_renorm ignore forward mode AD #78560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

samdow
Copy link
Contributor

@samdow samdow commented May 31, 2022

On functorch, we started seeing embedding forward mode fail. From looking at it, we figured out that recently embedding got forward mode support enabled and then doing forward mode with embedding and max_norm doesn't work with gradcheck, so it's not checked.

What was happening is that embedding_renorm was setting torch.no_grad() which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the embedding_renorm call. This makes it so that we don't use forward mode during the embedding_renorm call

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 31, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 25309c2 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@zou3519 zou3519 requested review from soulitzer and zou3519 May 31, 2022 19:21
with torch.no_grad():
torch.embedding_renorm_(weight, input, max_norm, norm_type)
set_fwd_grad_enabled(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has a context manager for enabling / disabling fwd grad been discussed before?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @albanD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion, Alban and I talked offline and this is a functorch-only function that I'm guessing worked locally through some dynamic linking or something?

This will change in a moment to use detach instead, though it might still be nice to have an equivalent for forward mode

@samdow
Copy link
Contributor Author

samdow commented May 31, 2022

Sorry again about the confusion, PR should be in final state of using detach and having tests. Locally, it fixes functorch and passes these tests (fwd ad one did fail before change to torch)

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

@samdow
Copy link
Contributor Author

samdow commented Jun 2, 2022

@pytorchbot merge

@github-actions
Copy link
Contributor

github-actions bot commented Jun 2, 2022

Hey @samdow.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@samdow samdow added release notes: nn release notes category topic: bug fixes topic category labels Jun 2, 2022
@malfet
Copy link
Contributor

malfet commented Jun 2, 2022

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 2, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message

usage: @pytorchbot revert -m MESSAGE
                          [-c {nosignal,ignoredsignal,landrace,weird,ghfirst}]

```.
Try `@pytorchbot help` for more info.

@malfet
Copy link
Contributor

malfet commented Jun 2, 2022

@pytorchmergebot revert -m "broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9" -c ignoredsignal

@malfet
Copy link
Contributor

malfet commented Jun 2, 2022

@pytorchbot revert -m "broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9" -c ignoredsignal

pytorchmergebot added a commit that referenced this pull request Jun 2, 2022
@samdow samdow reopened this Jun 2, 2022
facebook-github-bot pushed a commit that referenced this pull request Jun 2, 2022
…ode AD (#78560) (#78560)

Summary:
On functorch, we started seeing [embedding forward mode fail](pytorch/functorch#816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked.

What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call

Pull Request resolved: #78560
Approved by: https://github.com/soulitzer, https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9

Reviewed By: b0noI

Differential Revision: D36865377

fbshipit-source-id: acd33447f1213136019644bab6874ba38ff3836f
facebook-github-bot pushed a commit that referenced this pull request Jun 3, 2022
…orward mode AD (#78560)"

Summary:
This reverts commit ce7c7bb.

Reverted #78560 on behalf of https://github.com/malfet due to broke XLA (on CI and trunk), see https://hud.pytorch.org/pytorch/pytorch/commit/ce7c7bb2a9b21c8a05eedc6b43b277654064a1c9

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d578197747db4efef7771bb2304d6bb54638eb73

Reviewed By: b0noI

Differential Revision: D36882233

fbshipit-source-id: 1189a3da13824959f725cdcb784a01f3402faa78
@samdow
Copy link
Contributor Author

samdow commented Jun 3, 2022

@pytorchbot merge

facebook-github-bot pushed a commit that referenced this pull request Jun 3, 2022
…ode AD (#78560) (#78560)

Summary:
On functorch, we started seeing [embedding forward mode fail](pytorch/functorch#816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked.

What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call

Pull Request resolved: #78560
Approved by: https://github.com/soulitzer, https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b7cb4eae6b16ce312fb5e285aa82f227be6549a6

Reviewed By: b0noI

Differential Revision: D36907413

Pulled By: samdow

fbshipit-source-id: 52b6675ccb71287daed41f73a2d9f86b967d99b3
@github-actions github-actions bot deleted the fix_embedding branch February 16, 2024 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants