Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace node.meta source_fn with source_fn_stack #108595

Closed
wants to merge 1 commit into from
Closed

Conversation

ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Sep 5, 2023

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:

backend = EagerAndRecordGraphs()

@torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))

The graph captured is shown below:

class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_
        
        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)
        
    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add
            
    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond
            
        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub
                
        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub

the source_fn for inner cond, sin, cos will be a (name, target) tuple:

('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.

[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], 
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], 
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]

Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @ngimel

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 5, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108595

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit ce69b53 with merge base 40b83d9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@ydwu4 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Sep 6, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack


Test Plan: See added tests in test_higher_order_ops.py and modify existing test.

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 7, 2023

can you also add some examples for the format for the previous "source_fn" and the current "source_fn_stack" for both (one fn stack and two fn stack) in PR summary?

@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 7, 2023

can you also add some examples for the format for the previous "source_fn" and the current "source_fn_stack" for both (one fn stack and two fn stack) in PR summary?

Added in the PR summary

facebook-github-bot pushed a commit that referenced this pull request Sep 7, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack


Test Plan: See added tests in test_higher_order_ops.py and modify existing test.

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

@jerryzh168
Copy link
Contributor

('cond', <torch._ops.HigherOrderOperator object at xxx>)

thanks, can you add a print for the graph and subgraphs as well? and add some comments for the results?
e.g. for
('cond', <torch._ops.HigherOrderOperator object at xxx>)
is this (node_name, target)?

and for the source_fn_stack, why ('cond', <torch._ops.HigherOrderOperator object at xxx>) appeared multiple times?

@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 7, 2023

add a print for the graph and subgraphs, and some comments for the results?

Sure! See the updated summary.

why ('cond', <torch._ops.HigherOrderOperator object at xxx>) appeared multiple times?

I use a more complicated example where we have a nested cond so the operators will have the cond -> cond -> op stack as a list of (name, target) tuple.

@jerryzh168
Copy link
Contributor

thanks! makes sense

facebook-github-bot pushed a commit that referenced this pull request Sep 7, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack


Test Plan: See added tests in test_higher_order_ops.py and modify existing test.

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

@@ -465,12 +467,11 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
f"{k}:({v[0]},{self.serialize_operator(v[1])})"
for k, v in nn_module_stack.items()
]
ret["nn_module_stack"] = ";".join(nn_module_list)
ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are technically changing the IR, does ExportedProgram's version number need a bump? Or is the version number bump automatic (and in-line with the PyTorch version number?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are technically changing the IR, does ExportedProgram's version number need a bump

I thought about this too, but @zhxchen17 and @gmagogsfm 's opinion was that it's a minor change so we don't need a version bump

@@ -150,6 +150,7 @@ def speculate_subgraph(
graph_checkpoint,
checkpoint,
description,
source_target,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this after the *? Also, what is the type of this? (HigherOrderOpVariable ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! It's the .value of HigherOrderOpVariable, which is the target of the proxy that we created for the higherOrderOperator.

Copy link
Contributor

@zou3519 zou3519 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so this is just a string? The name "source_target" is not that descriptive, could we add a comment with what you said?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems reasonable, some questions, otherwise minor comments

facebook-github-bot pushed a commit that referenced this pull request Sep 18, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

facebook-github-bot pushed a commit that referenced this pull request Sep 18, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

facebook-github-bot pushed a commit that referenced this pull request Sep 18, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

facebook-github-bot pushed a commit that referenced this pull request Sep 20, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

facebook-github-bot pushed a commit that referenced this pull request Sep 27, 2023
Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

Summary:
X-link: pytorch/executorch#210

A resubmit of #108447. Copy over the descriptions:

This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```


Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48984986

facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Sep 28, 2023
Summary:
Pull Request resolved: #210

A resubmit of pytorch/pytorch#108447. Copy over the descriptions:

This is a follow-up of the discussion in pytorch/pytorch#108356, where we want to repalce source_fn with source_fn_stack

Before this PR, for the following example:
```python
backend = EagerAndRecordGraphs()

torch.compile(backend=backend, fullgraph=True)
def cond_f(pred, pred2, x, y):
    def true_fn(pred2, x, y):
        return x + y

    def false_fn(pred2, x, y):
        def true_fn2(x, y):
            return x.sin() - y.cos()

        def false_fn2(x, y):
            return x.cos() - y.sin()

        return control_flow.cond(pred2, true_fn2, false_fn2, (x, y))

    return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y))
```
The graph captured is shown below:
```python
class GraphModule(torch.nn.Module):
    def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
        l_pred_ = L_pred_
        l_pred2_ = L_pred2_
        l_x_ = L_x_
        l_y_ = L_y_

        cond_true_1 = self.cond_true_1
        cond_false_1 = self.cond_false_1
        cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]);  l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None
        return (cond,)

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            add = l_x_ + l_y_;  l_x_ = l_y_ = None
            return add

    class GraphModule(torch.nn.Module):
        def forward(self, l_pred2_, l_x_, l_y_):
            cond_true_0 = self.cond_true_0
            cond_false_0 = self.cond_false_0
            cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]);  l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None
            return cond

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                sin = l_x_.sin();  l_x_ = None
                cos = l_y_.cos();  l_y_ = None
                sub = sin - cos;  sin = cos = None
                return sub

        class GraphModule(torch.nn.Module):
            def forward(self, l_x_, l_y_):
                cos = l_x_.cos();  l_x_ = None
                sin = l_y_.sin();  l_y_ = None
                sub = cos - sin;  cos = sin = None
                return sub
```
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
```
('cond', <torch._ops.HigherOrderOperator object at xxx>)
('sin', 'sin')
('cos', 'cos')
('sub'. <built-in function sub>)
```

After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
```
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')],
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')]
[('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)]
```

X-link: pytorch/pytorch#108595

Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
Also updated bin by running:
"buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

Reviewed By: angelayi

Differential Revision: D48984986

Pulled By: ydwu4

fbshipit-source-id: 12beb1cee5c6a5b7dc976eda54baaa8f564a0f1e
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 28, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants