Skip to content

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Jun 6, 2025

Stack from ghstack (oldest at bottom):

Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR adds partial support for mutation analysis on programs that do stores via on-device TMA.

On-device TMA works like this:

@triton.jit
def kernel(A_ptr, workspace_ptr, ...):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...)
    tl._experimental_descriptor_store(workspace_ptr, data, ...)

The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case.

The approach here is to do a first pass to identify all the experimental_descriptor_stores (and collect the associated descriptor values); and then during mutation analysis, any tma creation on a mutated descriptor value (e.g. on workspace_ptr in the above example) will actually register as a mutation to the associated data pointer (e.g. data in the above example).

Consider this example, which I'll used to describe the pros/cons of this approach.

@triton.jit
def create_tma(global_ptr, workspace_ptr):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, global_ptr, ...)

@triton.jit
def kernel(A, B, workspace_ptr):
    create_tma(A, workspace_ptr)
    workspace_B = workspace_ptr + 128
    create_tma(B, workspace_B)
    data = tl._experimental_descriptor_load(workspace_ptr, ...)
    tl._experimental_descriptor_store(workspace_B, data, ...)

An alternative approach could be to simply modify the tl.extra.cuda.experimental_device_tensormap_create2d so that it returns a descriptor, and to use that descriptor in subsequent uses (i.e. to "functionalize" the uses of the tma creation API). However, this would (a) require "functionalization" through any function calls (e.g. to create_tma), and (b) would lead to both A and B being marked as mutated (i.e. mutation to workspace_B -> mutation to workspace_ptr -> mutation to A).

A downside of the current approach is that it doesn't understand offsets into workspaces. e.g. if one were to recompute workspace_B instead of reusing the variable, the analysis pass would not understand that these values point to the same descriptor.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Differential Revision: D76175117

@davidberard98 davidberard98 requested a review from zou3519 as a code owner June 6, 2025 23:50
Copy link

pytorch-bot bot commented Jun 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155380

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit f1f3905 with merge base 79aef14 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 7, 2025
ghstack-source-id: 50e2458
Pull Request resolved: #155380
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 7, 2025
Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR fixes on-device TMA store handling.

On-device TMA works like this:

```
triton.jit
def kernel(A_ptr, workspace_ptr, ...):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...)
    tl._experimental_descriptor_store(workspace_ptr, data, ...)
```

The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case.

[TODO caveats / places it doesn't work]

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117)

[ghstack-poisoned]
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR fixes on-device TMA store handling.

On-device TMA works like this:

```
triton.jit
def kernel(A_ptr, workspace_ptr, ...):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...)
    tl._experimental_descriptor_store(workspace_ptr, data, ...)
```

The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case.

[TODO caveats / places it doesn't work]

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117)

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
ghstack-source-id: 916936e
Pull Request resolved: #155380
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR fixes on-device TMA store handling.

On-device TMA works like this:

```
triton.jit
def kernel(A_ptr, workspace_ptr, ...):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...)
    tl._experimental_descriptor_store(workspace_ptr, data, ...)
```

The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case.

[TODO caveats / places it doesn't work]

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117)

[ghstack-poisoned]
@davidberard98 davidberard98 requested review from aakhundov and oulgen June 9, 2025 20:12
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Previously, the user-defined triton kernel mutation analysis would not detect mutation caused by TMA store, if the TMA descriptor was created via on-device TMA creation. This PR adds partial support for mutation analysis on programs that do stores via on-device TMA.

On-device TMA works like this:

```
triton.jit
def kernel(A_ptr, workspace_ptr, ...):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, A_ptr, ...)
    tl._experimental_descriptor_store(workspace_ptr, data, ...)
```

The first call (tensormap_create2d) mutates the contents of workspace_ptr to contain a data (including the fact that this TMA descriptor points to A_ptr). The second call (experimental_descriptor_store) writes to the location specified by the data in workspace_ptr: A_ptr, in this case.

The approach here is to do a first pass to identify all the experimental_descriptor_stores (and collect the associated descriptor values); and then during mutation analysis, any tma creation on a mutated descriptor value (e.g. on `workspace_ptr` in the above example) will actually register as a mutation to the associated data pointer (e.g. `data` in the above example).

Consider this example, which I'll used to describe the pros/cons of this approach.

```
triton.jit
def create_tma(global_ptr, workspace_ptr):
    tl.extra.cuda.experimental_device_tensormap_create2d(workspace_ptr, global_ptr, ...)

triton.jit
def kernel(A, B, workspace_ptr):
    create_tma(A, workspace_ptr)
    workspace_B = workspace_ptr + 128
    create_tma(B, workspace_B)
    data = tl._experimental_descriptor_load(workspace_ptr, ...)
    tl._experimental_descriptor_store(workspace_B, data, ...)
```

An alternative approach could be to simply modify the `tl.extra.cuda.experimental_device_tensormap_create2d` so that it returns a descriptor, and to use that descriptor in subsequent uses (i.e. to "functionalize" the uses of the tma creation API). However, this would (a) require "functionalization" through any function calls (e.g. to `create_tma`), and (b) would lead to both `A` and `B` being marked as mutated (i.e. mutation to `workspace_B` -> mutation to `workspace_ptr` -> mutation to `A`).

A downside of the current approach is that it doesn't understand offsets into workspaces. e.g. if one were to recompute workspace_B instead of reusing the variable, the analysis pass would not understand that these values point to the same descriptor.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76175117](https://our.internmc.facebook.com/intern/diff/D76175117)

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
ghstack-source-id: f6d9b75
Pull Request resolved: #155380
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.



@MemoizeWithCycleCheck
def get_tma_stores(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the PR makes sense, but why cant it be part of analyze call, i.e. why does it need to be a stand alone walk?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to simplify the implementation of analyze_kernel_mutations (e.g. easier to test standalone behavior, and then we don't have to return multiple types from analyze_kernel_mutations). Is there data suggesting that matters for compilation time (or is there another reason why this should be a part of the analyze call?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries around compile time, i was just thinking about it like one function call to deal with all analysis related business, easier for logging etc

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/davidberard98/367/head branch July 13, 2025 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants