-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AOTInductor] Implement autograd eager backend for native triton kernels #110403
Changes from 1 commit
2fbdfbc
7c9c462
b37f110
c34d126
b8a9266
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,23 @@ | ||
import functools | ||
|
||
import torch.utils._pytree as pytree | ||
from torch import Tensor | ||
from torch._C import DispatchKey | ||
from torch._ops import HigherOrderOperator | ||
from torch._prims_common import clone_preserve_strides | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
from torch.fx.experimental.proxy_tensor import ( | ||
disable_proxy_modes_tracing, | ||
ProxyTorchDispatchMode, | ||
track_tensor_tree, | ||
) | ||
|
||
|
||
# Used for wrapping a Triton Kernel | ||
class TritonKernelWrapperMutation(HigherOrderOperator): | ||
def __init__(self): | ||
super().__init__("triton_kernel_wrapper_mutation") | ||
|
||
def __call__(self, *, kernel, grid, kwargs): | ||
kernel[grid](**kwargs) | ||
|
||
|
||
triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() | ||
|
||
|
@@ -20,13 +27,141 @@ class TritonKernelWrapperFunctional(HigherOrderOperator): | |
def __init__(self): | ||
super().__init__("triton_kernel_wrapper_functional") | ||
|
||
def __call__(self, *, kernel, grid, kwargs): | ||
kwargs = { | ||
|
||
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) | ||
def triton_kernel_wrapper_mutation_dense(*, kernel, grid, kwargs): | ||
kernel[grid](**kwargs) | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) | ||
def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel, grid, kwargs): | ||
with mode: | ||
return None | ||
|
||
|
||
def trace_triton_kernel_wrapper(proxy_mode, func_overload, *, kernel, grid, kwargs): | ||
with disable_proxy_modes_tracing(): | ||
out = func_overload(kernel=kernel, grid=grid, kwargs=kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is always supposed to return Maybe add an assert? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only the mutation version returns none, but the functional version returns kwargs |
||
|
||
fn = functools.partial(func_overload, kernel=kernel) | ||
# FX graph needs __name__ and __module__ attributes | ||
fn.__name__ = func_overload.__name__ # type:ignore[attr-defined] | ||
if not hasattr(fn, "__module__"): | ||
# Super hacky but on AMD __module__ is not set | ||
fn.__module__ = "itertools" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we unify this logic with the Dynamo logic that does something similar? Maybe put into a helper function somewhere? Also we might want to note somewhere that we are doing a hack to workaround how fx does not allow functions in the graph. |
||
|
||
node_args = {"grid": grid, "kwargs": kwargs} | ||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) | ||
out_proxy = proxy_mode.tracer.create_proxy( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment that is probably more relevant for the previous dynamo PR (although fine to deal with as a followup): It's probably worth having a test for the case where the triton kernel (and maybe the grid lambda) is an inner function. Does everything still work out? (The graph that we compile probably directly holds a reference that inner function, so will the lifetime of the inner function match what its lifetime would have been in eager mode?)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will also do as follow up |
||
"call_function", | ||
fn, | ||
(), | ||
proxy_args, | ||
name="triton_kernel_wrapper_mutation", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: the name should probably be |
||
) | ||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be fair, I don't think we'll ever actually hit this in the "vanilla" torch.compile flow: functionalization runs above Mostly just an fyi - still agreed that we need this (for example, pre_dispatch tracing today will hit this, since it does not (yet) use functionalization) |
||
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( | ||
mode, *, kernel, grid, kwargs | ||
): | ||
if mode.enable_tracing: | ||
trace_triton_kernel_wrapper( | ||
mode, | ||
triton_kernel_wrapper_mutation, | ||
kernel=kernel, | ||
grid=grid, | ||
kwargs=kwargs, | ||
) | ||
else: | ||
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs) | ||
|
||
return None | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_functionalize_impl | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this implementation, we're also relying on the fact that both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This FakeTensorMode rule for the functional op:
Is implicitly assuming that the outputs to the triton kernel always have the same metadata (sizes/strides/etc) of the inputs. So if we're handed a triton kernel that mutates sizes/strides of a tensor input, our shape prop logic will be wrong during tracing. To be fair though - I'm imagining that it's either impossible or very rare for a triton kernel to do this - so if it is possible, the answer is probably to detect/ban it in dynamo. |
||
def triton_kernel_wrapper_mutation_functionalize(ctx, kernel, grid, kwargs): | ||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) | ||
with ctx.redispatch_to_next(): | ||
unwrapped_outputs = triton_kernel_wrapper_functional( | ||
kernel=kernel, grid=grid, kwargs=unwrapped_kwargs | ||
) | ||
|
||
assert unwrapped_outputs.keys() == kwargs.keys() | ||
for key, output_arg in unwrapped_outputs.items(): | ||
if not isinstance(output_arg, Tensor): | ||
continue | ||
input_arg = kwargs[key] | ||
assert isinstance(input_arg, Tensor) | ||
|
||
ctx.replace(input_arg, output_arg) | ||
ctx.commit_update(input_arg) | ||
ctx.sync(input_arg) | ||
return None | ||
|
||
|
||
@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) | ||
def triton_kernel_wrapper_functional_dense(*, kernel, grid, kwargs): | ||
kwargs = { | ||
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val) | ||
oulgen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for key, val in kwargs.items() | ||
} | ||
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs) | ||
return kwargs | ||
|
||
|
||
@triton_kernel_wrapper_functional.py_impl(FakeTensorMode) | ||
def triton_kernel_wrapper_functional_fake_tensor_mode(mode, *, kernel, grid, kwargs): | ||
with mode: | ||
return { | ||
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val) | ||
for key, val in kwargs.items() | ||
} | ||
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs) | ||
return kwargs | ||
|
||
|
||
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() | ||
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) | ||
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( | ||
mode, *, kernel, grid, kwargs | ||
): | ||
if mode.enable_tracing: | ||
return trace_triton_kernel_wrapper( | ||
mode, | ||
triton_kernel_wrapper_functional, | ||
kernel=kernel, | ||
grid=grid, | ||
kwargs=kwargs, | ||
) | ||
else: | ||
return triton_kernel_wrapper_functional(kernel=kernel, grid=grid, kwargs=kwargs) | ||
|
||
|
||
@triton_kernel_wrapper_functional.py_functionalize_impl | ||
def triton_kernel_wrapper_functional_functionalize(ctx, kernel, grid, kwargs): | ||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) | ||
with ctx.redispatch_to_next(): | ||
outputs = triton_kernel_wrapper_functional( | ||
kernel=kernel, grid=grid, kwargs=unwrapped_kwargs | ||
) | ||
return ctx.wrap_tensors(outputs) | ||
|
||
|
||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView) | ||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect) | ||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA) | ||
oulgen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView) | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect) | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) | ||
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,6 +170,15 @@ def from_functional(self): | |
torch._sync(self) | ||
return torch._from_functional_tensor(self.elem) | ||
|
||
def replace_(self, output) -> None: | ||
torch._functionalize_replace(self.elem, output) # type: ignore[attr-defined] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: add the type annotation to https://github.com/pytorch/pytorch/blob/main/torch/_C/__init__.pyi.in to make mypy happy, but it looks like we didn't care much about that in this file |
||
|
||
def commit_update(self) -> None: | ||
torch._functionalize_commit_update(self.elem) # type: ignore[attr-defined] | ||
|
||
def sync(self) -> None: | ||
torch._functionalize_sync(self.elem) # type: ignore[attr-defined] | ||
|
||
|
||
class FunctionalTensorMode(TorchDispatchMode): | ||
def __init__(self): | ||
|
@@ -382,6 +391,18 @@ def functionalize(self, inner_f: Callable) -> Callable: | |
def redispatch_to_next(self) -> ContextManager: | ||
pass | ||
|
||
@abstractmethod | ||
def replace(self, input_tensor, output_tensor) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def commit_update(self, tensor) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def sync(self, tensor) -> None: | ||
pass | ||
|
||
|
||
class PythonFunctionalizeAPI(BaseFunctionalizeAPI): | ||
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: | ||
|
@@ -400,6 +421,19 @@ def functionalize(self, inner_f: Callable) -> Callable: | |
def redispatch_to_next(self) -> ContextManager: | ||
return unset_functional_temporarily() | ||
|
||
def replace(self, input_tensor, output_tensor) -> None: | ||
assert isinstance(input_tensor, FunctionalTensor) | ||
assert not isinstance(output_tensor, FunctionalTensor) | ||
input_tensor.replace_(output_tensor) | ||
|
||
def commit_update(self, tensor) -> None: | ||
assert isinstance(tensor, FunctionalTensor) | ||
tensor.commit_update() | ||
|
||
def sync(self, tensor) -> None: | ||
assert isinstance(tensor, FunctionalTensor) | ||
tensor.sync() | ||
|
||
|
||
class CppFunctionalizeAPI(BaseFunctionalizeAPI): | ||
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: | ||
|
@@ -422,6 +456,15 @@ def redispatch_to_next(self) -> ContextManager: | |
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) | ||
) | ||
|
||
def replace(self, input_tensor, output_tensor) -> None: | ||
torch._functionalize_replace(input_tensor, output_tensor) # type: ignore[attr-defined] | ||
|
||
def commit_update(self, tensor) -> None: | ||
torch._functionalize_commit_update(tensor) # type: ignore[attr-defined] | ||
|
||
def sync(self, tensor) -> None: | ||
torch._functionalize_sync(tensor) # type: ignore[attr-defined] | ||
|
||
|
||
class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): | ||
def __init__(self, interpreter): | ||
|
@@ -451,3 +494,12 @@ def functionalize(self, inner_f: Callable) -> Callable: | |
|
||
def redispatch_to_next(self) -> ContextManager: | ||
return self.interpreter.lower() | ||
|
||
def replace(self, input_tensor, output_tensor) -> None: | ||
torch._functionalize_replace(input_tensor, output_tensor) # type: ignore[attr-defined] | ||
|
||
def commit_update(self, tensor) -> None: | ||
torch._functionalize_commit_update(tensor) # type: ignore[attr-defined] | ||
|
||
def sync(self, tensor) -> None: | ||
torch._functionalize_sync(tensor) # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these tests are fine for this PR, but when we add inductor support it would be great to have some more comprehensive tests (just testing that the outputs match eager mode):
(1) triton kernel that mutates an input to the graph
(2) triton kernel that mutate a view of an input to the graph
(3) graph that returns a view/alias of one of the mutated arguments to a triton kernel
For the first case (simple input mutation), there are ~3 individual cases worth testing. This is because these cases impact whether or not AOTAutograd sends the
copy_()
(input mutation) directly to inductor in the graph, so they affect whether or not inductor will be able to optimize away the clone.(a) the graph is an inference graph (test is under
no_grad()
, or no inputs to the graph require gradient)(b) graph is a training graph (at least one input requires grad), but the arguments to the triton kernel to not require grad
(c) graph is a training graph, and the argument(s) to the triton kernel also require grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add as follow up. Although I need some pointers for how to implement some of these test cases.
Right now all my tests are (1) I believe and none of my tensors require grad (i did not do require_grad=True).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just set your tensor inputs to the compiled code to require grad:
Testing just the forward is fine - but you're right, a more comprehensive test would involve also running autograd and checking that the gradients match between eager and compile. Example: