New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Add triton.autotune support for user defined triton kernels with complex grids #112290
Conversation
…s with complex grids [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112290
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 07165d9 with merge base f5088d2 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…s with complex grids ghstack-source-id: 8849bba31b2748464dc491fa72e98d12f8319c01 Pull Request resolved: #112290
…iton kernels with complex grids" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
def grid_fn(meta): | ||
configs = kernel.configs | ||
assert len(grid) == len(configs) | ||
for i, (grid_val, config) in enumerate(zip(grid, configs)): | ||
guards = [ | ||
f"meta['{name}'] == {val}" for name, val in config.kwargs.items() | ||
] | ||
guards = " and ".join(guards) | ||
if eval(guards): | ||
return grid[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Configs aren't actually guaranteed to be unique, since they can differ in num_stages/num_warps (but not kwargs). This might be fine though, since I think the grids will be the same for duplicates.
Calling eval()
is very expensive in python, and this function is O(num-configs).
This this could be implemented more efficiently using a dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be fine though, since I think the grids will be the same for duplicates
grid computation doesn't use num_warps and num_stages so duplicates are fine, just more branches.
torch/_inductor/codegen/wrapper.py
Outdated
grid_wrapper.writeline(f"def grid_wrapper_for_{kernel_name}(meta):") | ||
assert len(grid) == len(configs) | ||
with grid_wrapper.indent(): | ||
for i, (grid_val, conf) in enumerate(zip(grid, configs)): | ||
guards = [ | ||
f"meta['{name}'] == {val}" for name, val in conf.kwargs.items() | ||
] | ||
guards = " and ".join(guards) | ||
grid_wrapper.writeline(f"if {guards}: return grid({grid[i]})(meta)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we share code with the implementation above? I think we can preprocess this into a dict, so the body is more like:
def grid(meta):
return precomputed_grids[(meta["XBLOCK], meta["YBLOCK])]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One edge case I see here is
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 64}, num_stages=3, num_warps=8
),
],
key=[],
)
where configs are not of equal size in terms of their BLOCK_SIZE dimensions. Is this something I should disallow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just disallow this.
torch/_inductor/triton_heuristics.py
Outdated
# when passed as a single tuple, calling convention order is not | ||
# followed, so we need to reverse to match calling conversion order | ||
numels = numels[0][::-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why this change is needed, aren't we not using this function for triton kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cached_autotune uses the launder function which requires everything to be 3tuple since it unpacks.
pytorch/torch/_inductor/triton_heuristics.py
Lines 316 to 319 in c14c4ef
if callable(grid): | |
grid_0, grid_1, grid_2 = grid(grid_meta) | |
else: | |
grid_0, grid_1, grid_2 = grid |
I was using this function to convert 1tuple or 2tuple into 3 tuple. I can either write my own function to do this or is there a more pythonic way to do this converstion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are generating the grid function, can't we just make it return a 3-tuple?
…iton kernels with complex grids" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Explain offline to @jansel that using a dict imposes unnecessary restrictions. updated the code to only have a single exec in eager mode, and share code between eager and inductor. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #112292 Approved by: https://github.com/jansel ghstack dependencies: #112290
…s with complex grids (pytorch#112290) Pull Request resolved: pytorch#112290 Approved by: https://github.com/jansel
…112292) Pull Request resolved: pytorch#112292 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112290
…s with complex grids (pytorch#112290) Pull Request resolved: pytorch#112290 Approved by: https://github.com/jansel
…112292) Pull Request resolved: pytorch#112292 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112290
…s with complex grids (pytorch#112290) Pull Request resolved: pytorch#112290 Approved by: https://github.com/jansel
…112292) Pull Request resolved: pytorch#112292 Approved by: https://github.com/jansel ghstack dependencies: pytorch#112290
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler