Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for labels to ttir analysis #119836

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/dynamo/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,40 @@ def add_4_times_kernel(
["out_ptr"],
)

@make_mutation_test
def test_labels():
@triton.jit
def kernel_with_label(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
if pid > 1:
return
Comment on lines +1236 to +1237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add another test where we do different things in if and else blocks with a data-dependent condition like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already have a test for that which turns into a scf.if not labels

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, so labels are only used on if ...: return? or also if can have a body and this is still generated as label?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends on what happens in the body, is it dynamic or not

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

t = torch.randn(4)
return (
kernel_with_label,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
)


if HAS_CUDA and HAS_LARK:
t = torch.randn(4)
Expand Down
10 changes: 7 additions & 3 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,14 @@ def parse_ttir(ttir, kwargs):

func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func

?stmt: op | if | for
?stmt: op | if | for | label_stmt | cf_stmt

if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if

for: [assign_lhs "="] "scf.for" args rest stmt* "}" LOC -> process_for

label_stmt: LABEL ":" "// pred:" LABEL
cf_stmt: "cf" "." NAME /.+/ NEWLINE

op: OP_NAME LOC
| [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op

Expand All @@ -223,10 +225,12 @@ def parse_ttir(ttir, kwargs):
INTERMEDIATE.4: "%" DIGIT+
INTERMEDIATE_CONSTANT.3: "%" NAME
CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
LABEL: "^bb" DIGIT+

NAME: (LETTER | DIGIT | "_")+
NON_CF_NAME: /(?!(cf))/ NAME
FN_NAME: "@" (NAME | ESCAPED_STRING)
OP_NAME: "\\""? NAME ("." NAME)+ "\\""?
OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?

LOC.5: "loc(#loc" DIGIT* ")"

Expand Down