Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Add Dynamic shape support to user defined triton kernels #112523

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 24 additions & 11 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,8 +1819,9 @@ def prep():

@requires_cuda()
@requires_triton()
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_triton_kernel_with_views(self, backend):
def test_triton_kernel_with_views(self, dynamic, backend):
def call_triton_take_view(x: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
Expand All @@ -1839,13 +1840,13 @@ def call_triton_return_view(x: torch.Tensor):
t_view = t.view(16)

compiled_func = torch.compile(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add some more concrete tests about recompilation here with dynamic shapes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do as follow up

call_triton_take_view, backend=backend, fullgraph=True
call_triton_take_view, backend=backend, fullgraph=True, dynamic=dynamic
)
self.assertEqual(2 * t_view, compiled_func(t_view))
self.assertEqual(2 * t, compiled_func(t_view).view(4, 4))

compiled_func = torch.compile(
call_triton_return_view, backend=backend, fullgraph=True
call_triton_return_view, backend=backend, fullgraph=True, dynamic=dynamic
)
self.assertEqual(2 * t_view, compiled_func(t).view(16))
self.assertEqual(2 * t, compiled_func(t))
Expand Down Expand Up @@ -1902,8 +1903,9 @@ def pow2_kernel(
@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
def test_triton_kernel_no_clones(self, grad):
def test_triton_kernel_no_clones(self, grad, dynamic):
from torch._inductor.utils import run_and_get_code

def call_triton_add(
Expand All @@ -1922,7 +1924,9 @@ def call_triton_add(
t2 = torch.rand(5, device="cuda", requires_grad=grad)

torch_add = t1 + t2
test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2)
test, (code,) = run_and_get_code(
torch.compile(call_triton_add, dynamic=dynamic), t1, t2
)
self.assertEqual(torch_add, test)
self.assertTrue("aten.copy" not in code)
self.assertTrue("aten.clone" not in code)
Expand Down Expand Up @@ -2058,9 +2062,10 @@ def call_triton(
@requires_triton()
@skipIfRocm
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("grid_type", [1, 2, 3])
def test_triton_kernel_autotune(self, grad, backend, grid_type):
def test_triton_kernel_autotune(self, grad, dynamic, backend, grid_type):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x, requires_grad=grad)
n_elements = output.numel()
Expand All @@ -2082,16 +2087,19 @@ def grid_fn(meta):
t2 = torch.rand(256, device="cuda", requires_grad=grad)

torch_add = call_triton(t1, t2)
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
compiled_func = torch.compile(
call_triton, backend=backend, fullgraph=True, dynamic=dynamic
)
self.assertEqual(compiled_func(t1, t2), torch_add)

@requires_cuda()
@requires_triton()
@skipIfRocm
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@common_utils.parametrize("grid_type", [1, 2, 3])
def test_triton_kernel_2d_autotune(self, grad, backend, grid_type):
def test_triton_kernel_2d_autotune(self, grad, dynamic, backend, grid_type):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x, requires_grad=grad)
x_elements = output.size()[0]
Expand Down Expand Up @@ -2120,15 +2128,18 @@ def grid_fn(meta):
t2 = torch.rand((512, 256), device="cuda", requires_grad=grad)

torch_result = call_triton(t1, t2)
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
compiled_func = torch.compile(
call_triton, backend=backend, fullgraph=True, dynamic=dynamic
)
self.assertEqual(compiled_func(t1, t2), torch_result)

@requires_cuda()
@requires_triton()
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(torch._inductor.config, "implicit_fallbacks", False)
def test_triton_kernel_native(self, grad, backend):
def test_triton_kernel_native(self, grad, dynamic, backend):
def call_triton_add(
x: torch.Tensor, y: torch.Tensor, grid_type: int, num=1, positional=False
):
Expand Down Expand Up @@ -2163,7 +2174,9 @@ def grid_fn(meta):
self.assertEqual(call_triton_add(t1, t2, 1, True), torch_add)

# With Dynamo
compiled_func = torch.compile(call_triton_add, backend=backend, fullgraph=True)
compiled_func = torch.compile(
call_triton_add, backend=backend, fullgraph=True, dynamic=dynamic
)
# With simple kernel
self.assertEqual(compiled_func(t1, t2, 0), torch_add)
# With lambda kernel
Expand Down
17 changes: 5 additions & 12 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,18 @@ def __init__(self):
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()


def grid_fn_code(name, configs, grids):
assert len(grids) == len(configs)
fn_name = f"grid_wrapper_for_{name}"
grid_fn_str = f"def {fn_name}(meta):"
for grid, config in zip(grids, configs):
guards = [f"meta['{name}'] == {val}" for name, val in config.kwargs.items()]
guards = " and ".join(guards)
grid_fn_str += f"\n\tif {guards}: return {grid}"
return fn_name, grid_fn_str


@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs):
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed some PRs: previously, we were evaluating the grid at Dynamo time. Did that change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are still doing the same, this is just for emitting the python code for multi option grid function.

From above

This PR moves the grid function codegen to wrapper so that we can use IndentBuffers as opposed to manually adding tabs for indentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, thanks for clarifying


kernel = kernel_side_table.get_kernel(kernel_idx)

if len(grid) == 1:
grid_fn = grid[0]
else:
fn_name, code = grid_fn_code(kernel.fn.__name__, kernel.configs, grid)
fn_name, code = user_defined_kernel_grid_fn_code(
kernel.fn.__name__, kernel.configs, grid
)
namespace: Dict[str, Any] = {}
exec(code, namespace)
grid_fn = namespace[fn_name]
Expand Down
29 changes: 21 additions & 8 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ def get_cpp_op_schema(kernel):
return f"{cpp_return_value}({', '.join(cpp_arg_type)})"


def user_defined_kernel_grid_fn_code(name, configs, grids):
output = IndentedBuffer()

fn_name = f"grid_wrapper_for_{name}"
output.writeline(f"def {fn_name}(meta):")
with output.indent():
if len(grids) == 1:
output.writeline(f"return {grids[0]}")
else:
assert len(grids) == len(configs)
for grid, c in zip(grids, configs):
guards = [f"meta['{name}'] == {val}" for name, val in c.kwargs.items()]
guards = " and ".join(guards)
output.writeline(f"if {guards}: return {grid}")
Copy link
Contributor

@aakhundov aakhundov Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also generate an exception at the end saying smth like f"no matching Triton config found for the kernel {name=} and {meta=}"? Otherwise, the function will return None and I'm not sure how clear the downstream error would be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton will error saying that no suitable grid was found for {name}. But yes, we will not know what meta looked like. I could add the exception but by construction aren't we always guaranteed to match? Like what's scenario that we would fail to match?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the scenario, more like a defensive / informative code against the unexpected :)

return fn_name, output.getvalue()


@dataclasses.dataclass
class SymbolicCallArg:
inner: Any
Expand Down Expand Up @@ -497,14 +514,10 @@ def generate_extern_kernel_out(self, output_view, codegen_reference, args, kerne
self.writeline(f"{kernel}({', '.join(args)})")

def generate_user_defined_triton_kernel(self, kernel_name, grid, configs, args):
assert len(grid) != 0
if len(grid) == 1:
grid = f"{grid[0]}"
else:
from torch._higher_order_ops.triton_kernel_wrap import grid_fn_code

grid, code = grid_fn_code(kernel_name, configs, grid)
self.header.splice(code, strip=True)
grid, code = user_defined_kernel_grid_fn_code(kernel_name, configs, grid)
# Must happen after free symbols are already codegened
with self.prefix.indent():
self.prefix.splice(code)
Comment on lines +517 to +520
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking for feedback here. I moved the grid function to the prefix (rather than the header) so that it can access the free symbols from dynamic shapes. This happens to work because prefix contains the free symbols. Is there a better solution or is there a way for me to assert that free symbols are already generated?

Copy link
Contributor

@aakhundov aakhundov Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the grid function is a function (and not some inlined block of code) would all the symbols defined in the outer scope, including the ones defined after this function, not be visible in the inner scope of this function's body? E.g., this works (in Python):

def fn(a):
  return a * b

b = 123
c = fn(456)

The main thing is that fn is called before the b is defined. But otherwise it's fine that b is defined after fn. So we need to make sure that the calls to Triton kernels are codegened after the required symbols' definitions, but the grid functions can as well be anywhere in the call function's body before the kernel is called?

As for whether the required symbols will be codegened before the call, as the Triton kernel is represented by a Buffer in the IR which has dependencies, I'd hope that the existing dependency management mechanics should take care of this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know this is how python semantics worked.. C++ certainly does not work this way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on how we decide to deal with codegening the grid in the AOTInductor's C++ wrapper codegen, maybe we could opt for relying on this feature of Python.


stream_name = self.write_get_raw_stream(V.graph.scheduler.current_device.index)
self.writeline(
Expand Down