Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Add native support for Triton Kernels to Dynamo #109623

Closed
wants to merge 21 commits into from

Conversation

oulgen
Copy link
Contributor

@oulgen oulgen commented Sep 19, 2023

Stack from ghstack (oldest at bottom):

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109623

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9939238 with merge base 40b83d9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

oulgen added a commit that referenced this pull request Sep 19, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: a66e0bab6c4f50d0d238ea228a4edcf332c62665
Pull Request resolved: #109623

# If the grid is a function, then lets execute it and convert it to
# a list
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two function variables and list variables (as per below) are supported. Do I need to support any other type? Triton's docs only ever use lambda and tuple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, we can add more later if needed.

value,
None, # No grid provided
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is function match correct here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use ID_MATCH

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 19, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 5318cbec0001eee08d04359dbcd73af78eeb5730
Pull Request resolved: #109623
@oulgen oulgen requested a review from zou3519 September 19, 2023 21:22
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 19, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 75ac402a2d86d1670ff154448a29ea6401483627
Pull Request resolved: #109623
@@ -0,0 +1,13 @@
import functools
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is copy pasta from inductor utils. I cannot use that since it creates circular dependancy. If this looks like, I can have inductor also use this util. Although there appears two separate implementation for get_device_capability

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move things rather than copying them.

I believe the linter will complain unless the filename starts with _ (indicating it is not a public API).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will create a new PR under this one that moves everything to this new API. I was worried about the two different implementations of get_device_capability

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 20, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 50ab2c1295665223b18fd483ff034bce10e0e572
Pull Request resolved: #109623
@ezyang
Copy link
Contributor

ezyang commented Sep 20, 2023

Err, I wasn't privy to whatever initial design discussions you had @oulgen. Is there something in particular you want me to review haha

test/dynamo/test_functions.py Show resolved Hide resolved
value,
None, # No grid provided
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use ID_MATCH

Comment on lines 680 to 681
(grid,) + proxied_args,
proxied_kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grid as a kwarg would be easier to read in output code

Suggested change
(grid,) + proxied_args,
proxied_kwargs,
proxied_args,
{**proxied_kwargs, "grid": grid},

also need to change call_function above.

@@ -0,0 +1,13 @@
import functools
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move things rather than copying them.

I believe the linter will complain unless the filename starts with _ (indicating it is not a public API).

@zou3519
Copy link
Contributor

zou3519 commented Sep 20, 2023

Btw, I would recommend figuring out the "explicit API to turn a triton kernel into a custom operator" first (or in parallel) with this PR, because there are some additional things we would need to teach Dynamo in order for it to be able to take the triton kernel and the surrounding code and call the explicit API.

Concretely, in order to push the triton kernel through AOTAutograd:

  • Operators need a "schema" to describes which args are inputs and which are outputs: we need to know which args to the triton kernel are inputs and which are outputs. Either a user marks these explicitly, or Dynamo needs to tell us by actually introspecting the triton kernel
  • We can turn the triton kernel into an in-place operator. AOTAutograd needs to functionalize this, so, we need an out-of-place variant of the triton kernel that is responsible for (1) allocating the outputs of the triton kernel and (2) passing them to the triton kernel. We would need to pass the "explicit API" a function describing how to allocate the outputs and Dynamo would need to browse the code surrounding the triton kernel to come up with such a function

@jansel
Copy link
Contributor

jansel commented Sep 21, 2023

IMO we shouldn't require any user annotations for Triton kernels. This seems like worse UX than regular Triton.

  • Operators need a "schema" to describes which args are inputs and which are outputs: we need to know which args to the triton kernel are inputs and which are outputs. Either a user marks these explicitly, or Dynamo needs to tell us by actually introspecting the triton kernel

Can't we just conservatively assume everything is both an input and an output?

  • We can turn the triton kernel into an in-place operator. AOTAutograd needs to functionalize this, so, we need an out-of-place variant of the triton kernel that is responsible for (1) allocating the outputs of the triton kernel and (2) passing them to the triton kernel. We would need to pass the "explicit API" a function describing how to allocate the outputs and Dynamo would need to browse the code surrounding the triton kernel to come up with such a function

I don't think all of this is actually needed. If we do this, then when it gets to inductor we will need to reverse it to have efficient code. I think we can do a simple rewrite like:

triton_kernel(a, b, c)

becomes:

a, b, c = wrapped_triton_kernel(triton_kernel, a, b, c)

Where we basically just assume all args to the triton kernel are mutated in functionalization. The meta function for this is just it returns outputs with the same shape as its inputs, and it can be functional copying the inputs before the kernel is called.

This will be easier to codegen since it requires zero metadata, and also easier for inductor to reverse.

@bdhirsh
Copy link
Contributor

bdhirsh commented Sep 21, 2023

Where we basically just assume all args to the triton kernel are mutated in functionalization. The meta function for this is just it returns outputs with the same shape as its inputs, and it can be functional copying the inputs before the kernel is called.

Is the idea that wrapped_triton_kernel adds clone() calls before the triton kernel to make it functional, and inductor removes them?

One risk is that this won't play well with aliasing. Example:

def post_aot_functionalized_model(...):
    a, b, c = ...
    # b is mutated by the user's triton kernel, so any of b's aliases will also be mutated
    b_view = b.view(-1)
    a, b, c = triton_kernel_wrapper(triton_fn, a, b, c)
    # If functionalization is incorrectly told that "triton_kernel_wrapper" is functional,
    # then we won't know to update b's aliases during functionalization.

@jansel
Copy link
Contributor

jansel commented Sep 21, 2023

Is the idea that wrapped_triton_kernel adds clone() calls before the triton kernel to make it functional, and inductor removes them?

Yes, correct.

One risk is that this won't play well with aliasing.

I think this conversion might need to happen during functionalization. You could have something like:

wrapped_triton_kernel_(triton_kernel, a0, b0, c0)

No need for custom user-provided metadata on this, because it returns None and we assume it mutates all of its inputs.

Then functionalization changes it to:

a1, b1, c1 = wrapped_triton_kernel(triton_kernel, a0, b0, c0)

and rewrites all remaining uses of a/b/c to use a1/b1/c1.

The reference implementation could be just:

def wrapped_triton_kernel(fn, *args):
   args = [(clone_preserve_strides(x) if isinstance(x, Tensor) else x)
           for x in args]
   fn(*args)
   return args

This type of conversion is exactly the same as what functionalization does for relu_ -> relu, so should be easy to implement.

Basically there should be no need for authors of Triton kernels to provide any metadata, because Triton kernels cannot allocate tensors. Triton kernels also have no concept of sizes/strides, not is there any restriction on what is an input or an output (or both).

cc @bdhirsh @zou3519

@zou3519
Copy link
Contributor

zou3519 commented Sep 21, 2023

@jansel IMO getting the triton kernel through the custom op mechanism is the right long-term solution (and does not make it more difficult for inductor to reverse).

On the long-term: people will want to compose the triton kernel with transforms like vmap or teach subsystems like DTensor how to behave when it sees the triton kernel. The right way to do this is to integrate with the PyTorch dispatcher: turning the triton kernel into a "custom operator" allows the PyTorch dispatcher to interpose on it. In the long-term, folks will use an explicit API to deal with this, but we can support automatic Dynamo capture of triton kernels into this explicit API to support the way people author existing triton kernels.

I don't think all of this is actually needed. If we do this, then when it gets to inductor we will need to reverse it to have efficient code.

For inductor reversal: my proposal is, given a triton kernel, turn it into two custom ops: a functional op, and an out variant op. The out variant op is exactly the triton kernel and the functional op allocates output tensors and invokes the out variant. After AOTAutograd runs, the functional op will exist in the IR that gets passed to Inductor.

When Inductor sees the functional op:

  • it does memory planning, by converting the functional op into the out variant
  • when it sees the out variant, it can lookup the triton kernel in a sidetable and then do what it needs to do with it.

Can't we just conservatively assume everything is both an input and an output? ...
I think we can do a simple rewrite like:
triton_kernel(a, b, c)
becomes:
a, b, c = wrapped_triton_kernel(triton_kernel, a, b, c)

I'm still thinking through this. But I'm not convinced we can generate a correct functionalization rule for this op without taking some performance hits. One particular performance issue is that we are no longer able to reorder usages of (a, b) with respect to wrapped_triton_kernel(triton_kernel, a, b, c)

@jansel
Copy link
Contributor

jansel commented Sep 21, 2023

On the long-term: people will want to compose the triton kernel with transforms like vmap or teach subsystems like DTensor how to behave when it sees the triton kernel. The right way to do this is to integrate with the PyTorch dispatcher: turning the triton kernel into a "custom operator" allows the PyTorch dispatcher to interpose on it. In the long-term, folks will use an explicit API to deal with this, but we can support automatic Dynamo capture of triton kernels into this explicit API to support the way people author existing triton kernels.

Given how Triton kernels work, I'm not sure that this is coherent. Triton kernels are just too low level for this to make sense.

One could use (possibly multiple) Triton kernels to build a custom op like this, but it would require a lot of extra work on the part of the user. There are also some types of Triton kernels that would not fit into this model.

We could always do both. Vanilla Triton kernels work out of the box, and you can optionally use the same API as one would for custom CUDA kernels (which would require more work to setup).

I'm still thinking through this. But I'm not convinced we can generate a correct functionalization rule for this op without taking some performance hits. One particular performance issue is that we are no longer able to reorder usages of (a, b) with respect to wrapped_triton_kernel(triton_kernel, a, b, c)

I think inductor should be able to optimize away the extra copies.

@ezyang
Copy link
Contributor

ezyang commented Sep 21, 2023

If you conservatively assume that all inputs are read-write, then you force functionalization to pessimize in this way:

# x input
triton_kernel(x, out)
z = f(x)

to

x2 = x.clone()
triton_kernel(x2, out)
z = f(x)

and there isn't really anyway that Inductor can optimize this away

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 26, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 5f15f11fcd901f87c2a65622e6ec1de436caee81
Pull Request resolved: #109623
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 26, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: a4a4c6a3aa9ba7bebc1f19417cfdc3babffba32a
Pull Request resolved: #109623
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cant do this, because even tho you have HAS_TRITON around it, the type hint happens at a different time. Try "tl.constexpr" - type as a string. If not, feel free to omit it.

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 26, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 3352b97a25c331fe090263750e48a4200dc9d0fa
Pull Request resolved: #109623
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Sep 26, 2023
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 1d1c76cfda6f09f0c25dbfa03aeaf3b0bc8629df
Pull Request resolved: #109623
@oulgen oulgen requested a review from jansel September 27, 2023 02:06
Comment on lines 9 to 10
def __call__(self, *args, kernel, grid, **kwargs):
kernel[grid](*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we don't do (self, kernel, grid, args: Tuple, kwargs: Dict) instead? That would be a bit cleaner (what if the triton kernel has a kwarg named grid or kernel?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of itertools.partial kernel is passed as kwarg, so it has to come after positional args. I made grid be a positional arg earlier on but @jansel asked me to make it a kwarg as well. So I needed to move them after positional arg list of *args.

It is possible that I misunderstand how positional and kwargs work in python but doing (self, kernel, grid, args: Tuple, kwargs: Dict) gives error saying that kernel and grid are defined twice once as pos and once as kw.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say we had a triton kernel that accepted kernel argument (unlikely, but just going through the exercise):

@triton.jit
def sin_kernel(kernel, out):
    pass

and a function like

@torch.compile
def f(x):
   out = torch.empty_like(x)
   sin_kernel[grid](kernel=x, out=out)
   return out

Will this cause a conflict with the kernel argument to the HigherOrderOp? (if not, then I'm satisfied). I expect that the TritonKernelVariable.call_function receives kwargs= {"kernel": TensorVariable, "out": TensorVariable}, and then us later invoking triton_kernel_wrapper_mutation with 2 different kernels

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, you can pass a positional arg as kwarg in functools.partials:

image

Copy link
Contributor Author

@oulgen oulgen Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this does in fact result in TypeError: 'Tensor' object is not callable. Since we are using itertools.partial, we are forced to use kwargs AFAICT. How would we be able to name this something a user cannot write themselves?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see the problem. If we did (kernel, grid, args, kwargs) as the signature, we could pass functools.partial the kernel as a positional arg. Does that work? For example:

def f(kernel, grid, *args, **kwargs):
   return (kernel, grid, args, kwargs)

kernel = 1
kf = functools.partial(f, kernel)
print(kf(2, 3, 4, five=5))
# (1, 2, (3, 4), {'five': 5})

Comment on lines 704 to 705
# __getitem__ should only be called if we don't already have a grid
assert self.grid is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to raise Unsupported (and in some of the other asserts that the user may see, like `assert len(args) == 1)? That way we fallback to eager-mode and the user can see the error that eager-mode provides.

Comment on lines +683 to +685
if not hasattr(fn, "__module__"):
# Super hacky but on AMD __module__ is not set
fn.__module__ = "itertools"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is an amd thing, this might be a python version thing (locally for me functools.partial doesn't return something with a .__module__).

Also it would be better to set the __module__ to triton_kernel_wrapper_mutation's .__module__: IIRC these two things are responsible for how the function displays in the dynamo graph

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm deferring to the Dynamo folks (Jason, Voz) for the Dynamo pieces (but those look OK to me). The HigherOrderOp LGTM, would prefer that the signature looks sane like (kernel, grid, args, kwargs), but not a big deal if it doesn't

@oulgen
Copy link
Contributor Author

oulgen commented Sep 27, 2023

@zou3519 and I chatted offline, I'll update the diff to make the call functions be __call__(*, kernel, grid, args, kwargs) where args is passed as a list and kwargs is passed as dict in dynamo so that user defined kernel does not have any argument name limitations

This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]

# If the grid is a function, then lets execute it and convert it to
# a list
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, we can add more later if needed.

@oulgen
Copy link
Contributor Author

oulgen commented Sep 29, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/oulgen/1/head branch October 3, 2023 14:23
Stonepia added a commit to Stonepia/pytorch that referenced this pull request Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants