Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor][codegen] Codegen constexpr globals and constexpr annotated globals correctly. #126195

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions test/inductor/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
fast_dividef as my_fast_dividef,
)


# Define shared triton constants here.
CONSTANT_C = 4
STRING_CONSTANT_C = "CONSTANT_C"
BOOL_CONSTANT_C = True
# Define shared triton constants here.
CONSTANT_C: tl.constexpr = 4
STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C"
BOOL_CONSTANT_C: tl.constexpr = True


class KernelTests(torch._inductor.test_case.TestCase):
Expand Down Expand Up @@ -600,7 +599,7 @@ def mulC_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if CONSTANT_NAME.value == STRING_CONSTANT_C:
if CONSTANT_NAME == STRING_CONSTANT_C:
output = CONSTANT_C * x
if BOOL_CONSTANT_C:
output *= CONSTANT_C
Expand Down
23 changes: 21 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,9 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):

# Also include any possible kernel being called indirectly
from triton import JITFunction
from triton.language import constexpr

# global constexpr vars handled above
symbols_included = {original_name}

def traverse(cur_kernel):
Expand All @@ -1237,6 +1239,7 @@ def traverse(cur_kernel):
for inst in dis.Bytecode(cur_kernel.fn)
if inst.opname == "LOAD_GLOBAL"
}
global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {})
for symbol_name in cur_kernel.fn.__code__.co_names:
if symbol_name in symbols_included:
continue
Expand All @@ -1248,9 +1251,25 @@ def traverse(cur_kernel):
compile_wrapper.splice(symbol.src, strip=True)
symbols_included.add(symbol_name)
traverse(symbol)
elif isinstance(symbol, (int, str, bool)):
elif isinstance(symbol, (int, str, bool, constexpr)):
compile_wrapper.newline()
compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
if isinstance(symbol, constexpr):
symbol_str = f"tl.constexpr({symbol.value!r})"
else:
symbol_str = f"{symbol!r}"
if annotation := global_annotations.get(symbol_name):
annotion_code = ""
if isinstance(annotation, type):
annotation_code = (
f": {annotation.__module__}.{annotation.__name__}"
)
else:
annotation_code = f": {annotation!r}"
compile_wrapper.writeline(
f"{symbol_name}{annotation_code} = {symbol_str}"
)
else:
compile_wrapper.writeline(f"{symbol_name} = {symbol!r}")
symbols_included.add(symbol_name)
elif (
symbol_name in unqualified_loads
Expand Down
Loading