Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] delete graph_out_{n} after restoring local vars #122658

Closed
wants to merge 2 commits into from

Conversation

yifuwang
Copy link
Contributor

@yifuwang yifuwang commented Mar 25, 2024

Stack from ghstack (oldest at bottom):

At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

Example Problem

Tensor b's last usage is in the graph break. However, it won't be deallocated until bar() completes. In the orignal issue report by @Yuzhen11, b is a large tensor and bar() is an expensive computation.

import torch


def foo(a):
    return torch.mm(a, a)


@torch._dynamo.disable()
def graph_break_fn(a):
    ret = a.bfloat16()
    return ret


def bar(c):
    return torch.mm(c, c)


def fn(a):
    b = foo(a)
    c = graph_break_fn(b)
    # del b
    return bar(c)


fn_compiled = torch.compile(fn, backend="eager")
a = torch.randn(10000, 10000, device="cuda", requires_grad=True)

fn_compiled(a).sum().backward()

Bytecode before this PR:

ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18
 19           0 LOAD_GLOBAL              0 (foo)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               1 (b)

 20           8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (c)

 22          16 LOAD_GLOBAL              2 (bar)
             18 LOAD_FAST                2 (c)
             20 CALL_FUNCTION            1
             22 RETURN_VALUE


MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18
 18           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               3 (graph_out_0)
              8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                3 (graph_out_0)
             12 LOAD_CONST               1 (0)
             14 BINARY_SUBSCR

 20          16 CALL_FUNCTION            1
             18 LOAD_GLOBAL              4 (__resume_at_14_1)
             20 ROT_TWO
             22 CALL_FUNCTION            1
             24 RETURN_VALUE

ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_FAST                0 (___stack0)
              2 JUMP_ABSOLUTE            9 (to 18)
              4 LOAD_GLOBAL              0 (foo)
              6 LOAD_FAST                1 (a)
              8 CALL_FUNCTION            1
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              1 (graph_break_fn)
             14 LOAD_FAST                2 (b)
             16 CALL_FUNCTION            1
        >>   18 STORE_FAST               3 (c)

 22          20 LOAD_GLOBAL              2 (bar)
             22 LOAD_FAST                3 (c)
             24 CALL_FUNCTION            1
             26 RETURN_VALUE


MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_GLOBAL              3 (__compiled_fn_2)
              2 LOAD_FAST                0 (___stack0)
              4 CALL_FUNCTION            1
              6 UNPACK_SEQUENCE          1
              8 RETURN_VALUE

Bytecode after this PR:

ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18
 19           0 LOAD_GLOBAL              0 (foo)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               1 (b)

 20           8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (c)

 22          16 LOAD_GLOBAL              2 (bar)
             18 LOAD_FAST                2 (c)
             20 CALL_FUNCTION            1
             22 RETURN_VALUE


MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18
 18           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               3 (graph_out_0)
              8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                3 (graph_out_0)
             12 LOAD_CONST               1 (0)
             14 BINARY_SUBSCR
             16 DELETE_FAST              3 (graph_out_0)

 20          18 CALL_FUNCTION            1
             20 LOAD_GLOBAL              4 (__resume_at_14_1)
             22 ROT_TWO
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_FAST                0 (___stack0)
              2 JUMP_ABSOLUTE            9 (to 18)
              4 LOAD_GLOBAL              0 (foo)
              6 LOAD_FAST                1 (a)
              8 CALL_FUNCTION            1
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              1 (graph_break_fn)
             14 LOAD_FAST                2 (b)
             16 CALL_FUNCTION            1
        >>   18 STORE_FAST               3 (c)

 22          20 LOAD_GLOBAL              2 (bar)
             22 LOAD_FAST                3 (c)
             24 CALL_FUNCTION            1
             26 RETURN_VALUE


MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_GLOBAL              3 (__compiled_fn_2)
              2 LOAD_FAST                0 (___stack0)
              4 CALL_FUNCTION            1
              6 UNPACK_SEQUENCE          1
              8 RETURN_VALUE

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang

At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Mar 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122658

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit daf4803 with merge base b2c496b (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

yifuwang added a commit that referenced this pull request Mar 25, 2024
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

ghstack-source-id: 8db5e3b77b026a1c767f27ab361f9ab3a9b2bb6d
Pull Request resolved: #122658
@yifuwang yifuwang closed this Mar 26, 2024
@yifuwang yifuwang reopened this Mar 26, 2024
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
yifuwang added a commit that referenced this pull request Mar 26, 2024
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

ghstack-source-id: 4abd6a9c503b88635420271918d49fcd8ce4aadc
Pull Request resolved: #122658
@ezyang ezyang requested a review from jansel March 26, 2024 01:25
@ezyang
Copy link
Contributor

ezyang commented Mar 26, 2024

Yeah, this looks good. Feel free to undraft it.

@yifuwang yifuwang marked this pull request as ready for review March 26, 2024 02:56
@yifuwang yifuwang changed the title [RFC] delete graph_out_{n} after restoring local vars [dynamo] delete graph_out_{n} after restoring local vars Mar 26, 2024
@yifuwang
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 26, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woo! Very nice!

@yifuwang
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ezyang
Copy link
Contributor

ezyang commented Mar 27, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Can't merge closed PR #122658

@soumith
Copy link
Member

soumith commented Mar 27, 2024

follow up with a test to not regress?

pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

## Example Problem
Tensor `b`'s last usage is in the graph break. However, it won't be deallocated until `bar()` completes. In the orignal issue report by @Yuzhen11, `b` is a large tensor and `bar()` is an expensive computation.

```python
import torch

def foo(a):
    return torch.mm(a, a)

@torch._dynamo.disable()
def graph_break_fn(a):
    ret = a.bfloat16()
    return ret

def bar(c):
    return torch.mm(c, c)

def fn(a):
    b = foo(a)
    c = graph_break_fn(b)
    # del b
    return bar(c)

fn_compiled = torch.compile(fn, backend="eager")
a = torch.randn(10000, 10000, device="cuda", requires_grad=True)

fn_compiled(a).sum().backward()
```

Bytecode before this PR:
```
ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18
 19           0 LOAD_GLOBAL              0 (foo)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               1 (b)

 20           8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (c)

 22          16 LOAD_GLOBAL              2 (bar)
             18 LOAD_FAST                2 (c)
             20 CALL_FUNCTION            1
             22 RETURN_VALUE

MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18
 18           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               3 (graph_out_0)
              8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                3 (graph_out_0)
             12 LOAD_CONST               1 (0)
             14 BINARY_SUBSCR

 20          16 CALL_FUNCTION            1
             18 LOAD_GLOBAL              4 (__resume_at_14_1)
             20 ROT_TWO
             22 CALL_FUNCTION            1
             24 RETURN_VALUE

ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_FAST                0 (___stack0)
              2 JUMP_ABSOLUTE            9 (to 18)
              4 LOAD_GLOBAL              0 (foo)
              6 LOAD_FAST                1 (a)
              8 CALL_FUNCTION            1
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              1 (graph_break_fn)
             14 LOAD_FAST                2 (b)
             16 CALL_FUNCTION            1
        >>   18 STORE_FAST               3 (c)

 22          20 LOAD_GLOBAL              2 (bar)
             22 LOAD_FAST                3 (c)
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_GLOBAL              3 (__compiled_fn_2)
              2 LOAD_FAST                0 (___stack0)
              4 CALL_FUNCTION            1
              6 UNPACK_SEQUENCE          1
              8 RETURN_VALUE
```

Bytecode after this PR:
```
ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18
 19           0 LOAD_GLOBAL              0 (foo)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               1 (b)

 20           8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (c)

 22          16 LOAD_GLOBAL              2 (bar)
             18 LOAD_FAST                2 (c)
             20 CALL_FUNCTION            1
             22 RETURN_VALUE

MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18
 18           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               3 (graph_out_0)
              8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                3 (graph_out_0)
             12 LOAD_CONST               1 (0)
             14 BINARY_SUBSCR
             16 DELETE_FAST              3 (graph_out_0)

 20          18 CALL_FUNCTION            1
             20 LOAD_GLOBAL              4 (__resume_at_14_1)
             22 ROT_TWO
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_FAST                0 (___stack0)
              2 JUMP_ABSOLUTE            9 (to 18)
              4 LOAD_GLOBAL              0 (foo)
              6 LOAD_FAST                1 (a)
              8 CALL_FUNCTION            1
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              1 (graph_break_fn)
             14 LOAD_FAST                2 (b)
             16 CALL_FUNCTION            1
        >>   18 STORE_FAST               3 (c)

 22          20 LOAD_GLOBAL              2 (bar)
             22 LOAD_FAST                3 (c)
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_GLOBAL              3 (__compiled_fn_2)
              2 LOAD_FAST                0 (___stack0)
              4 CALL_FUNCTION            1
              6 UNPACK_SEQUENCE          1
              8 RETURN_VALUE

```

Pull Request resolved: #122658
Approved by: https://github.com/jansel, https://github.com/anijain2305
@github-actions github-actions bot deleted the gh/yifuwang/72/head branch April 27, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants