Skip to content

Commit

Permalink
Add support for while-loops in ttir analysis (#119838)
Browse files Browse the repository at this point in the history
Pull Request resolved: #119838
Approved by: https://github.com/aakhundov
ghstack dependencies: #119834, #119835, #119836
  • Loading branch information
oulgen authored and pytorchmergebot committed Feb 14, 2024
1 parent 5ffac76 commit 1f0e4ac
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
3 changes: 1 addition & 2 deletions test/dynamo/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,8 +1357,7 @@ def kernel_with_label(
"n_elements": 4,
"BLOCK_SIZE": 4,
},
# TODO(oulgen): While loops not implemented yet
["in_ptr0", "in_ptr1", "out_ptr"],
["out_ptr"],
],
[
cond_op_kernel,
Expand Down
9 changes: 7 additions & 2 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def parse_ttir(ttir, kwargs):
parser which further makes parsing much simpler.
"""
# TODO(oulgen):
# - Support parsing while loops
# - Support closures (e.g. "tt.reduce")

try:
Expand All @@ -195,12 +194,15 @@ def parse_ttir(ttir, kwargs):
func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
?stmt: op | if | for | label_stmt | cf_stmt
?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt
if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
for: [assign_lhs "="] "scf.for" args rest stmt* "}" LOC -> process_for
while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while
condition_stmt: "scf.condition" "(" arg ")" args rest
label_stmt: LABEL ":" "// pred:" LABEL
| LABEL "(" /.+/ NEWLINE
cf_stmt: "cf" "." NAME /.+/ NEWLINE
op: OP_NAME LOC
Expand Down Expand Up @@ -326,6 +328,9 @@ def process_if(self, ret, _args, _rest, *stmts):
def process_for(self, ret, _args, _rest, *stmts):
return self._process_scf(ret, stmts)

def process_while(self, ret, _args, _rest, *stmts):
return self._process_scf(ret, stmts)

parser = Lark(
grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps()
)
Expand Down

0 comments on commit 1f0e4ac

Please sign in to comment.