-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Add Dynamic shape support to user defined triton kernels #112523
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,25 +76,18 @@ def __init__(self): | |
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() | ||
|
||
|
||
def grid_fn_code(name, configs, grids): | ||
assert len(grids) == len(configs) | ||
fn_name = f"grid_wrapper_for_{name}" | ||
grid_fn_str = f"def {fn_name}(meta):" | ||
for grid, config in zip(grids, configs): | ||
guards = [f"meta['{name}'] == {val}" for name, val in config.kwargs.items()] | ||
guards = " and ".join(guards) | ||
grid_fn_str += f"\n\tif {guards}: return {grid}" | ||
return fn_name, grid_fn_str | ||
|
||
|
||
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) | ||
def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs): | ||
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed some PRs: previously, we were evaluating the grid at Dynamo time. Did that change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are still doing the same, this is just for emitting the python code for multi option grid function. From above
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gotcha, thanks for clarifying |
||
|
||
kernel = kernel_side_table.get_kernel(kernel_idx) | ||
|
||
if len(grid) == 1: | ||
grid_fn = grid[0] | ||
else: | ||
fn_name, code = grid_fn_code(kernel.fn.__name__, kernel.configs, grid) | ||
fn_name, code = user_defined_kernel_grid_fn_code( | ||
kernel.fn.__name__, kernel.configs, grid | ||
) | ||
namespace: Dict[str, Any] = {} | ||
exec(code, namespace) | ||
grid_fn = namespace[fn_name] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,6 +124,23 @@ def get_cpp_op_schema(kernel): | |
return f"{cpp_return_value}({', '.join(cpp_arg_type)})" | ||
|
||
|
||
def user_defined_kernel_grid_fn_code(name, configs, grids): | ||
output = IndentedBuffer() | ||
|
||
fn_name = f"grid_wrapper_for_{name}" | ||
output.writeline(f"def {fn_name}(meta):") | ||
with output.indent(): | ||
if len(grids) == 1: | ||
output.writeline(f"return {grids[0]}") | ||
else: | ||
assert len(grids) == len(configs) | ||
for grid, c in zip(grids, configs): | ||
guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()] | ||
guards = " and ".join(guards) | ||
output.writeline(f"if {guards}: return {grid}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also generate an exception at the end saying smth like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Triton will error saying that no suitable grid was found for {name}. But yes, we will not know what meta looked like. I could add the exception but by construction aren't we always guaranteed to match? Like what's scenario that we would fail to match? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about the scenario, more like a defensive / informative code against the unexpected :) |
||
return fn_name, output.getvalue() | ||
|
||
|
||
@dataclasses.dataclass | ||
class SymbolicCallArg: | ||
inner: Any | ||
|
@@ -497,14 +514,10 @@ def generate_extern_kernel_out(self, output_view, codegen_reference, args, kerne | |
self.writeline(f"{kernel}({', '.join(args)})") | ||
|
||
def generate_user_defined_triton_kernel(self, kernel_name, grid, configs, args): | ||
assert len(grid) != 0 | ||
if len(grid) == 1: | ||
grid = f"{grid[0]}" | ||
else: | ||
from torch._higher_order_ops.triton_kernel_wrap import grid_fn_code | ||
|
||
grid, code = grid_fn_code(kernel_name, configs, grid) | ||
self.header.splice(code, strip=True) | ||
grid, code = user_defined_kernel_grid_fn_code(kernel_name, configs, grid) | ||
# Must happen after free symbols are already codegened | ||
with self.prefix.indent(): | ||
self.prefix.splice(code) | ||
Comment on lines
+517
to
+520
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking for feedback here. I moved the grid function to the prefix (rather than the header) so that it can access the free symbols from dynamic shapes. This happens to work because prefix contains the free symbols. Is there a better solution or is there a way for me to assert that free symbols are already generated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because the grid function is a function (and not some inlined block of code) would all the symbols defined in the outer scope, including the ones defined after this function, not be visible in the inner scope of this function's body? E.g., this works (in Python): def fn(a):
return a * b
b = 123
c = fn(456) The main thing is that As for whether the required symbols will be codegened before the call, as the Triton kernel is represented by a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't know this is how python semantics worked.. C++ certainly does not work this way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depending on how we decide to deal with codegening the grid in the AOTInductor's C++ wrapper codegen, maybe we could opt for relying on this feature of Python. |
||
|
||
stream_name = self.write_get_raw_stream(V.graph.scheduler.current_device.index) | ||
self.writeline( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add some more concrete tests about recompilation here with dynamic shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do as follow up