Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlexAttention isn't using decompositions #124643

Closed
Chillee opened this issue Apr 22, 2024 · 2 comments
Closed

FlexAttention isn't using decompositions #124643

Chillee opened this issue Apr 22, 2024 · 2 comments
Assignees
Labels
module: decompositions Topics related to decomposition (excluding PrimTorch) oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Chillee
Copy link
Contributor

Chillee commented Apr 22, 2024

馃殌 The feature, motivation and pitch

We don't trace with any decompositions today (https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/templated_attention.py#L125).

This can cause errors in the case that we use a pointwise op that's actually decomposed (e.g. q // kv).

cc: @ydwu4 who was looking at how this should be done properly.

Alternatives

N/A

Additional context

No response

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @SherlockNoMad

@bdhirsh
Copy link
Contributor

bdhirsh commented Apr 22, 2024

Since we stash the current decomposition table in a global while tracing here, would it be reasonable for all HOP's that use make_fx on their inner subgraphs to poke at the global if it's available?

I guess a minor downside is that this is easy to get (silently) wrong without some auditing.

@ydwu4 ydwu4 self-assigned this Apr 22, 2024
@ydwu4
Copy link
Contributor

ydwu4 commented Apr 22, 2024

A related issue: #122972. I'll work on it asap.

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: decompositions Topics related to decomposition (excluding PrimTorch) labels Apr 24, 2024
ydwu4 added a commit that referenced this issue May 10, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 10, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 10, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 10, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 13, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 13, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 14, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ydwu4 added a commit that referenced this issue May 14, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in #122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes. 

Also fixes #124643



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Adds trace_subgraph to _MakefxTracer, the motivation is in pytorch#122972. Also migrate all existing usage of reenter_make_fx to the new sub-tracer. Previously, the torch function mode for creating torch_fn metadata won't be re-enetered when we're in ProxyTensorMode (since it's inside of __torch_function__). This PR reconstruct the torch function mode based on parent tracer's config and reentered the torch function mode so the metadata is shown in the graph.

**Test Plan:**
Existing tests. We have a bunch of make_fx tests for cond, map and while_loop. Also remove expected failure for torch_fn since reenter_make_fx is able to re-construct torch function modes.

Also fixes pytorch#124643

Pull Request resolved: pytorch#125363
Approved by: https://github.com/Chillee
ghstack dependencies: pytorch#125267
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: decompositions Topics related to decomposition (excluding PrimTorch) oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants