New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AOTInductor] Implement autograd eager backend for native triton kernels #110403
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110403
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit b8a9266 with merge base efb73fe (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 1eecfac35d1796ee51e5f174fdafb796ea3b1f91 Pull Request resolved: #110403
if isinstance(proxy, fx.Proxy): | ||
set_meta(proxy, e) | ||
|
||
# example use case: allreduce_ returns ([tensor], work) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i will update this comment
…triton kernels" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: e41cd757f7a3f29c44855fb1ba45cbdf8830bbcb Pull Request resolved: #110403
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, minor nits
# FX graph needs __name__ and __module__ attributes | ||
fn.__name__ = func_overload.__name__ # type:ignore[attr-defined] | ||
if not hasattr(fn, "__module__"): | ||
# Super hacky but on AMD __module__ is not set | ||
fn.__module__ = "itertools" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we unify this logic with the Dynamo logic that does something similar? Maybe put into a helper function somewhere? Also we might want to note somewhere that we are doing a hack to workaround how fx does not allow functions in the graph.
fn, | ||
(), | ||
proxy_args, | ||
name="triton_kernel_wrapper_mutation", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the name should probably be overload.__name__
; in the functional case it is better for the name of the output to be be triton_kernel_wrapper_functional
than triton_kernel_wrapper_mutation
@@ -170,6 +170,15 @@ def from_functional(self): | |||
torch._sync(self) | |||
return torch._from_functional_tensor(self.elem) | |||
|
|||
def replace_(self, output) -> None: | |||
torch._functionalize_replace(self.elem, output) # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: add the type annotation to https://github.com/pytorch/pytorch/blob/main/torch/_C/__init__.pyi.in to make mypy happy, but it looks like we didn't care much about that in this file
|
||
node_args = {"grid": grid, "kwargs": kwargs} | ||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) | ||
out_proxy = proxy_mode.tracer.create_proxy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment that is probably more relevant for the previous dynamo PR (although fine to deal with as a followup): It's probably worth having a test for the case where the triton kernel (and maybe the grid lambda) is an inner function. Does everything still work out? (The graph that we compile probably directly holds a reference that inner function, so will the lifetime of the inner function match what its lifetime would have been in eager mode?)
@torch.compile
def f(x, y):
@triton.jit
def kernel(in_ptr, out_ptr0, BLOCK_SIZE: "tl.constexpr"):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
...
# if grid/kernel hold onto some global state (like a large cuda tensor), then that memory will stick around until grid/kernel is cleaned up.
return kernel[grid](x)
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will also do as follow up
|
||
def trace_triton_kernel_wrapper(proxy_mode, func_overload, *, kernel, grid, kwargs): | ||
with disable_proxy_modes_tracing(): | ||
out = func_overload(kernel=kernel, grid=grid, kwargs=kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is always supposed to return None
, right? Since this will redispatch to the FakeTensorMode
implementation, which returns None (since this higher order op is only ever used for triton kernels).
Maybe add an assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the mutation version returns none, but the functional version returns kwargs
return None | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_functionalize_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this implementation, we're also relying on the fact that both grid
and kernel
will never mutate the metadata of any of the inputs. I think we said we'll (eventually) try to check for mutations from grid
in dynamo and error. But is it possible for a triton kernel to mutate the metadata of a tensor input? e.g. in_tensor.transpose_(1, 0)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This FakeTensorMode rule for the functional op:
with mode:
return {
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
for key, val in kwargs.items()
}
Is implicitly assuming that the outputs to the triton kernel always have the same metadata (sizes/strides/etc) of the inputs. So if we're handed a triton kernel that mutates sizes/strides of a tensor input, our shape prop logic will be wrong during tracing.
To be fair though - I'm imagining that it's either impossible or very rare for a triton kernel to do this - so if it is possible, the answer is probably to detect/ban it in dynamo.
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be fair, I don't think we'll ever actually hit this in the "vanilla" torch.compile flow: functionalization runs above ProxyTorchDispatchMode
, so we will always have converted triton_kernel_wrapper_mutation
calls into triton_kernel_wrapper_functional
calls in the graph before hitting the proxy mode.
Mostly just an fyi - still agreed that we need this (for example, pre_dispatch tracing today will hit this, since it does not (yet) use functionalization)
@@ -194,6 +194,18 @@ def wrap_with_proxy(e, proxy, constant): | |||
# example use case: allreduce_ returns ([tensor], work) | |||
for idx, ee in enumerate(e): | |||
wrap_with_proxy(ee, proxy[idx], get_constant(idx)) | |||
elif isinstance(e, dict): | |||
assert constant is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth a comment: in theory we could support const-prop when proxy-tensor-tracing operators that returns dicts of tensors, but we have no use case for it today (since the only op we currently trace that can return a dict is triton_kernel_wrapper_functional/mutation
, which does not participate in const-prop)
@@ -1541,6 +1541,73 @@ def add_kernel( | |||
# Make sure it is NOT modified | |||
self.assertEqual(output, torch.zeros_like(t1)) | |||
|
|||
@requires_cuda() | |||
@requires_triton() | |||
def test_triton_kernel_functionalize(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these tests are fine for this PR, but when we add inductor support it would be great to have some more comprehensive tests (just testing that the outputs match eager mode):
(1) triton kernel that mutates an input to the graph
(2) triton kernel that mutate a view of an input to the graph
(3) graph that returns a view/alias of one of the mutated arguments to a triton kernel
For the first case (simple input mutation), there are ~3 individual cases worth testing. This is because these cases impact whether or not AOTAutograd sends the copy_()
(input mutation) directly to inductor in the graph, so they affect whether or not inductor will be able to optimize away the clone.
(a) the graph is an inference graph (test is under no_grad()
, or no inputs to the graph require gradient)
(b) graph is a training graph (at least one input requires grad), but the arguments to the triton kernel to not require grad
(c) graph is a training graph, and the argument(s) to the triton kernel also require grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add as follow up. Although I need some pointers for how to implement some of these test cases.
Right now all my tests are (1) I believe and none of my tensors require grad (i did not do require_grad=True).
- How do make inputs to triton kernel require grad?
- Do I need to test things differently when tensors require grad or do I still check outputs match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do make inputs to triton kernel require grad?
You can just set your tensor inputs to the compiled code to require grad:
@torch.compile
def f(x):
return kernel[grid](x, ...)
x = torch.randn(4, requires_grad=True)
out = f(x)
Do I need to test things differently when tensors require grad or do I still check outputs match?
Testing just the forward is fine - but you're right, a more comprehensive test would involve also running autograd and checking that the gradients match between eager and compile. Example:
@torch.compile
def f_triton_input_mutation_of_view(x, out):
out_view = out[:, 1]
kernel[grid](x=x, out=out_view)
return torch.mul(out_view, 2)
x = torch.randn(4)
out = torch.randn(4, requires_grad=True)
# run autograd
out.sum().backward()
# Run the same with ref inputs on eager, check that the grads match
self.assertEqual(out.grad, out_ref.grad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good! Left some minor comments that don't all need to be addressed in this PR
…triton kernels" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 295f7ad5da97982c28ff91e88874053217c9c66b Pull Request resolved: #110403
…triton kernels" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
ghstack-source-id: 7623fe472754dfef821250f3c5c49f4abdb30ab4 Pull Request resolved: #110403
…triton kernels" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #110486 Approved by: https://github.com/jansel ghstack dependencies: #110403
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng