-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[user triton] mutation analysis for on-device TMA #155380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[user triton] mutation analysis for on-device TMA #155380
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155380
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit f1f3905 with merge base 79aef14 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR fixes on-device TMA store handling. On-device TMA works like this: ``` triton.jit def kernel(A_ptr, workspace_ptr, ...): tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...) tl._experimental_descriptor_store(workspace_ptr, data, ...) ``` The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case. [TODO caveats / places it doesn't work] cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117) [ghstack-poisoned]
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR fixes on-device TMA store handling. On-device TMA works like this: ``` triton.jit def kernel(A_ptr, workspace_ptr, ...): tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...) tl._experimental_descriptor_store(workspace_ptr, data, ...) ``` The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case. [TODO caveats / places it doesn't work] cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117) [ghstack-poisoned]
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR fixes on-device TMA store handling. On-device TMA works like this: ``` triton.jit def kernel(A_ptr, workspace_ptr, ...): tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...) tl._experimental_descriptor_store(workspace_ptr, data, ...) ``` The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case. [TODO caveats / places it doesn't work] cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117) [ghstack-poisoned]
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR adds partial support for mutation analysis on programs that do stores via on-device TMA. On-device TMA works like this: ``` triton.jit def kernel(A_ptr, workspace_ptr, ...): tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...) tl._experimental_descriptor_store(workspace_ptr, data, ...) ``` The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case. The approach here is to do a first pass to identify all the experimental_descriptor_stores (and collect the associated descriptor values); and then during mutation analysis, any tma creation on a mutated descriptor value (e.g. on `workspace_ptr` in the above example) will actually register as a mutation to the associated data pointer (e.g. `data` in the above example). Consider this example, which I'll used to describe the pros/cons of this approach. ``` triton.jit def create_tma(global_ptr, workspace_ptr): tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, global_ptr, ...) triton.jit def kernel(A, B, workspace_ptr): create_tma(A, workspace_ptr) workspace_B = workspace_ptr + 128 create_tma(B, workspace_B) data = tl._experimental_descriptor_load(workspace_ptr, ...) tl._experimental_descriptor_store(workspace_B, data, ...) ``` An alternative approach could be to simply modify the `tl.extra.cuda.experimental_device_tensormap_create2d` so that it returns a descriptor, and to use that descriptor in subsequent uses (i.e. to "functionalize" the uses of the tma creation API). However, this would (a) require "functionalization" through any function calls (e.g. to `create_tma`), and (b) would lead to both `A` and `B` being marked as mutated (i.e. mutation to `workspace_B` -> mutation to `workspace_ptr` -> mutation to `A`). A downside of the current approach is that it doesn't understand offsets into workspaces. e.g. if one were to recompute workspace_B instead of reusing the variable, the analysis pass would not understand that these values point to the same descriptor. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117) [ghstack-poisoned]
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
||
|
||
@MemoizeWithCycleCheck | ||
def get_tma_stores( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the PR makes sense, but why cant it be part of analyze call, i.e. why does it need to be a stand alone walk?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to simplify the implementation of analyze_kernel_mutations (e.g. easier to test standalone behavior, and then we don't have to return multiple types from analyze_kernel_mutations). Is there data suggesting that matters for compilation time (or is there another reason why this should be a part of the analyze call?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries around compile time, i was just thinking about it like one function call to deal with all analysis related business, easier for logging etc
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR adds partial support for mutation analysis on programs that do stores via on-device TMA.
On-device TMA works like this:
The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case.
The approach here is to do a first pass to identify all the experimental_descriptor_stores (and collect the associated descriptor values); and then during mutation analysis, any tma creation on a mutated descriptor value (e.g. on
workspace_ptr
in the above example) will actually register as a mutation to the associated data pointer (e.g.data
in the above example).Consider this example, which I'll used to describe the pros/cons of this approach.
An alternative approach could be to simply modify the
tl.extra.cuda.experimental_device_tensormap_create2d
so that it returns a descriptor, and to use that descriptor in subsequent uses (i.e. to "functionalize" the uses of the tma creation API). However, this would (a) require "functionalization" through any function calls (e.g. tocreate_tma
), and (b) would lead to bothA
andB
being marked as mutated (i.e. mutation toworkspace_B
-> mutation toworkspace_ptr
-> mutation toA
).A downside of the current approach is that it doesn't understand offsets into workspaces. e.g. if one were to recompute workspace_B instead of reusing the variable, the analysis pass would not understand that these values point to the same descriptor.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
Differential Revision: D76175117