Skip to content

Commit

Permalink
[AOTInductor] Implement autograd eager backend for native triton kernels
Browse files Browse the repository at this point in the history
ghstack-source-id: e41cd757f7a3f29c44855fb1ba45cbdf8830bbcb
Pull Request resolved: #110403
  • Loading branch information
oulgen committed Oct 2, 2023
1 parent 934cc18 commit 47b8727
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 18 deletions.
90 changes: 80 additions & 10 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,73 @@ def add_kernel(
# Make sure it is NOT modified
self.assertEqual(output, torch.zeros_like(t1))

@requires_cuda()
@requires_triton()
def test_triton_kernel_functionalize(self):
import functorch
from functorch import make_fx
from torch._subclasses.functional_tensor import (
CppFunctionalizeAPI,
FunctorchFunctionalizeAPI,
PythonFunctionalizeAPI,
)

@triton.jit
def kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = 2 * x
tl.store(out_ptr + offsets, output, mask=mask)

def f(x, output):
out = triton_kernel_wrapper_functional(
kernel=kernel,
grid=(x.numel(),),
kwargs={
"in_ptr0": x,
"out_ptr": output,
"n_elements": output.numel(),
"BLOCK_SIZE": 16,
},
)
return out["out_ptr"]

t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")

gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)

gm = make_fx(CppFunctionalizeAPI().functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)

gm = make_fx(torch.func.functionalize(f))(t1, t2)
# Make sure t2 was not modified
self.assertNotEqual(gm(t1, t2), t2)

gm = make_fx(f, tracing_mode="fake")(t1, t2)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1, output_1):
triton_kernel_wrapper_mutation = functools_triton_kernel_wrapper_functional(grid = (5,), kwargs = {'in_ptr0': x_1, 'out_ptr': output_1, 'n_elements': 5, 'BLOCK_SIZE': 16}); x_1 = output_1 = None
getitem = triton_kernel_wrapper_mutation['in_ptr0']
getitem_1 = triton_kernel_wrapper_mutation['out_ptr']
getitem_2 = triton_kernel_wrapper_mutation['n_elements']
getitem_3 = triton_kernel_wrapper_mutation['BLOCK_SIZE']; triton_kernel_wrapper_mutation = None
return getitem_1""",
)

@requires_cuda()
@requires_triton()
def test_triton_kernel_by_hand(self):
Expand Down Expand Up @@ -1594,16 +1661,19 @@ def grid_fn(meta):
# No Dynamo -- Make sure triton kernel works (with positional BLOCK_SIZE)
self.assertEqual(call_triton_add(t1, t2, 1, True), torch_add)

# With Dynamo
compiled_func = torch.compile(call_triton_add, backend="eager", fullgraph=True)
# With simple kernel
self.assertEqual(compiled_func(t1, t2, 0), torch_add)
# With lambda kernel
self.assertEqual(compiled_func(t1, t2, 1), torch_add)
# With lambda kernel (with positional BLOCK_SIZE)
self.assertEqual(compiled_func(t1, t2, 1, 1, True), torch_add)
# With user defined function kernel
self.assertEqual(compiled_func(t1, t2, 2, 200), torch_add)
for backend in ["eager", "aot_eager"]:
# With Dynamo
compiled_func = torch.compile(
call_triton_add, backend=backend, fullgraph=True
)
# With simple kernel
self.assertEqual(compiled_func(t1, t2, 0), torch_add)
# With lambda kernel
self.assertEqual(compiled_func(t1, t2, 1), torch_add)
# With lambda kernel (with positional BLOCK_SIZE)
self.assertEqual(compiled_func(t1, t2, 1, 1, True), torch_add)
# With user defined function kernel
self.assertEqual(compiled_func(t1, t2, 2, 200), torch_add)

def test_dataclass_factory(self):
@dataclass
Expand Down
151 changes: 143 additions & 8 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import functools

import torch.utils._pytree as pytree
from torch import Tensor
from torch._C import DispatchKey
from torch._ops import HigherOrderOperator
from torch._prims_common import clone_preserve_strides
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)


# Used for wrapping a Triton Kernel
class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self):
super().__init__("triton_kernel_wrapper_mutation")

def __call__(self, *, kernel, grid, kwargs):
kernel[grid](**kwargs)


triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()

Expand All @@ -20,13 +27,141 @@ class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self):
super().__init__("triton_kernel_wrapper_functional")

def __call__(self, *, kernel, grid, kwargs):
kwargs = {

triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()


@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_mutation_dense(*, kernel, grid, kwargs):
kernel[grid](**kwargs)


@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel, grid, kwargs):
with mode:
return None


def trace_triton_kernel_wrapper(proxy_mode, func_overload, *, kernel, grid, kwargs):
with disable_proxy_modes_tracing():
out = func_overload(kernel=kernel, grid=grid, kwargs=kwargs)

fn = functools.partial(func_overload, kernel=kernel)
# FX graph needs __name__ and __module__ attributes
fn.__name__ = func_overload.__name__ # type:ignore[attr-defined]
if not hasattr(fn, "__module__"):
# Super hacky but on AMD __module__ is not set
fn.__module__ = "itertools"

node_args = {"grid": grid, "kwargs": kwargs}
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function",
fn,
(),
proxy_args,
name="triton_kernel_wrapper_mutation",
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)


@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
mode, *, kernel, grid, kwargs
):
if mode.enable_tracing:
trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_mutation,
kernel=kernel,
grid=grid,
kwargs=kwargs,
)
else:
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)

return None


@triton_kernel_wrapper_mutation.py_functionalize_impl
def triton_kernel_wrapper_mutation_functionalize(ctx, kernel, grid, kwargs):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
unwrapped_outputs = triton_kernel_wrapper_functional(
kernel=kernel, grid=grid, kwargs=unwrapped_kwargs
)

assert unwrapped_outputs.keys() == kwargs.keys()
for key, output_arg in unwrapped_outputs.items():
if not isinstance(output_arg, Tensor):
continue
input_arg = kwargs[key]
assert isinstance(input_arg, Tensor)

ctx.replace(input_arg, output_arg)
ctx.commit_update(input_arg)
ctx.sync(input_arg)
return None


@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_functional_dense(*, kernel, grid, kwargs):
kwargs = {
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
for key, val in kwargs.items()
}
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)
return kwargs


@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
def triton_kernel_wrapper_functional_fake_tensor_mode(mode, *, kernel, grid, kwargs):
with mode:
return {
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
for key, val in kwargs.items()
}
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)
return kwargs


triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
mode, *, kernel, grid, kwargs
):
if mode.enable_tracing:
return trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_functional,
kernel=kernel,
grid=grid,
kwargs=kwargs,
)
else:
return triton_kernel_wrapper_functional(kernel=kernel, grid=grid, kwargs=kwargs)


@triton_kernel_wrapper_functional.py_functionalize_impl
def triton_kernel_wrapper_functional_functionalize(ctx, kernel, grid, kwargs):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
outputs = triton_kernel_wrapper_functional(
kernel=kernel, grid=grid, kwargs=unwrapped_kwargs
)
return ctx.wrap_tensors(outputs)


triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)

triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
52 changes: 52 additions & 0 deletions torch/_subclasses/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ def from_functional(self):
torch._sync(self)
return torch._from_functional_tensor(self.elem)

def replace_(self, output) -> None:
torch._functionalize_replace(self.elem, output) # type: ignore[attr-defined]

def commit_update(self) -> None:
torch._functionalize_commit_update(self.elem) # type: ignore[attr-defined]

def sync(self) -> None:
torch._functionalize_sync(self.elem) # type: ignore[attr-defined]


class FunctionalTensorMode(TorchDispatchMode):
def __init__(self):
Expand Down Expand Up @@ -382,6 +391,18 @@ def functionalize(self, inner_f: Callable) -> Callable:
def redispatch_to_next(self) -> ContextManager:
pass

@abstractmethod
def replace(self, input_tensor, output_tensor) -> None:
pass

@abstractmethod
def commit_update(self, tensor) -> None:
pass

@abstractmethod
def sync(self, tensor) -> None:
pass


class PythonFunctionalizeAPI(BaseFunctionalizeAPI):
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
Expand All @@ -400,6 +421,19 @@ def functionalize(self, inner_f: Callable) -> Callable:
def redispatch_to_next(self) -> ContextManager:
return unset_functional_temporarily()

def replace(self, input_tensor, output_tensor) -> None:
assert isinstance(input_tensor, FunctionalTensor)
assert not isinstance(output_tensor, FunctionalTensor)
input_tensor.replace_(output_tensor)

def commit_update(self, tensor) -> None:
assert isinstance(tensor, FunctionalTensor)
tensor.commit_update()

def sync(self, tensor) -> None:
assert isinstance(tensor, FunctionalTensor)
tensor.sync()


class CppFunctionalizeAPI(BaseFunctionalizeAPI):
def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]:
Expand All @@ -422,6 +456,15 @@ def redispatch_to_next(self) -> ContextManager:
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)

def replace(self, input_tensor, output_tensor) -> None:
torch._functionalize_replace(input_tensor, output_tensor) # type: ignore[attr-defined]

def commit_update(self, tensor) -> None:
torch._functionalize_commit_update(tensor) # type: ignore[attr-defined]

def sync(self, tensor) -> None:
torch._functionalize_sync(tensor) # type: ignore[attr-defined]


class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI):
def __init__(self, interpreter):
Expand Down Expand Up @@ -451,3 +494,12 @@ def functionalize(self, inner_f: Callable) -> Callable:

def redispatch_to_next(self) -> ContextManager:
return self.interpreter.lower()

def replace(self, input_tensor, output_tensor) -> None:
torch._functionalize_replace(input_tensor, output_tensor) # type: ignore[attr-defined]

def commit_update(self, tensor) -> None:
torch._functionalize_commit_update(tensor) # type: ignore[attr-defined]

def sync(self, tensor) -> None:
torch._functionalize_sync(tensor) # type: ignore[attr-defined]

0 comments on commit 47b8727

Please sign in to comment.