Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix inference_mode with torch.compile #101219

Closed
wants to merge 9 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented May 11, 2023

It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented May 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101219

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a469373:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request May 11, 2023
ghstack-source-id: 96a941f43bbc85d927166fd541d45fb57dc416d0
Pull Request resolved: #101219
// E.g. when running torch.compile under inference mode, we need to make sure that
// for any inputs that were created outside of inference mode (so they are not inference tensors),
// then the functional wrappers that we wrap them with should also not be inference tensors.
version_counter_ = value_.unsafeGetTensorImpl()->version_counter();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this access to the version_counter raise an error on inference Tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked offline - we're not accessing the version counter, just straight up copying the struct onto the wrapper.

Also - we copy the dispatch keyset from the inner tensor onto the wrapper, so if the inner tensor has the Autograd dispatch key (because it was created outside of inference mode), then the wrapper will as well (even though it was created in inference mode).

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds ok!

It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2023
ghstack-source-id: 7c24b4ff106383a8840c86c6a24a57eb92a0676f
Pull Request resolved: #101219
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this!!! What a pain haha

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this!!! What a pain haha

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this!!! What a pain haha

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this!!! What a pain haha

@Chillee
Copy link
Contributor

Chillee commented May 15, 2023

Wow Ed is really thankful for this fix.

It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 17, 2023
ghstack-source-id: 5bc15b0c2ec4cae3db00f095e875fbb0bee21ddc
Pull Request resolved: #101219
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 17, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 17, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@PaliC
Copy link
Contributor

PaliC commented May 19, 2023

@pytorchbot revert -c "nosignal" -m "breaking inductor tests"
Add ciflow/inductor to run the tests on the pull request.

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@bdhirsh your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request May 19, 2023
This reverts commit 11f7ae1.

Reverted #101219 on behalf of https://github.com/PaliC due to breaking inductor tests ([comment](#101219 (comment)))
@bdhirsh bdhirsh reopened this May 19, 2023
@github-actions github-actions bot requested review from albanD and ezyang May 19, 2023 19:36
pytorchmergebot added a commit that referenced this pull request May 19, 2023
#100570)"

This reverts commit 1fabee3.

Reverted #100570 on behalf of https://github.com/PaliC due to breaking inductor tests along with #101219 ([comment](#100570 (comment)))
pytorchmergebot referenced this pull request May 19, 2023
Fixes #100977

This will hopefully fix this error (from [issue](#99616))

This PR fixes an internal model: we were running an inductor inference graph, but `torch.is_grad_enabled()` was True, causing us to error inside of the inference graph when we encountered an out= operator.

I haven't been able to create a smaller repro - before landing this, I want to create a smaller repro to convince myself of why we need to separate out these guards.

Pull Request resolved: #100570
Approved by: https://github.com/ezyang
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 22, 2023
ghstack-source-id: fc724bccbfeb8315dc90599b7bd5e8299a11a652
Pull Request resolved: #101219
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 23, 2023
ghstack-source-id: 4aef060e6ca8055b49b354d2f8f1ace49f962a01
Pull Request resolved: #101219
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 24, 2023
ghstack-source-id: 2061c74c08bd4a763a31ab973b1e113a639bd840
Pull Request resolved: #101219
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 24, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/416/head branch June 8, 2023 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: composability release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants