-
Couldn't load subscription status.
- Fork 560
Skip execution for extract_compiled_graph #4612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Where can I read and learn the |
|
Ah, one thing I realized is we actually need to materializing some of the tensor, at least those one we save for future execution. Let me tweak it a bit. |
If you search Aliasing in XLA you should see some very very brief doc.. I mostly just read codes. |
|
@shunting314 @alanwaketan This pr is ready for review. |
Is it easy to have a unit test to show the buffer aliasing related issue? |
|
yea, actually |
|
TBH, I don't quite understand what the 'buffer aliasing' issue fixed by this PR is. Or more specifically, how does it cause crashes. A simple unit test should help. But it's not blocking, I think we can merge this PR first as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
|
@shunting314 I should be more specified. The buffer aliasing is only enabled during a xla/torch_xla/csrc/xla_graph_executor.cpp Lines 1104 to 1105 in 574cfda
In this pr I set the xla/test/test_input_output_aliases.py Lines 26 to 34 in 574cfda
You see that we perform two inplace operation on
The bug we had prior this pr is we actually saved some input for each dynamo graph. However when we do |
* Fix HLO dumping (#4619) * Update TF pin to 2/13 (#4615) * Update TF pin to 2/13 * Fix pinned commit * Add patch to revert TF 3e24055 * Add comment to new patch * Fix patch command in TPU CI (#4623) * Skip execution for extract_compiled_graph (#4612) * Only warm up cache for dynamo extract_graph step * Add missing config * Make sure warm up run does not cause place holder to be created * Fix tests * Disable failing `test_operations.py` tests on TPU (#4622) * Disable `test_operations.py` tests failing on TPU * Add to TPU CI * Bazel (#4528) * Replace tensorflow with a bazel external repository * Basic migration to bazel for xla_client. * Revert to blob * Add vscode config. * Update newlines * Merge with pjrt client test build changes. * Migrate tests to new build * Format test and plugin * Order imports * Conditionally apply tf patches; apply pt patches always. * Format python * configure formatters * Mirror TF pin update an fixes in bazel. * Support local and sandboxed build based on flags * Add cloud cache URLs for llvm. * Merge with upstream * Update TF pin * Fix patching regression * Revert "Bazel (#4528)" (#4631) This reverts commit 3a90f5a. --------- Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Co-authored-by: Will Cromar <wcromar@google.com> Co-authored-by: stgpetrovic <stgpetrovic@gmail.com>
* Only warm up cache for dynamo extract_graph step * Add missing config * Make sure warm up run does not cause place holder to be created * Fix tests
* Only warm up cache for dynamo extract_graph step * Add missing config * Make sure warm up run does not cause place holder to be created * Fix tests
* Only warm up cache for dynamo extract_graph step * Add missing config * Make sure warm up run does not cause place holder to be created * Fix tests
* Only warm up cache for dynamo extract_graph step * Add missing config * Make sure warm up run does not cause place holder to be created * Fix tests
Currently we use
xla_sync_multito force a execution which will warm up the cache for future dynamo runs, this is not ideal becausexla_sync_multiis considered as amark_steplike sync and will trigger buffer aliasing. However we actually saved some XLA data when we cache the dynamo graph. Those buffer might be aliased during thexla_sync_multiand cause random crash when we actually execute the dynamo graphconfirmed that this change does not regress speed on TPU when running torchbench.
FYI @wconstab @will-cromar