Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Add triton.autotune support for user defined triton kernels with complex grids #112290

Closed
wants to merge 3 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112290

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 07165d9 with merge base f5088d2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

oulgen added a commit that referenced this pull request Oct 27, 2023
…s with complex grids

ghstack-source-id: 8849bba31b2748464dc491fa72e98d12f8319c01
Pull Request resolved: #112290
@oulgen oulgen added ciflow/trunk Trigger trunk jobs on your pull request release notes: inductor labels Oct 27, 2023
…iton kernels with complex grids"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
Comment on lines 87 to 96
def grid_fn(meta):
configs = kernel.configs
assert len(grid) == len(configs)
for i, (grid_val, config) in enumerate(zip(grid, configs)):
guards = [
f"meta['{name}'] == {val}" for name, val in config.kwargs.items()
]
guards = " and ".join(guards)
if eval(guards):
return grid[i]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configs aren't actually guaranteed to be unique, since they can differ in num_stages/num_warps (but not kwargs). This might be fine though, since I think the grids will be the same for duplicates.

Calling eval() is very expensive in python, and this function is O(num-configs).

This this could be implemented more efficiently using a dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be fine though, since I think the grids will be the same for duplicates

grid computation doesn't use num_warps and num_stages so duplicates are fine, just more branches.

Comment on lines 505 to 513
grid_wrapper.writeline(f"def grid_wrapper_for_{kernel_name}(meta):")
assert len(grid) == len(configs)
with grid_wrapper.indent():
for i, (grid_val, conf) in enumerate(zip(grid, configs)):
guards = [
f"meta['{name}'] == {val}" for name, val in conf.kwargs.items()
]
guards = " and ".join(guards)
grid_wrapper.writeline(f"if {guards}: return grid({grid[i]})(meta)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we share code with the implementation above? I think we can preprocess this into a dict, so the body is more like:

def grid(meta):
  return precomputed_grids[(meta["XBLOCK], meta["YBLOCK])]

Copy link
Contributor Author

@oulgen oulgen Oct 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One edge case I see here is

@triton.autotune(
        configs=[
            triton.Config(
                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
            ),
            triton.Config(
                {"BLOCK_SIZE_X": 64}, num_stages=3, num_warps=8
            ),
        ],
        key=[],
    )

where configs are not of equal size in terms of their BLOCK_SIZE dimensions. Is this something I should disallow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just disallow this.

Comment on lines 1195 to 1197
# when passed as a single tuple, calling convention order is not
# followed, so we need to reverse to match calling conversion order
numels = numels[0][::-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this change is needed, aren't we not using this function for triton kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cached_autotune uses the launder function which requires everything to be 3tuple since it unpacks.

if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid

I was using this function to convert 1tuple or 2tuple into 3 tuple. I can either write my own function to do this or is there a more pythonic way to do this converstion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are generating the grid function, can't we just make it return a 3-tuple?

…iton kernels with complex grids"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@oulgen oulgen requested a review from jansel October 29, 2023 06:25
@oulgen
Copy link
Contributor Author

oulgen commented Oct 29, 2023

Explain offline to @jansel that using a dict imposes unnecessary restrictions. updated the code to only have a single exec in eager mode, and share code between eager and inductor.

@oulgen
Copy link
Contributor Author

oulgen commented Oct 30, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 30, 2023
@facebook-github-bot facebook-github-bot deleted the gh/oulgen/18/head branch November 3, 2023 14:27
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Nov 19, 2023
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Nov 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants