-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add new (private) capture_triton API #130178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130178
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8b770a8 with merge base a5f816d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. ghstack-source-id: fcc9614 Pull Request resolved: #130178
return grid(meta) | ||
|
||
def check_grid(self, grid): | ||
if not isinstance(grid, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
I haven't actually verified this but is it really only ever tuple? would say a list fail here? IIRC triton only reads it as decomposition, but good to verify
-
not sure how the orchestration here happens but if grid is a function, we do already have it fully resolved before here? can you point to who calls these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I haven't actually verified this but is it really only ever tuple? would say a list fail here? IIRC triton only reads it as decomposition, but good to verify
A list works here, so I'll update the code to convert to tuple (and add a test for it)
- not sure how the orchestration here happens but if grid is a function, we do already have it fully resolved before here? can you point to who calls these?
If grid is a function, it's fully resolved before here. The orchestrator resolves the function before calling check_grid; the code for this is over at https://github.com/pytorch/pytorch/pull/130178/files#diff-f5fa7d0e418e91c63fa56d577a92a294c87e19318a3c8b3736ac4254eaa51db9R907-R915
non_graphable_args = { | ||
k: v | ||
for k, v in combined_args.items() | ||
if not isinstance(v, (torch.Tensor, int, float, bool)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did you come up with this list? pretty sure string also belongs in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From memory, I can add string to it. I'll go digging to see if there's a helper function for this in torch.fx already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side table in this file also contains a check for this, maybe use the same one as that one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side table in this file also contains a check for this, maybe use the same one as that one?
afaict in the Dynamo path, we put all constant args into the side table. In make_fx there's not a concept of constant vs non-constant args, so it doesn't look like we can reuse that check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, I did that because triton.dtype was passed as an arg and we could not put that on the fx graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.fx has a way for checking this -- they encode the acceptable types in BaseArgumentType. I'll use that check
|
||
if not is_fx_tracing() or torch._dynamo.is_compiling(): | ||
assert self.kernel is not None | ||
return self.kernel[self.grid](*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
be consistent with above
return self.kernel[self.grid](*args, **kwargs) | |
return self.kernel.run(*args, **kwargs, grid=self.grid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting TypeError: JITFunction.run() missing 1 required keyword-only argument: 'warmup'
when I use .run
here -- doing the __call__
doesn't seem to require the warmup arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass warmup=False, __call__
does that: https://github.com/triton-lang/triton/blob/c14b033cd979d5c39e5fdb3847c022fa5d71a0c1/python/triton/runtime/jit.py#L326C59-L326C72
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, that worked
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. ghstack-source-id: 4b6b203 Pull Request resolved: #130178
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. ghstack-source-id: 5292d93 Pull Request resolved: #130178
@pytorchbot merge -f "rocm tests not ending; everything else passed" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
When applied to a triton kernel, capture_triton allows the triton kernel to be captured when tracing with make_fx. It does this by transforming the call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP, which can actually be traced into a graph via make_fx. We have two main uses cases for this: - non-strict export doesn't use Dynamo, but people want to use non-strict export to export programs with triton kernels. non-strict export uses make_fx tracing, so this is a necessary step in that direction. - People want to write inductor passes that replace a sequence of operators with a call to a function that may contain a triton kernel. The way these passes work today is that we have a FX graph and want to replace a subgraph of it with a new subgraph. We obtain said subgraph from calling make_fx on the function; this won't work on raw triton kernels but will work if one uses capture_triton. Test Plan: - I wrote some manual tests to run make_fx over two of the triton kernels in test_triton_kernels. It would be nice to be able to run make_fx through all of the tests in the file but I'm not sure how to do that refactor right now. Pull Request resolved: pytorch#130178 Approved by: https://github.com/oulgen ghstack dependencies: pytorch#130177
Stack from ghstack (oldest at bottom):
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.
We have two main uses cases for this:
non-strict export to export programs with triton kernels.
non-strict export uses make_fx tracing, so this is a necessary step in
that direction.
operators with a call to a function that may contain a triton kernel.
The way these passes work today is that we have a FX graph and want to
replace a subgraph of it with a new subgraph. We obtain said subgraph
from calling make_fx on the function; this won't work on raw triton
kernels but will work if one uses capture_triton.
Test Plan:
kernels in test_triton_kernels. It would be nice to be able to run
make_fx through all of the tests in the file but I'm not sure how to
do that refactor right now.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang