Skip to content

Commit

Permalink
[Dynamo] Add functional triton kernel wrapper (#110185)
Browse files Browse the repository at this point in the history
Pull Request resolved: #110185
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/bdhirsh
ghstack dependencies: #109623
  • Loading branch information
oulgen authored and pytorchmergebot committed Sep 30, 2023
1 parent 6b84658 commit f7ba3e8
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
83 changes: 83 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from torch import sub
from torch._dynamo.testing import expectedFailureDynamic
from torch._dynamo.utils import same

from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_functional,
triton_kernel_wrapper_mutation,
)
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
disable_translation_validation_if_dynamic_shapes,
Expand Down Expand Up @@ -1458,6 +1463,84 @@ def func():
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 1)

@requires_cuda()
@requires_triton()
def test_triton_kernel_with_kernel_param(self):
@triton.jit
def test_kernel(kernel):
pass

@torch.compile(backend="eager")
def f(x):
grid = (x.numel(),)
test_kernel[grid](kernel=x)

t1 = torch.rand(5, device="cuda")
f(t1)
# No need to assert anything, the goal is to make sure dynamo does
# not crash

@requires_cuda()
@requires_triton()
def test_triton_kernel_higher_order_func(self):
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")

torch_add = t1 + t2

# Test higher order function with mutation
output = torch.zeros_like(t1)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
triton_kernel_wrapper_mutation(
kernel=add_kernel,
grid=grid,
kwargs={
"in_ptr0": t1,
"in_ptr1": t2,
"out_ptr": output,
"n_elements": n_elements,
"BLOCK_SIZE": 16,
},
)
self.assertEqual(output, torch_add)
# Make sure it is modified
self.assertNotEqual(output, torch.zeros_like(t1))

# Test higher order function without mutation
output = torch.zeros_like(t1)
out_dict = triton_kernel_wrapper_functional(
kernel=add_kernel,
grid=grid,
kwargs={
"in_ptr0": t1,
"in_ptr1": t2,
"out_ptr": output,
"n_elements": n_elements,
"BLOCK_SIZE": 16,
},
)
self.assertEqual(out_dict["out_ptr"], torch_add)
# Make sure it is NOT modified
self.assertEqual(output, torch.zeros_like(t1))

@requires_cuda()
@requires_triton()
def test_triton_kernel_by_hand(self):
Expand Down
19 changes: 19 additions & 0 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from torch import Tensor
from torch._ops import HigherOrderOperator
from torch._prims_common import clone_preserve_strides


# Used for wrapping a Triton Kernel
Expand All @@ -11,3 +13,20 @@ def __call__(self, *, kernel, grid, kwargs):


triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()


# Used for wrapping a Triton Kernel in a functional manner
class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self):
super().__init__("triton_kernel_wrapper_functional")

def __call__(self, *, kernel, grid, kwargs):
kwargs = {
key: (clone_preserve_strides(val) if isinstance(val, Tensor) else val)
for key, val in kwargs.items()
}
triton_kernel_wrapper_mutation(kernel=kernel, grid=grid, kwargs=kwargs)
return kwargs


triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()

0 comments on commit f7ba3e8

Please sign in to comment.