Skip to content

Commit

Permalink
[Dynamo] Add native support for Triton Kernels to Dynamo
Browse files Browse the repository at this point in the history
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

ghstack-source-id: 75ac402a2d86d1670ff154448a29ea6401483627
Pull Request resolved: #109623
  • Loading branch information
oulgen committed Sep 19, 2023
1 parent 1427b81 commit 981142f
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 1 deletion.
70 changes: 70 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@
disable_translation_validation_if_dynamic_shapes,
)

try:
try:
import triton
from triton import language as tl
except ImportError:
raise unittest.SkipTest("requires triton")

except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise

from torch.testing._internal.inductor_utils import HAS_CUDA

requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")


d = torch.ones(10, 10)
e = torch.nn.Linear(10, 10)
flag = True
Expand Down Expand Up @@ -1435,6 +1452,59 @@ def func():
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 1)

@requires_cuda()
def test_triton_kernel_by_hand(self):
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

def call_triton_add(x: torch.Tensor, y: torch.Tensor, grid_type: int, num=1):
output = torch.zeros_like(x)
n_elements = output.numel()

def grid_fn(meta):
return (triton.cdiv(num, meta["BLOCK_SIZE"]),)

if grid_type == 0:
grid = (x.numel(),)
elif grid_type == 1:
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
else:
grid = grid_fn

add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=2)
return output

t1 = torch.rand(5, device="cuda")
t2 = torch.rand(5, device="cuda")

torch_add = t1 + t2

# No Dynamo -- Make sure triton kernel works
self.assertEqual(call_triton_add(t1, t2, True), torch_add)

# With Dynamo
compiled_func = torch.compile(call_triton_add, backend="eager", fullgraph=True)
# With simple kernel
self.assertEqual(compiled_func(t1, t2, 0), torch_add)
# With lambda kernel
self.assertEqual(compiled_func(t1, t2, 1), torch_add)
# With user defined function kernel
self.assertEqual(compiled_func(t1, t2, 2, 200), torch_add)

def test_dataclass_factory(self):
@dataclass
class Output:
Expand Down
16 changes: 16 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from torch.fx.immutable_collections import immutable_list
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.triton import has_triton
from torch.utils.weak import TensorWeakRef, WeakIdRef

from .. import config, mutation_guard, replay_record, skipfiles
Expand Down Expand Up @@ -94,6 +95,7 @@
from .functions import (
CollectiveFunctionRewriteVariable,
FunctoolsPartialVariable,
TritonKernelVariable,
UserFunctionVariable,
UserMethodVariable,
)
Expand Down Expand Up @@ -140,6 +142,13 @@
UserDefinedObjectVariable,
)

if has_triton():
from triton.runtime.jit import JITFunction
else:

class JITFunction:
pass


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -755,6 +764,13 @@ def index_source(key):
sym_node_proxy,
new_symint == 1,
)
elif isinstance(value, JITFunction):
return TritonKernelVariable(
value,
None, # No grid provided
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
else:
result = UserDefinedObjectVariable(
value,
Expand Down
69 changes: 68 additions & 1 deletion torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
GetItemSource,
GlobalSource,
)
from ..utils import make_cell
from ..utils import make_cell, proxy_args_kwargs
from .base import typestr, VariableTracker


Expand Down Expand Up @@ -639,3 +639,70 @@ def as_python_constant(self):
*[arg.as_python_constant for arg in self.args],
**{k: v.as_python_constant() for k, v in self.keywords.items()},
)


class TritonKernelVariable(VariableTracker):
def __init__(self, kernel, grid, **kwargs):
super().__init__(**kwargs)
self.kernel = kernel
self.grid = grid

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
def call_kernel(grid, *args, **kwargs):
self.kernel.run(*args, grid=grid, **kwargs)

from .dicts import ConstDictVariable
from .functions import NestedUserFunctionVariable, UserFunctionVariable
from .lists import BaseListVariable

grid = self.grid

# If the grid is a function, then lets execute it and convert it to
# a list
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
d = dict(kwargs)
meta = ConstDictVariable(d, dict)
grid = grid.call_function(tx, [meta], {})

# Now, the grid must be a list either originally or through above
# modification
if isinstance(grid, BaseListVariable):
grid = grid.as_proxy()
else:
unimplemented(f"grid for the triton kernel is {type(grid)}")

proxied_args, proxied_kwargs = proxy_args_kwargs(args, kwargs)
tx.output.create_proxy(
"call_function",
call_kernel,
(grid,) + proxied_args,
proxied_kwargs,
)

return variables.ConstantVariable(
None,
**VariableTracker.propagate(self, args),
)

def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
# __getitem__ should only be called if we don't already have a grid
assert self.grid is None

# Only grid needs to be passed
assert len(args) == 1
grid = args[0]
return TritonKernelVariable(
self.kernel, grid, **VariableTracker.propagate(self)
)

# Bail out to parents implementation
return super().call_method(tx, name, args, kwargs)
13 changes: 13 additions & 0 deletions torch/utils/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import functools
import torch

@functools.lru_cache(None)
def has_triton() -> bool:
if not torch.cuda.is_available():
return False
try:
import triton

return triton is not None and torch.cuda.get_device_capability() >= (7, 0)
except ImportError:
return False

0 comments on commit 981142f

Please sign in to comment.