Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bootcamp] [AI Accelerators] Convert attention equation to attention call in torch.compile() #94729

Closed
wants to merge 1 commit into from

Conversation

kadeng
Copy link
Contributor

@kadeng kadeng commented Feb 13, 2023

Summary:

Goals

Many attention-based models directly implement scaled dot product attention used by transformer models using component pytorch operators. Recognize these with FX rewrites and convert them to calls to sdpa(), such that
the optimized implementation is chosen where applicable.

Implementation

Given that the usual scaled dot product attention equation references a shape variable ( because everything gets normalized by a dividing by sqrt(embedding_dim_size), the graph rewrite is not entirely trivial.

It requires a new scale_factor parameter in torch.nn.functional.scaled_dot_product_attention ( not part of this PR, see discussion #94729 (comment) about earlier version of this PR )

This in turn will allow to rewrite any pattern like

torch.matmul(query, key.transpose(-2, -1))
.div(scale_factor)
.softmax(dim=-1)
.matmul(value)

into a corresponding (potentially optimized) API call like

scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale_factor=scale_factor,
)

without knowing the scale factor or query and key shapes at compile time.

Reviewer Notes:

As discussed with drisspg, this code would be ready for integration into torch._inductor.overrides once the scale_factor parameter is available on torch.nn.functional.scaled_dot_product_attention. Once that is done, the corresponding sections in the unit tests
may be activated again to verify not just the graph rewrite, but also numerical equivalence.

Followup Tasks:

Once torch.nn.functional.scaled_dot_product_attention has a scale_factor parameter, this rewrite could be activated by configuring it as the sdpa implementation, and the inactive code sections in the unit tests could be made active again.

Test Plan:
Added several new unit tests. These can be executed selectively by running ( Meta Internal Tooling )

buck2 test @//mode/dev-nosan caffe2/test/inductor:pattern_matcher -- test_sdpa_rewriter

Differential Revision: D43234635

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @soumith @yanboliang @anijain2305 @desertfire @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94729

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 791358f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Feb 13, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: kadeng / name: Kai Londenberg (0c26d4a)

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43234635

@kadeng
Copy link
Contributor Author

kadeng commented Feb 13, 2023

CLA is signed now. I am a Meta employee in Bootcamp.

@soumith
Copy link
Member

soumith commented Feb 13, 2023

/easycla

@albanD albanD requested review from ngimel and drisspg and removed request for albanD February 13, 2023 22:12
kadeng added a commit to kadeng/pytorch that referenced this pull request Feb 15, 2023
…rch#94729)

Summary:
Pull Request resolved: pytorch#94729

Many attention-based models directly implement scaled dot product attention used by transformer models using component pytorch operators.  Recognize these with FX rewrites and convert them to calls to sdpa(), such that
the optimized implementation is chosen where applicable.

Implementation

Given that the usual scaled dot product attention equation references a shape variable ( because everything gets normalized by a dividing by sqrt(embedding_dim_size), the graph rewrite is not entirely trivial.

It was neccessary to introduce a new function ( torch.nn.scale_factor_dot_product_attention ).
which accepts a scale factor as parameter. This in turn allows to rewrite any pattern like

  torch.matmul(query, key.transpose(-2, -1))
        .div(scale_factor)
        .softmax(dim=-1)
        .matmul(value)

into a corresponding (potentially optimized) API call like

  scale_factor_dot_product_attention(
        query.contiguous(),
        key.contiguous(),
        value.contiguous(),
        attn_mask=None,
        dropout_p=0.0,
        is_causal=False,
        scale_factor=scale_factor,
    )

without knowing the scale factor or query and key shapes at compile time.

As I would consider this rewriter to be an alpha-quality feature, a new **feature switch** was introduced
at *torch._inductor.pattern_replace_scaled_dot_product_attention* which defaults to False

torch.nn.scale_factor_dot_product_attention would be unneccessary if we would introduce an
optional scale_factor parameter to the pre-existing torch.nn.scaled_dot_product_attention.

I chose not to do this, as it would suggest that an optimized function would be used, regardless
of the choice of scale_factor, and due to the higher risk of introducing a breaking change.

Also, it should be possible to push the logic which is now in native code within torch.nn.functional.scale_factor_dot_product_attention into Python code. Not clear what the benefits of that would be, but maybe you prefer it that way.

Test Plan:
Added several new unit tests. These can be executed selectively by running

  buck2 test @//mode/dev-nosan caffe2/test/inductor:pattern_matcher -- test_scaled_dot_product_attention
  buck2 test @//mode/dev-nosan caffe2/test:test_transformers -- test_scale_factor_dot_product_attention

Differential Revision: D43234635
@kadeng
Copy link
Contributor Author

kadeng commented Feb 15, 2023

Update:
Addressed issues flagged by CI, made torch.nn.functional.scale_factor_dot_product_attention private ( now called torch.nn.functional._scale_factor_dot_product_attention )

Had to do a force-commit from the patch file after I was unable to merge via export from Meta-internal diff.

kadeng added a commit to kadeng/pytorch that referenced this pull request Feb 15, 2023
…call in torch.compile() (pytorch#94729)

Summary:
Pull Request resolved: pytorch#94729

### Goals

Many attention-based models directly implement scaled dot product attention used by transformer models using component pytorch operators.  Recognize these with FX rewrites and convert them to calls to sdpa(), such that
the optimized implementation is chosen where applicable.

Implementation

Given that the usual scaled dot product attention equation references a shape variable ( because everything gets normalized by a dividing by sqrt(embedding_dim_size), the graph rewrite is not entirely trivial.

It was neccessary to introduce a new function ( torch.nn.scale_factor_dot_product_attention ).
which accepts a scale factor as parameter. This in turn allows to rewrite any pattern like

  torch.matmul(query, key.transpose(-2, -1))
        .div(scale_factor)
        .softmax(dim=-1)
        .matmul(value)

into a corresponding (potentially optimized) API call like

  scale_factor_dot_product_attention(
        query.contiguous(),
        key.contiguous(),
        value.contiguous(),
        attn_mask=None,
        dropout_p=0.0,
        is_causal=False,
        scale_factor=scale_factor,
    )

without knowing the scale factor or query and key shapes at compile time.

### Feature Switch

As I would consider this rewriter to be an alpha-quality feature, a new **feature switch** was introduced
at *torch._inductor.pattern_replace_scaled_dot_product_attention* which defaults to False

### Reviewer Notes:

torch.nn.scale_factor_dot_product_attention would be unneccessary if we would introduce an
optional scale_factor parameter to the pre-existing torch.nn.scaled_dot_product_attention.

I chose not to do this, as it would suggest that an optimized function would be used, regardless
of the choice of scale_factor, and due to the higher risk of introducing a breaking change.

Also, it should be possible to push the logic which is now in native code within torch.nn.functional.scale_factor_dot_product_attention into Python code. Not clear what the benefits of that would be, but maybe you prefer it that way.

Test Plan:
Added several new unit tests. These can be executed selectively by running

  buck2 test @//mode/dev-nosan caffe2/test/inductor:pattern_matcher -- test_scaled_dot_product_attention
  buck2 test @//mode/dev-nosan caffe2/test:test_transformers -- test_scale_factor_dot_product_attention

Differential Revision: D43234635

fbshipit-source-id: d5517f78a1e39088bb604d4e27468b574f5eea21
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D43234635

jansel added a commit that referenced this pull request Apr 2, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 2, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: 8a47b3bdf3ad335c9a100cd91ba1b00e01fdb82a
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 2, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 2, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 2, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: 9d151de57eec17de6f70e66b0cd17eab2c9ccb5c
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 4, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: fe121935b26fc1a9bfae4ac91342d2d2d0389e0a
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 4, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 4, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: 7af7bed3e6fccb9b32ef7b2087316ada15e4ba7b
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: ccd7b0e6f0454ee6760058762a7c806ba0c6bf2d
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: 5e1675469721abcad30970ed0de0de2a65e8e837
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 6, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: db9be06f974bb6adf43a8c2ace0f58f5da3a65c0
Pull Request resolved: #97741
jansel added a commit that referenced this pull request Apr 9, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 9, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
jansel added a commit that referenced this pull request Apr 9, 2023
Patterns based on #94729 --
mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

ghstack-source-id: 59e3d747059acb3ab1423cf3f8e8bad8687f8b0f
Pull Request resolved: #97741
pytorchmergebot pushed a commit that referenced this pull request Apr 10, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

Pull Request resolved: #97741
Approved by: https://github.com/Chillee
skotapati pushed a commit to kulinseth/pytorch that referenced this pull request Apr 10, 2023
Patterns based on pytorch#94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

Pull Request resolved: pytorch#97741
Approved by: https://github.com/Chillee
ZainRizvi pushed a commit that referenced this pull request Apr 19, 2023
Patterns based on #94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

Pull Request resolved: #97741
Approved by: https://github.com/Chillee
@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 16, 2023
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jun 16, 2023
@github-actions github-actions bot closed this Jul 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants