Skip to content

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Jul 5, 2024

Stack from ghstack (oldest at bottom):

When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:

  • non-strict export doesn't use Dynamo, but people want to use
    non-strict export to export programs with triton kernels.
    non-strict export uses make_fx tracing, so this is a necessary step in
    that direction.
  • People want to write inductor passes that replace a sequence of
    operators with a call to a function that may contain a triton kernel.
    The way these passes work today is that we have a FX graph and want to
    replace a subgraph of it with a new subgraph. We obtain said subgraph
    from calling make_fx on the function; this won't work on raw triton
    kernels but will work if one uses capture_triton.

Test Plan:

  • I wrote some manual tests to run make_fx over two of the triton
    kernels in test_triton_kernels. It would be nice to be able to run
    make_fx through all of the tests in the file but I'm not sure how to
    do that refactor right now.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130178

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8b770a8 with merge base a5f816d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jul 5, 2024
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

ghstack-source-id: fcc9614
Pull Request resolved: #130178
@zou3519 zou3519 requested review from eellison and oulgen July 8, 2024 13:47
return grid(meta)

def check_grid(self, grid):
if not isinstance(grid, tuple):
Copy link
Contributor

@oulgen oulgen Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I haven't actually verified this but is it really only ever tuple? would say a list fail here? IIRC triton only reads it as decomposition, but good to verify

  2. not sure how the orchestration here happens but if grid is a function, we do already have it fully resolved before here? can you point to who calls these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I haven't actually verified this but is it really only ever tuple? would say a list fail here? IIRC triton only reads it as decomposition, but good to verify

A list works here, so I'll update the code to convert to tuple (and add a test for it)

  1. not sure how the orchestration here happens but if grid is a function, we do already have it fully resolved before here? can you point to who calls these?

If grid is a function, it's fully resolved before here. The orchestrator resolves the function before calling check_grid; the code for this is over at https://github.com/pytorch/pytorch/pull/130178/files#diff-f5fa7d0e418e91c63fa56d577a92a294c87e19318a3c8b3736ac4254eaa51db9R907-R915

non_graphable_args = {
k: v
for k, v in combined_args.items()
if not isinstance(v, (torch.Tensor, int, float, bool))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you come up with this list? pretty sure string also belongs in it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From memory, I can add string to it. I'll go digging to see if there's a helper function for this in torch.fx already

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side table in this file also contains a check for this, maybe use the same one as that one?

Copy link
Contributor Author

@zou3519 zou3519 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side table in this file also contains a check for this, maybe use the same one as that one?

afaict in the Dynamo path, we put all constant args into the side table. In make_fx there's not a concept of constant vs non-constant args, so it doesn't look like we can reuse that check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I did that because triton.dtype was passed as an arg and we could not put that on the fx graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.fx has a way for checking this -- they encode the acceptable types in BaseArgumentType. I'll use that check


if not is_fx_tracing() or torch._dynamo.is_compiling():
assert self.kernel is not None
return self.kernel[self.grid](*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be consistent with above

Suggested change
return self.kernel[self.grid](*args, **kwargs)
return self.kernel.run(*args, **kwargs, grid=self.grid)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting TypeError: JITFunction.run() missing 1 required keyword-only argument: 'warmup' when I use .run here -- doing the __call__ doesn't seem to require the warmup arg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that worked

When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jul 9, 2024
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

ghstack-source-id: 4b6b203
Pull Request resolved: #130178
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@zou3519 zou3519 added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Jul 9, 2024
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Jul 9, 2024
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

ghstack-source-id: 5292d93
Pull Request resolved: #130178
@zou3519 zou3519 added the ci-no-td Do not run TD on this PR label Jul 10, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Jul 10, 2024

@pytorchbot merge -f "rocm tests not ending; everything else passed"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
When applied to a triton kernel, capture_triton allows the triton kernel
to be captured when tracing with make_fx. It does this by transforming the
call to the triton kernel into a call to the
triton_kernel_wrapper_mutation HOP, which can actually be traced into a
graph via make_fx.

We have two main uses cases for this:
- non-strict export doesn't use Dynamo, but people want to use
  non-strict export to export programs with triton kernels.
  non-strict export uses make_fx tracing, so this is a necessary step in
  that direction.
- People want to write inductor passes that replace a sequence of
  operators with a call to a function that may contain a triton kernel.
  The way these passes work today is that we have a FX graph and want to
  replace a subgraph of it with a new subgraph. We obtain said subgraph
  from calling make_fx on the function; this won't work on raw triton
  kernels but will work if one uses capture_triton.

Test Plan:
- I wrote some manual tests to run make_fx over two of the triton
  kernels in test_triton_kernels. It would be nice to be able to run
  make_fx through all of the tests in the file but I'm not sure how to
  do that refactor right now.

Pull Request resolved: pytorch#130178
Approved by: https://github.com/oulgen
ghstack dependencies: pytorch#130177
@github-actions github-actions bot deleted the gh/zou3519/1018/head branch August 10, 2024 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants