New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamo] Add native support for Triton Kernels to Dynamo #109623
Conversation
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109623
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9939238 with merge base 40b83d9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: a66e0bab6c4f50d0d238ea228a4edcf332c62665 Pull Request resolved: #109623
|
||
# If the grid is a function, then lets execute it and convert it to | ||
# a list | ||
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two function variables and list variables (as per below) are supported. Do I need to support any other type? Triton's docs only ever use lambda and tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine, we can add more later if needed.
torch/_dynamo/variables/builder.py
Outdated
value, | ||
None, # No grid provided | ||
source=self.source, | ||
guards=make_guards(GuardBuilder.FUNCTION_MATCH), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is function match correct here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use ID_MATCH
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: 5318cbec0001eee08d04359dbcd73af78eeb5730 Pull Request resolved: #109623
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: 75ac402a2d86d1670ff154448a29ea6401483627 Pull Request resolved: #109623
torch/utils/triton.py
Outdated
@@ -0,0 +1,13 @@ | |||
import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is copy pasta from inductor utils. I cannot use that since it creates circular dependancy. If this looks like, I can have inductor also use this util. Although there appears two separate implementation for get_device_capability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's move things rather than copying them.
I believe the linter will complain unless the filename starts with _
(indicating it is not a public API).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will create a new PR under this one that moves everything to this new API. I was worried about the two different implementations of get_device_capability
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: 50ab2c1295665223b18fd483ff034bce10e0e572 Pull Request resolved: #109623
Err, I wasn't privy to whatever initial design discussions you had @oulgen. Is there something in particular you want me to review haha |
torch/_dynamo/variables/builder.py
Outdated
value, | ||
None, # No grid provided | ||
source=self.source, | ||
guards=make_guards(GuardBuilder.FUNCTION_MATCH), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use ID_MATCH
torch/_dynamo/variables/functions.py
Outdated
(grid,) + proxied_args, | ||
proxied_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grid as a kwarg would be easier to read in output code
(grid,) + proxied_args, | |
proxied_kwargs, | |
proxied_args, | |
{**proxied_kwargs, "grid": grid}, |
also need to change call_function above.
torch/utils/triton.py
Outdated
@@ -0,0 +1,13 @@ | |||
import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's move things rather than copying them.
I believe the linter will complain unless the filename starts with _
(indicating it is not a public API).
Btw, I would recommend figuring out the "explicit API to turn a triton kernel into a custom operator" first (or in parallel) with this PR, because there are some additional things we would need to teach Dynamo in order for it to be able to take the triton kernel and the surrounding code and call the explicit API. Concretely, in order to push the triton kernel through AOTAutograd:
|
IMO we shouldn't require any user annotations for Triton kernels. This seems like worse UX than regular Triton.
Can't we just conservatively assume everything is both an input and an output?
I don't think all of this is actually needed. If we do this, then when it gets to inductor we will need to reverse it to have efficient code. I think we can do a simple rewrite like:
becomes:
Where we basically just assume all args to the triton kernel are mutated in functionalization. The meta function for this is just it returns outputs with the same shape as its inputs, and it can be functional copying the inputs before the kernel is called. This will be easier to codegen since it requires zero metadata, and also easier for inductor to reverse. |
Is the idea that One risk is that this won't play well with aliasing. Example:
|
Yes, correct.
I think this conversion might need to happen during functionalization. You could have something like:
No need for custom user-provided metadata on this, because it returns Then functionalization changes it to:
and rewrites all remaining uses of a/b/c to use a1/b1/c1. The reference implementation could be just:
This type of conversion is exactly the same as what functionalization does for Basically there should be no need for authors of Triton kernels to provide any metadata, because Triton kernels cannot allocate tensors. Triton kernels also have no concept of sizes/strides, not is there any restriction on what is an input or an output (or both). |
@jansel IMO getting the triton kernel through the custom op mechanism is the right long-term solution (and does not make it more difficult for inductor to reverse). On the long-term: people will want to compose the triton kernel with transforms like vmap or teach subsystems like DTensor how to behave when it sees the triton kernel. The right way to do this is to integrate with the PyTorch dispatcher: turning the triton kernel into a "custom operator" allows the PyTorch dispatcher to interpose on it. In the long-term, folks will use an explicit API to deal with this, but we can support automatic Dynamo capture of triton kernels into this explicit API to support the way people author existing triton kernels.
For inductor reversal: my proposal is, given a triton kernel, turn it into two custom ops: a functional op, and an out variant op. The out variant op is exactly the triton kernel and the functional op allocates output tensors and invokes the out variant. After AOTAutograd runs, the functional op will exist in the IR that gets passed to Inductor. When Inductor sees the functional op:
I'm still thinking through this. But I'm not convinced we can generate a correct functionalization rule for this op without taking some performance hits. One particular performance issue is that we are no longer able to reorder usages of (a, b) with respect to |
Given how Triton kernels work, I'm not sure that this is coherent. Triton kernels are just too low level for this to make sense. One could use (possibly multiple) Triton kernels to build a custom op like this, but it would require a lot of extra work on the part of the user. There are also some types of Triton kernels that would not fit into this model. We could always do both. Vanilla Triton kernels work out of the box, and you can optionally use the same API as one would for custom CUDA kernels (which would require more work to setup).
I think inductor should be able to optimize away the extra copies. |
If you conservatively assume that all inputs are read-write, then you force functionalization to pessimize in this way:
to
and there isn't really anyway that Inductor can optimize this away |
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: 5f15f11fcd901f87c2a65622e6ec1de436caee81 Pull Request resolved: #109623
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: a4a4c6a3aa9ba7bebc1f19417cfdc3babffba32a Pull Request resolved: #109623
out_ptr, | ||
n_elements, | ||
BLOCK_SIZE: tl.constexpr, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You cant do this, because even tho you have HAS_TRITON
around it, the type hint happens at a different time. Try "tl.constexpr"
- type as a string. If not, feel free to omit it.
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: 3352b97a25c331fe090263750e48a4200dc9d0fa Pull Request resolved: #109623
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. ghstack-source-id: 1d1c76cfda6f09f0c25dbfa03aeaf3b0bc8629df Pull Request resolved: #109623
def __call__(self, *args, kernel, grid, **kwargs): | ||
kernel[grid](*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why we don't do (self, kernel, grid, args: Tuple, kwargs: Dict) instead? That would be a bit cleaner (what if the triton kernel has a kwarg named grid or kernel?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of itertools.partial kernel is passed as kwarg, so it has to come after positional args. I made grid be a positional arg earlier on but @jansel asked me to make it a kwarg as well. So I needed to move them after positional arg list of *args.
It is possible that I misunderstand how positional and kwargs work in python but doing (self, kernel, grid, args: Tuple, kwargs: Dict) gives error saying that kernel and grid are defined twice once as pos and once as kw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say we had a triton kernel that accepted kernel
argument (unlikely, but just going through the exercise):
@triton.jit
def sin_kernel(kernel, out):
pass
and a function like
@torch.compile
def f(x):
out = torch.empty_like(x)
sin_kernel[grid](kernel=x, out=out)
return out
Will this cause a conflict with the kernel
argument to the HigherOrderOp? (if not, then I'm satisfied). I expect that the TritonKernelVariable.call_function
receives kwargs= {"kernel": TensorVariable, "out": TensorVariable}, and then us later invoking triton_kernel_wrapper_mutation with 2 different kernel
s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this does in fact result in TypeError: 'Tensor' object is not callable
. Since we are using itertools.partial, we are forced to use kwargs AFAICT. How would we be able to name this something a user cannot write themselves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see the problem. If we did (kernel, grid, args, kwargs) as the signature, we could pass functools.partial the kernel as a positional arg. Does that work? For example:
def f(kernel, grid, *args, **kwargs):
return (kernel, grid, args, kwargs)
kernel = 1
kf = functools.partial(f, kernel)
print(kf(2, 3, 4, five=5))
# (1, 2, (3, 4), {'five': 5})
torch/_dynamo/variables/functions.py
Outdated
# __getitem__ should only be called if we don't already have a grid | ||
assert self.grid is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to raise Unsupported (and in some of the other asserts that the user may see, like `assert len(args) == 1)? That way we fallback to eager-mode and the user can see the error that eager-mode provides.
if not hasattr(fn, "__module__"): | ||
# Super hacky but on AMD __module__ is not set | ||
fn.__module__ = "itertools" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is an amd thing, this might be a python version thing (locally for me functools.partial doesn't return something with a .__module__
).
Also it would be better to set the __module__
to triton_kernel_wrapper_mutation's .__module__
: IIRC these two things are responsible for how the function displays in the dynamo graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm deferring to the Dynamo folks (Jason, Voz) for the Dynamo pieces (but those look OK to me). The HigherOrderOp LGTM, would prefer that the signature looks sane like (kernel, grid, args, kwargs), but not a big deal if it doesn't
@zou3519 and I chatted offline, I'll update the diff to make the call functions be |
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
This PR adds native support to Dynamo to detect Triton kernels and create an FX graph node out of them. AOT eager and inductor modes will be support in follow up PRs. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
|
||
# If the grid is a function, then lets execute it and convert it to | ||
# a list | ||
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine, we can add more later if needed.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #110185 Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/bdhirsh ghstack dependencies: #109623
Stack from ghstack (oldest at bottom):
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng