-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dynamo] delete graph_out_{n} after restoring local vars #122658
Conversation
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and use it to restore the local vars. In addition to their own symbols, the local vars are kept alive by the symbol we created. This means that if the graph break is the last usage of one of the symbols, the symbol would still be kept alive upon graph resumption. This PR: delete the graph_out_{n} symbol after restoring local vars so the lifetime of the local vars is governed by themselves. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122658
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit daf4803 with merge base b2c496b (): FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and use it to restore the local vars. In addition to their own symbols, the local vars are kept alive by the symbol we created. This means that if the graph break is the last usage of one of the symbols, the symbol would still be kept alive upon graph resumption. This PR: delete the graph_out_{n} symbol after restoring local vars so the lifetime of the local vars is governed by themselves. ghstack-source-id: 8db5e3b77b026a1c767f27ab361f9ab3a9b2bb6d Pull Request resolved: #122658
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and use it to restore the local vars. In addition to their own symbols, the local vars are kept alive by the symbol we created. This means that if the graph break is the last usage of one of the symbols, the symbol would still be kept alive upon graph resumption. This PR: delete the graph_out_{n} symbol after restoring local vars so the lifetime of the local vars is governed by themselves. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and use it to restore the local vars. In addition to their own symbols, the local vars are kept alive by the symbol we created. This means that if the graph break is the last usage of one of the symbols, the symbol would still be kept alive upon graph resumption. This PR: delete the graph_out_{n} symbol after restoring local vars so the lifetime of the local vars is governed by themselves. ghstack-source-id: 4abd6a9c503b88635420271918d49fcd8ce4aadc Pull Request resolved: #122658
Yeah, this looks good. Feel free to undraft it. |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woo! Very nice!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
Can't merge closed PR #122658 |
follow up with a test to not regress? |
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and use it to restore the local vars. In addition to their own symbols, the local vars are kept alive by the symbol we created. This means that if the graph break is the last usage of one of the symbols, the symbol would still be kept alive upon graph resumption. This PR: delete the graph_out_{n} symbol after restoring local vars so the lifetime of the local vars is governed by themselves. ## Example Problem Tensor `b`'s last usage is in the graph break. However, it won't be deallocated until `bar()` completes. In the orignal issue report by @Yuzhen11, `b` is a large tensor and `bar()` is an expensive computation. ```python import torch def foo(a): return torch.mm(a, a) @torch._dynamo.disable() def graph_break_fn(a): ret = a.bfloat16() return ret def bar(c): return torch.mm(c, c) def fn(a): b = foo(a) c = graph_break_fn(b) # del b return bar(c) fn_compiled = torch.compile(fn, backend="eager") a = torch.randn(10000, 10000, device="cuda", requires_grad=True) fn_compiled(a).sum().backward() ``` Bytecode before this PR: ``` ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18 19 0 LOAD_GLOBAL 0 (foo) 2 LOAD_FAST 0 (a) 4 CALL_FUNCTION 1 6 STORE_FAST 1 (b) 20 8 LOAD_GLOBAL 1 (graph_break_fn) 10 LOAD_FAST 1 (b) 12 CALL_FUNCTION 1 14 STORE_FAST 2 (c) 22 16 LOAD_GLOBAL 2 (bar) 18 LOAD_FAST 2 (c) 20 CALL_FUNCTION 1 22 RETURN_VALUE MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18 18 0 LOAD_GLOBAL 3 (__compiled_fn_0) 2 LOAD_FAST 0 (a) 4 CALL_FUNCTION 1 6 STORE_FAST 3 (graph_out_0) 8 LOAD_GLOBAL 1 (graph_break_fn) 10 LOAD_FAST 3 (graph_out_0) 12 LOAD_CONST 1 (0) 14 BINARY_SUBSCR 20 16 CALL_FUNCTION 1 18 LOAD_GLOBAL 4 (__resume_at_14_1) 20 ROT_TWO 22 CALL_FUNCTION 1 24 RETURN_VALUE ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20 20 0 LOAD_FAST 0 (___stack0) 2 JUMP_ABSOLUTE 9 (to 18) 4 LOAD_GLOBAL 0 (foo) 6 LOAD_FAST 1 (a) 8 CALL_FUNCTION 1 10 STORE_FAST 2 (b) 12 LOAD_GLOBAL 1 (graph_break_fn) 14 LOAD_FAST 2 (b) 16 CALL_FUNCTION 1 >> 18 STORE_FAST 3 (c) 22 20 LOAD_GLOBAL 2 (bar) 22 LOAD_FAST 3 (c) 24 CALL_FUNCTION 1 26 RETURN_VALUE MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20 20 0 LOAD_GLOBAL 3 (__compiled_fn_2) 2 LOAD_FAST 0 (___stack0) 4 CALL_FUNCTION 1 6 UNPACK_SEQUENCE 1 8 RETURN_VALUE ``` Bytecode after this PR: ``` ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18 19 0 LOAD_GLOBAL 0 (foo) 2 LOAD_FAST 0 (a) 4 CALL_FUNCTION 1 6 STORE_FAST 1 (b) 20 8 LOAD_GLOBAL 1 (graph_break_fn) 10 LOAD_FAST 1 (b) 12 CALL_FUNCTION 1 14 STORE_FAST 2 (c) 22 16 LOAD_GLOBAL 2 (bar) 18 LOAD_FAST 2 (c) 20 CALL_FUNCTION 1 22 RETURN_VALUE MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18 18 0 LOAD_GLOBAL 3 (__compiled_fn_0) 2 LOAD_FAST 0 (a) 4 CALL_FUNCTION 1 6 STORE_FAST 3 (graph_out_0) 8 LOAD_GLOBAL 1 (graph_break_fn) 10 LOAD_FAST 3 (graph_out_0) 12 LOAD_CONST 1 (0) 14 BINARY_SUBSCR 16 DELETE_FAST 3 (graph_out_0) 20 18 CALL_FUNCTION 1 20 LOAD_GLOBAL 4 (__resume_at_14_1) 22 ROT_TWO 24 CALL_FUNCTION 1 26 RETURN_VALUE ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20 20 0 LOAD_FAST 0 (___stack0) 2 JUMP_ABSOLUTE 9 (to 18) 4 LOAD_GLOBAL 0 (foo) 6 LOAD_FAST 1 (a) 8 CALL_FUNCTION 1 10 STORE_FAST 2 (b) 12 LOAD_GLOBAL 1 (graph_break_fn) 14 LOAD_FAST 2 (b) 16 CALL_FUNCTION 1 >> 18 STORE_FAST 3 (c) 22 20 LOAD_GLOBAL 2 (bar) 22 LOAD_FAST 3 (c) 24 CALL_FUNCTION 1 26 RETURN_VALUE MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20 20 0 LOAD_GLOBAL 3 (__compiled_fn_2) 2 LOAD_FAST 0 (___stack0) 4 CALL_FUNCTION 1 6 UNPACK_SEQUENCE 1 8 RETURN_VALUE ``` Pull Request resolved: #122658 Approved by: https://github.com/jansel, https://github.com/anijain2305
Stack from ghstack (oldest at bottom):
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.
This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.
Example Problem
Tensor
b
's last usage is in the graph break. However, it won't be deallocated untilbar()
completes. In the orignal issue report by @Yuzhen11,b
is a large tensor andbar()
is an expensive computation.Bytecode before this PR:
Bytecode after this PR:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang