New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace node.meta source_fn with source_fn_stack #108595
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108595
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit ce69b53 with merge base 40b83d9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@ydwu4 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
abc5952
to
c264a44
Compare
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Differential Revision: D48984986 Pulled By: ydwu4
This pull request was exported from Phabricator. Differential Revision: D48984986 |
can you also add some examples for the format for the previous "source_fn" and the current "source_fn_stack" for both (one fn stack and two fn stack) in PR summary? |
Added in the PR summary |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Differential Revision: D48984986 Pulled By: ydwu4
c264a44
to
91c5910
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
thanks, can you add a print for the graph and subgraphs as well? and add some comments for the results? and for the source_fn_stack, why |
Sure! See the updated summary.
I use a more complicated example where we have a nested cond so the operators will have the cond -> cond -> op stack as a list of (name, target) tuple. |
thanks! makes sense |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Differential Revision: D48984986 Pulled By: ydwu4
91c5910
to
0922073
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
@@ -465,12 +467,11 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: | |||
f"{k}:({v[0]},{self.serialize_operator(v[1])})" | |||
for k, v in nn_module_stack.items() | |||
] | |||
ret["nn_module_stack"] = ";".join(nn_module_list) | |||
ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are technically changing the IR, does ExportedProgram's version number need a bump? Or is the version number bump automatic (and in-line with the PyTorch version number?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are technically changing the IR, does ExportedProgram's version number need a bump
I thought about this too, but @zhxchen17 and @gmagogsfm 's opinion was that it's a minor change so we don't need a version bump
@@ -150,6 +150,7 @@ def speculate_subgraph( | |||
graph_checkpoint, | |||
checkpoint, | |||
description, | |||
source_target, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we put this after the *
? Also, what is the type of this? (HigherOrderOpVariable ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! It's the .value of HigherOrderOpVariable, which is the target of the proxy that we created for the higherOrderOperator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, so this is just a string? The name "source_target" is not that descriptive, could we add a comment with what you said?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems reasonable, some questions, otherwise minor comments
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
ab46223
to
11dcac7
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
11dcac7
to
418a30e
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
418a30e
to
1a87fbe
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
1a87fbe
to
300ff88
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
300ff88
to
2e9515e
Compare
This pull request was exported from Phabricator. Differential Revision: D48984986 |
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
This pull request was exported from Phabricator. Differential Revision: D48984986 |
2e9515e
to
ce69b53
Compare
Summary: Pull Request resolved: #210 A resubmit of pytorch/pytorch#108447. Copy over the descriptions: This is a follow-up of the discussion in pytorch/pytorch#108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` X-link: pytorch/pytorch#108595 Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4 fbshipit-source-id: 12beb1cee5c6a5b7dc976eda54baaa8f564a0f1e
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
A resubmit of #108447. Copy over the descriptions:
This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack
Before this PR, for the following example:
The graph captured is shown below:
the source_fn for inner cond, sin, cos will be a (name, target) tuple:
After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list.
Test Plan:
See added tests in test_higher_order_ops.py and modify existing test.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @ngimel