Skip to content

Conversation

@leslie-fang-intel
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel commented Apr 1, 2024

Stack from ghstack (oldest at bottom):

Summary

Refactor the Scheduler.fuse_nodes changes in #121625. In the previous implementation of Scheduler.fuse_nodes in #121625, we use the enable_outer_loop_fusion context to ensure OuterLoopFusion happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current score_fusion mechanism. However, given that fuse_nodes will invoke fuse_nodes_once 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of fuse_nodes_once if they have been pick up by the outer loop fusion in the 1st invocation of fuse_nodes_once.

In this PR, we propose adding an abstract of filter_possible_fusions_by_priority. In each invoking of fuse_nodes_once, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure OuterLoopFusion happens after all the norm fusions.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123067

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit b7a05be with merge base d8717c2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@leslie-fang-intel leslie-fang-intel marked this pull request as draft April 1, 2024 02:24
@leslie-fang-intel leslie-fang-intel changed the title Add the possible fusions group by priority [Inductor] Add the possible fusions group by priority Apr 1, 2024
@leslie-fang-intel leslie-fang-intel added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 1, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Apr 1, 2024
ghstack-source-id: 7504d75
Pull Request resolved: #123067
@leslie-fang-intel leslie-fang-intel marked this pull request as ready for review April 2, 2024 07:52
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@leslie-fang-intel
Copy link
Collaborator Author

Hi @lezcano @jgong5, for the discussion in #121625 (comment). Here is one possible abstract comes to my mind. Could you kindly help to take a look?

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much cleaner! Thank you!

**Summary**

Refactor the `Scheduler.fuse_nodes` changes in #121625. In the previous implementation of `Scheduler.fuse_nodes` in #121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](https://github.com/pytorch/pytorch/blob/f4ff063c333f286d4384523bac67c047aca4d7b9/torch/_inductor/scheduler.py#L1679-L1698) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
**Summary**

Refactor the `Scheduler.fuse_nodes` changes in #121625. In the previous implementation of `Scheduler.fuse_nodes` in #121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](https://github.com/pytorch/pytorch/blob/f4ff063c333f286d4384523bac67c047aca4d7b9/torch/_inductor/scheduler.py#L1679-L1698) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Apr 2, 2024
ghstack-source-id: b06a961
Pull Request resolved: #123067
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for simplifying the logic here.

**Summary**

Refactor the `Scheduler.fuse_nodes` changes in #121625. In the previous implementation of `Scheduler.fuse_nodes` in #121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](https://github.com/pytorch/pytorch/blob/f4ff063c333f286d4384523bac67c047aca4d7b9/torch/_inductor/scheduler.py#L1679-L1698) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
**Summary**

Refactor the `Scheduler.fuse_nodes` changes in #121625. In the previous implementation of `Scheduler.fuse_nodes` in #121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](https://github.com/pytorch/pytorch/blob/f4ff063c333f286d4384523bac67c047aca4d7b9/torch/_inductor/scheduler.py#L1679-L1698) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
**Summary**

Refactor the `Scheduler.fuse_nodes` changes in #121625. In the previous implementation of `Scheduler.fuse_nodes` in #121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](https://github.com/pytorch/pytorch/blob/f4ff063c333f286d4384523bac67c047aca4d7b9/torch/_inductor/scheduler.py#L1679-L1698) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this pull request Apr 5, 2024
ghstack-source-id: 71b5913
Pull Request resolved: #123067
@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@leslie-fang-intel
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
**Summary**

Refactor the `Scheduler.fuse_nodes` changes in pytorch#121625. In the previous implementation of `Scheduler.fuse_nodes` in pytorch#121625, we use the `enable_outer_loop_fusion` context to ensure `OuterLoopFusion` happens after all the norm fusions.

And there is a discussion in https://github.com/pytorch/pytorch/pull/121625/files#r1527177141 to reuse current `score_fusion` mechanism. However, given that [fuse_nodes](https://github.com/pytorch/pytorch/blob/f4ff063c333f286d4384523bac67c047aca4d7b9/torch/_inductor/scheduler.py#L1679-L1698) will invoke `fuse_nodes_once` 10 times. We are concerned that the score approach may potentially disrupt pairs of regular fusion nodes in the 2rd invocation of `fuse_nodes_once` if they have been pick up by the outer loop fusion in the 1st invocation of `fuse_nodes_once`.

In this PR, we propose adding an abstract of `filter_possible_fusions_by_priority`. In each invoking of `fuse_nodes_once`, the possible fusions will be grouped by their priority from the backend. And only the group of possible fusions with highest priority will be fused in this invocation. In this way, we can ensure `OuterLoopFusion` happens after all the norm fusions.

Pull Request resolved: pytorch#123067
Approved by: https://github.com/lezcano, https://github.com/jgong5
ghstack dependencies: pytorch#121625
@github-actions github-actions bot deleted the gh/leslie-fang-intel/85/head branch May 6, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants