Skip to content

Commit

Permalink
Update on "[Dynamo] Add native support for Triton Kernels to Dynamo"
Browse files Browse the repository at this point in the history
This PR adds native support to Dynamo to detect Triton kernels and
create an FX graph node out of them. AOT eager and inductor modes will
be support in follow up PRs.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng

[ghstack-poisoned]
  • Loading branch information
oulgen committed Sep 28, 2023
2 parents 312b4f3 + 170bc8e commit 9939238
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
16 changes: 9 additions & 7 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,19 +651,22 @@ def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .dicts import ConstDictVariable
from .lists import BaseListVariable, TupleVariable
from .lists import BaseListVariable

grid = self.grid

if grid is None:
raise Unsupported("Triton kernels should always be called with a grid")

# Both for grid's meta as well as for the kernel, we need combined
# args and kwargs normalized
normalized_args = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
meta = ConstDictVariable(normalized_args, dict)

# If the grid is a function, then lets execute it and convert it to
# a list
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
# Populate the special "meta" argument to call the grid function
d = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
meta = ConstDictVariable(d, dict)
grid = grid.call_function(tx, [meta], {})

# Now, the grid must be a list either originally or through above
Expand All @@ -684,7 +687,7 @@ def call_function(
# Super hacky but on AMD __module__ is not set
fn.__module__ = "itertools"

# Pass args and kwargs as tuple and dict so that if user defined triton
# Combine args and kwargs and pass as a dict so that if user defined triton
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
# parameters of the wrapper function
tx.output.create_proxy(
Expand All @@ -693,14 +696,13 @@ def call_function(
(),
{
"grid": grid,
"args": TupleVariable(args).as_proxy(),
"kwargs": ConstDictVariable(kwargs, dict).as_proxy(),
"kwargs": meta.as_proxy(),
},
)

return variables.ConstantVariable(
None,
**VariableTracker.propagate(self, args),
**VariableTracker.propagate(self, args, kwargs.values()),
)

def call_method(
Expand Down
4 changes: 2 additions & 2 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self):
super().__init__("triton_kernel_wrapper_mutation")

def __call__(self, *, kernel, grid, args, kwargs):
kernel[grid](*args, **kwargs)
def __call__(self, *, kernel, grid, kwargs):
kernel[grid](**kwargs)


triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()

0 comments on commit 9939238

Please sign in to comment.