Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Context Manager for Disabling Multithreading in Backwards, use in aot autograd #86245

Closed
wants to merge 6 commits into from

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Oct 4, 2022

Stack from ghstack (oldest at bottom):

We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and FakeTensorMode getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 4, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86245

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 1 Pending

As of commit 67dd731:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Oct 4, 2022
… aot autograd

ghstack-source-id: e59889bd05a2df14235ca9849ebe42e44601eb0f
Pull Request resolved: #86245
torch/autograd/grad_mode.py Outdated Show resolved Hide resolved
torch/csrc/autograd/init.cpp Outdated Show resolved Hide resolved
torch/csrc/autograd/init.cpp Outdated Show resolved Hide resolved
@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2022

This looks basically fine but deferring to @albanD for final review.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

torch/csrc/autograd/init.cpp Outdated Show resolved Hide resolved
torch/autograd/grad_mode.py Outdated Show resolved Hide resolved
torch/autograd/grad_mode.py Outdated Show resolved Hide resolved
c10/core/AutogradState.h Show resolved Hide resolved
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 5, 2022
…rds, use in aot autograd"


We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. 


[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 5, 2022
… aot autograd

ghstack-source-id: a608e1cba5489ab9e28a9c5e511567f89bf827c1
Pull Request resolved: #86245
…rds, use in aot autograd"


We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. 


[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 5, 2022
… aot autograd

ghstack-source-id: 43cdfe558be6e25bc097c98e1858aa992619fb44
Pull Request resolved: #86245
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small error in doc, good to go otherwise!

@@ -268,7 +268,7 @@ Examples::
set_grad_enabled
is_grad_enabled
inference_mode
is_inference_mode_enabled
set_multithreading_enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be reverted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks

…rds, use in aot autograd"


We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. 


[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 5, 2022
… aot autograd

ghstack-source-id: c9480c096fe195b176079deeda239de7ca56d3b8
Pull Request resolved: #86245
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

…rds, use in aot autograd"


We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. 


[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 5, 2022
… aot autograd

ghstack-source-id: b22baf5bbc1d8b3e9ab80eb1ddabdb7306f3c19f
Pull Request resolved: #86245
@eellison eellison removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 5, 2022
@eellison
Copy link
Contributor Author

eellison commented Oct 5, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

…rds, use in aot autograd"


We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. 


[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/eellison/335/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/86245)

pytorchmergebot pushed a commit that referenced this pull request Oct 5, 2022
… aot autograd

ghstack-source-id: ba4a24c07c4028a214e19e37a4eea35ce34c57ae
Pull Request resolved: #86245
@eellison
Copy link
Contributor Author

eellison commented Oct 6, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link

github-actions bot commented Oct 6, 2022

Hey @eellison.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Oct 7, 2022
… aot autograd (#86245) (#86245)

Summary:
We were running into a few issues with running multithreaded backwards in aot_autograd: such as #86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity.

Pull Request resolved: #86245
Approved by: https://github.com/albanD, https://github.com/yf225

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d04889323e2bc0b7321b76e564292565c88b9a5e

Reviewed By: seemethere

Differential Revision: D40167028

Pulled By: seemethere

fbshipit-source-id: f427c71e528deaa494521a61fcbf789d1a964711
@facebook-github-bot facebook-github-bot deleted the gh/eellison/335/head branch June 8, 2023 16:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants