Skip to content

Commit

Permalink
Support calling user defined triton kernels with kernel.run (#112292)
Browse files Browse the repository at this point in the history
Pull Request resolved: #112292
Approved by: https://github.com/jansel
ghstack dependencies: #112290
  • Loading branch information
oulgen authored and pytorchmergebot committed Oct 30, 2023
1 parent 1250032 commit 219763c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ def call_triton_add(
n_elements = output.numel()

grid = (x.numel(),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
add_kernel.run(x, y, output, n_elements, grid=grid, BLOCK_SIZE=16)

return output

Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,11 @@ def call_method(
grid=args[0],
**VariableTracker.propagate(self),
)
elif name == "run":
if "grid" not in kwargs:
raise Unsupported("Triton kernel requires to be called with a grid")
grid = kwargs.pop("grid")
return self.clone(grid=grid).call_function(tx, args, kwargs)

# Bail out to parent's implementation
return super().call_method(tx, name, args, kwargs)

0 comments on commit 219763c

Please sign in to comment.