Skip to content

Commit

Permalink
Add support for labels to ttir analysis (#119836)
Browse files Browse the repository at this point in the history
Pull Request resolved: #119836
Approved by: https://github.com/aakhundov
ghstack dependencies: #119834, #119835
  • Loading branch information
oulgen authored and pytorchmergebot committed Feb 14, 2024
1 parent 3f09c5e commit 5ffac76
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
34 changes: 34 additions & 0 deletions test/dynamo/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,40 @@ def add_4_times_kernel(
["out_ptr"],
)

@make_mutation_test
def test_labels():
@triton.jit
def kernel_with_label(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
if pid > 1:
return
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)

t = torch.randn(4)
return (
kernel_with_label,
{
"in_ptr0": t,
"in_ptr1": t,
"out_ptr": t,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["out_ptr"],
)


if HAS_CUDA and HAS_LARK:
t = torch.randn(4)
Expand Down
10 changes: 7 additions & 3 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,14 @@ def parse_ttir(ttir, kwargs):
func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
?stmt: op | if | for
?stmt: op | if | for | label_stmt | cf_stmt
if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
for: [assign_lhs "="] "scf.for" args rest stmt* "}" LOC -> process_for
label_stmt: LABEL ":" "// pred:" LABEL
cf_stmt: "cf" "." NAME /.+/ NEWLINE
op: OP_NAME LOC
| [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op
Expand All @@ -223,10 +225,12 @@ def parse_ttir(ttir, kwargs):
INTERMEDIATE.4: "%" DIGIT+
INTERMEDIATE_CONSTANT.3: "%" NAME
CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
LABEL: "^bb" DIGIT+
NAME: (LETTER | DIGIT | "_")+
NON_CF_NAME: /(?!(cf))/ NAME
FN_NAME: "@" (NAME | ESCAPED_STRING)
OP_NAME: "\\""? NAME ("." NAME)+ "\\""?
OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?
LOC.5: "loc(#loc" DIGIT* ")"
Expand Down

0 comments on commit 5ffac76

Please sign in to comment.