Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTInductor] Implement autograd eager backend for native triton kernels #110403

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
92 changes: 82 additions & 10 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
disable_translation_validation_if_dynamic_shapes,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand Down Expand Up @@ -1548,6 +1549,74 @@ def add_kernel(
# Make sure it is NOT modified
self.assertEqual(output, torch.zeros_like(t1))

@requires_cuda()
@requires_triton()
@skipIfRocm
def test_triton_kernel_functionalize(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests are fine for this PR, but when we add inductor support it would be great to have some more comprehensive tests (just testing that the outputs match eager mode):

(1) triton kernel that mutates an input to the graph
(2) triton kernel that mutate a view of an input to the graph
(3) graph that returns a view/alias of one of the mutated arguments to a triton kernel

For the first case (simple input mutation), there are ~3 individual cases worth testing. This is because these cases impact whether or not AOTAutograd sends the copy_() (input mutation) directly to inductor in the graph, so they affect whether or not inductor will be able to optimize away the clone.

(a) the graph is an inference graph (test is under no_grad(), or no inputs to the graph require gradient)
(b) graph is a training graph (at least one input requires grad), but the arguments to the triton kernel to not require grad
(c) graph is a training graph, and the argument(s) to the triton kernel also require grad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add as follow up. Although I need some pointers for how to implement some of these test cases.

Right now all my tests are (1) I believe and none of my tensors require grad (i did not do require_grad=True).

  • How do make inputs to triton kernel require grad?
  • Do I need to test things differently when tensors require grad or do I still check outputs match?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do make inputs to triton kernel require grad?

You can just set your tensor inputs to the compiled code to require grad:

@torch.compile
def f(x):
    return kernel[grid](x, ...)

x = torch.randn(4, requires_grad=True)
out = f(x)

Do I need to test things differently when tensors require grad or do I still check outputs match?

Testing just the forward is fine - but you're right, a more comprehensive test would involve also running autograd and checking that the gradients match between eager and compile. Example:

@torch.compile
def f_triton_input_mutation_of_view(x, out):
    out_view = out[:, 1]
    kernel[grid](x=x, out=out_view)
    return torch.mul(out_view, 2)

x = torch.randn(4)
out = torch.randn(4, requires_grad=True)
# run autograd
out.sum().backward()

# Run the same with ref inputs on eager, check that the grads match
self.assertEqual(out.grad, out_ref.grad)

import functorch
from functorch import make_fx
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI,
FunctorchFunctionalizeAPI,
PythonFunctionalizeAPI,
)

@triton.jit
def kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = 2 * x
tl.store(out_ptr + offsets, output, mask=mask)

def f(x, output):
out = triton_kernel_wrapper_functional(
kernel=kernel,
grid=(x.numel(),),
kwargs={
"in_ptr0": x,
"out_ptr": output,
"n_elements": output.numel(),
"BLOCK_SIZE": 16,
},
)
return out["out_ptr"]

t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")

gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)

gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)

gm = make_fx(torch.func.functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)

gm = make_fx(f, tracing_mode="fake")(t1, t2)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1, output_1):
triton_kernel_wrapper_functional_proxy = functools_triton_kernel_wrapper_functional(grid = (5,), kwargs = {'in_ptr0': x_1, 'out_ptr': output_1, 'n_elements': 5, 'BLOCK_SIZE': 16}); x_1 = output_1 = None
getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']
getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']
getitem_2 = triton_kernel_wrapper_functional_proxy['n_elements']
getitem_3 = triton_kernel_wrapper_functional_proxy['BLOCK_SIZE']; triton_kernel_wrapper_functional_proxy = None
return getitem_1""",
)

@requires_cuda()
@requires_triton()
def test_triton_kernel_by_hand(self):
Expand Down Expand Up @@ -1601,16 +1670,19 @@ def grid_fn(meta):
# No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE)
self.assertEqual(call_triton_add(t1, t2, 1, True), torch_add)

# With Dynamo
compiled_func = torch.compile(call_triton_add, backend="eager", fullgraph=True)
# With simple kernel
self.assertEqual(compiled_func(t1, t2, 0), torch_add)
# With lambda kernel
self.assertEqual(compiled_func(t1, t2, 1), torch_add)
# With lambda kernel (with positional BLOCK_SIZE)
self.assertEqual(compiled_func(t1, t2, 1, 1, True), torch_add)
# With user defined function kernel
self.assertEqual(compiled_func(t1, t2, 2, 200), torch_add)
for backend in ["eager", "aot_eager"]:
# With Dynamo
compiled_func = torch.compile(
call_triton_add, backend=backend, fullgraph=True
)
# With simple kernel
self.assertEqual(compiled_func(t1, t2, 0), torch_add)
# With lambda kernel
self.assertEqual(compiled_func(t1, t2, 1), torch_add)
# With lambda kernel (with positional BLOCK_SIZE)
self.assertEqual(compiled_func(t1, t2, 1, 1, True), torch_add)
# With user defined function kernel
self.assertEqual(compiled_func(t1, t2, 2, 200), torch_add)

def test_dataclass_factory(self):
@dataclass
Expand Down
7 changes: 7 additions & 0 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,13 @@ def gen_pyi(
"_to_functional_tensor": [
"def _to_functional_tensor(t: Tensor) -> Tensor: ..."
],
"_functionalize_replace": [
"def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..."
],
"_functionalize_commit_update": [
"def _functionalize_commit_update(t: Tensor) -> None: ..."
],
"_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."],
"_enable_functionalization": [
"def _enable_functionalization(*, reapply_views: _bool = False): ..."
],
Expand Down
12 changes: 4 additions & 8 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,22 +677,18 @@ def call_function(
unimplemented(f"grid for the triton kernel is {type(grid)}")

from torch._higher_order_ops.triton_kernel_wrap import (
prepare_triton_kernel_for_graph_node,
triton_kernel_wrapper_mutation,
)

fn = functools.partial(triton_kernel_wrapper_mutation, kernel=self.kernel)
# FX graph needs __name__ and __module__ attributes
fn.__name__ = triton_kernel_wrapper_mutation.__name__
if not hasattr(fn, "__module__"):
# Super hacky but on AMD __module__ is not set
fn.__module__ = "itertools"

# Combine args and kwargs and pass as a dict so that if user defined triton
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
# parameters of the wrapper function
tx.output.create_proxy(
"call_function",
fn,
prepare_triton_kernel_for_graph_node(
triton_kernel_wrapper_mutation, self.kernel
),
(),
{
"grid": grid,
Expand Down
167 changes: 159 additions & 8 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import functools

import torch.utils._pytree as pytree
from torch import Tensor
from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
from torch._prims_common import clone_preserve_strides
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)


# Used for wrapping a Triton Kernel
class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self):
super().__init__("triton_kernel_wrapper_mutation")

def __call__(self, *, kernel, grid, kwargs):
kernel[grid](**kwargs)


triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()

Expand All @@ -20,13 +27,157 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self):
super().__init__("triton_kernel_wrapper_functional")

def __call__(self, *, kernel, grid, kwargs):
kwargs = {

triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()


@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_mutation_dense(*, kernel, grid, kwargs):
kernel[grid](**kwargs)


@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel, grid, kwargs):
with mode:
return None


def prepare_triton_kernel_for_graph_node(f, kernel):
# This is a hack to workaround how FX does not allow for functions
# in the graph

fn = functools.partial(f, kernel=kernel)
# FX graph needs __name__ and __module__ attributes
fn.__name__ = f.__name__ # type:ignore[attr-defined]
if not hasattr(fn, "__module__"):
# Super hacky but on AMD __module__ is not set
fn.__module__ = "itertools"
return fn


def trace_triton_kernel_wrapper(proxy_mode, func_overload, *, kernel, grid, kwargs):
with disable_proxy_modes_tracing():
out = func_overload(kernel=kernel, grid=grid, kwargs=kwargs)

node_args = {"grid": grid, "kwargs": kwargs}
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment that is probably more relevant for the previous dynamo PR (although fine to deal with as a followup): It's probably worth having a test for the case where the triton kernel (and maybe the grid lambda) is an inner function. Does everything still work out? (The graph that we compile probably directly holds a reference that inner function, so will the lifetime of the inner function match what its lifetime would have been in eager mode?)

@torch.compile
def f(x, y):
        @triton.jit
        def kernel(in_ptr, out_ptr0, BLOCK_SIZE: "tl.constexpr"):
            pid = tl.program_id(axis=0)
            block_start = pid * BLOCK_SIZE
            ...
        # if grid/kernel hold onto some global state (like a large cuda tensor), then that memory will stick around until grid/kernel is cleaned up.
        return kernel[grid](x)
    ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will also do as follow up

"call_function",
prepare_triton_kernel_for_graph_node(func_overload, kernel),
(),
proxy_args,
name=func_overload.__name__ + "_proxy",
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)


@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, I don't think we'll ever actually hit this in the "vanilla" torch.compile flow: functionalization runs above ProxyTorchDispatchMode, so we will always have converted triton_kernel_wrapper_mutation calls into triton_kernel_wrapper_functional calls in the graph before hitting the proxy mode.

Mostly just an fyi - still agreed that we need this (for example, pre_dispatch tracing today will hit this, since it does not (yet) use functionalization)

def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
mode, *, kernel, grid, kwargs
):
if mode.enable_tracing:
trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_mutation,
kernel=kernel,
grid=grid,
kwargs=kwargs,
)
else:
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)

return None


@triton_kernel_wrapper_mutation.py_functionalize_impl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this implementation, we're also relying on the fact that both grid and kernel will never mutate the metadata of any of the inputs. I think we said we'll (eventually) try to check for mutations from grid in dynamo and error. But is it possible for a triton kernel to mutate the metadata of a tensor input? e.g. in_tensor.transpose_(1, 0).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the clone/replace for the triton kernel inputs, if the kernel mutates the metadata will that not be sufficient? @jansel any comments on this?

As for the grid, @jansel recommended doing something similar to what cond does for mutation checking, so i was planning on doing something like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This FakeTensorMode rule for the functional op:

    with mode:
        return {
            key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
            for key, val in kwargs.items()
        }

Is implicitly assuming that the outputs to the triton kernel always have the same metadata (sizes/strides/etc) of the inputs. So if we're handed a triton kernel that mutates sizes/strides of a tensor input, our shape prop logic will be wrong during tracing.

To be fair though - I'm imagining that it's either impossible or very rare for a triton kernel to do this - so if it is possible, the answer is probably to detect/ban it in dynamo.

def triton_kernel_wrapper_mutation_functionalize(ctx, kernel, grid, kwargs):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
unwrapped_outputs = triton_kernel_wrapper_functional(
kernel=kernel, grid=grid, kwargs=unwrapped_kwargs
)

assert unwrapped_outputs.keys() == kwargs.keys()
for key, output_arg in unwrapped_outputs.items():
if not isinstance(output_arg, Tensor):
continue
input_arg = kwargs[key]
assert isinstance(input_arg, Tensor)

ctx.replace(input_arg, output_arg)
ctx.commit_update(input_arg)
ctx.sync(input_arg)
return None


@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_functional_dense(*, kernel, grid, kwargs):
# TODO(oulgen): For performance reasons, we want to ensure that these
# `clone_preserve_strides` calls are never executed at runtime
# (inductor should always optimize them away).
# Requires https://github.com/pytorch/pytorch/issues/109240
kwargs = {
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
oulgen marked this conversation as resolved.
Show resolved Hide resolved
for key, val in kwargs.items()
}
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)
return kwargs


@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
def triton_kernel_wrapper_functional_fake_tensor_mode(mode, *, kernel, grid, kwargs):
# TODO(oulgen): For performance reasons, we want to ensure that these
# `clone_preserve_strides` calls are never executed at runtime
# (inductor should always optimize them away).
# Requires https://github.com/pytorch/pytorch/issues/109240
with mode:
return {
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
for key, val in kwargs.items()
}
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)
return kwargs


triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
mode, *, kernel, grid, kwargs
):
if mode.enable_tracing:
return trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_functional,
kernel=kernel,
grid=grid,
kwargs=kwargs,
)
else:
return triton_kernel_wrapper_functional(kernel=kernel, grid=grid, kwargs=kwargs)


@triton_kernel_wrapper_functional.py_functionalize_impl
def triton_kernel_wrapper_functional_functionalize(ctx, kernel, grid, kwargs):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
outputs = triton_kernel_wrapper_functional(
kernel=kernel, grid=grid, kwargs=unwrapped_kwargs
)
return ctx.wrap_tensors(outputs)


triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)
oulgen marked this conversation as resolved.
Show resolved Hide resolved
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU)

triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU)