Skip to content

Conversation

jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Feb 19, 2025

Stack from ghstack (oldest at bottom):

This PR intends to fix the cache related issues from #147405.
It does not handle the dynamo recompile case in process, because it does not introduce any extra guards. For FXGraphCache and AOTAutogradCache, we simply have to have the device context in the cache key.

Note that for any function that accepts tensor inputs, the device context is naturally already included in the cache key by the metadata of example inputs. However, for functions that return constants or have no arguments, the device context still needs to be in the cache key.

A more robust fix for this would be to have inductor generate device guards that are dynamic, instead of specialized. This would also help us share more cache artifacts.

I've added unit tests for FXGraphCache and AOTAutogradCache, both of which would fail without this change.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Differential Revision: D69875939

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Feb 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147464

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2adfd4d with merge base f63db62 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
# This device index is usually already encoded by the device of the inputs
# but fx graphs don't necessarily have tensor inputs
if torch.cuda.is_available():
self.default_cuda_device_index = torch.cuda.current_device()
Copy link
Contributor

@bdhirsh bdhirsh Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partially a PSA / question for @albanD - to what extent do you expect people to be using torch.accelerator.current_device() + PT2 today (and I guess... changing their device index from run to run)? James and I talked about it offline a bit - we'll need to do something similar for other accelerators in the long run to avoid accelerator-specific warm cache problems

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you should use torch.accelerator instead of cuda for these things. They lead to the same thing for the cuda device but you also get rocm/mtia/xpu support for free.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that:

if torch.cuda.is_available(), then torch.accelerator.is_available()?
And that torch.cuda.current_device() == torch.accelerator.current_device?

If so happy to change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it works in my tests, will run with it!

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm! Although we should fix the dynamo guard issue too :)

[ghstack-poisoned]
@jamesjwu
Copy link
Contributor Author

@jamesjwu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 19, 2025
[ghstack-poisoned]
@jamesjwu
Copy link
Contributor Author

@jamesjwu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

This PR intends to fix the cache related issues from #147405. 
It does *not* handle the dynamo recompile case in process, because it does not introduce any extra guards. For FXGraphCache and AOTAutogradCache, we simply have to have the device context in the cache key. 

Note that for any function that accepts tensor inputs, the device context is naturally already included in the cache key by the metadata of example inputs. However, for functions that return constants or have no arguments, the device context still needs to be in the cache key. 

A more robust fix for this would be to have inductor generate device guards that are dynamic, instead of specialized. This would also help us share more cache artifacts. 

I've added unit tests for FXGraphCache and AOTAutogradCache, both of which would fail without this change.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D69875939](https://our.internmc.facebook.com/intern/diff/D69875939)

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Feb 20, 2025
ghstack-source-id: ec82e6e
Pull Request resolved: #147464
@jamesjwu
Copy link
Contributor Author

@jamesjwu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jamesjwu
Copy link
Contributor Author

Rebase to rerun tests

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/jamesjwu/110/head branch March 27, 2025 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants