Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] make thread order consistent with loop order #106827

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,11 +820,14 @@ def set_last_usage(self, nodes):
)

def initialize_range_tree(self, pid_cache):
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
names = list(
reversed(["xindex", "yindex", "zindex"][: len(self.numels) - 1])
) + ["rindex"]
for i in range(len(self.numels)):
pid_idx = i if names[i][0] == "r" else "xyz".find(names[i][0])
self.range_trees.append(
IterationRangesRoot(
names[i], self.numels[i], names[i][0], i, self, pid_cache
names[i], self.numels[i], names[i][0], pid_idx, self, pid_cache
)
)
for tree in self.range_trees:
Expand Down
11 changes: 10 additions & 1 deletion torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,9 +976,18 @@ def foreach(meta, num_warps, filename=None):
)


def grid(xnumel, ynumel=None, znumel=None):
def grid(*numels):
"""Helper function to compute triton grids"""

if len(numels) == 1:
xnumel, ynumel, znumel = numels[0], None, None
elif len(numels) == 2:
xnumel, ynumel, znumel = numels[1], numels[0], None
elif len(numels) == 3:
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
else:
raise AssertionError(f"invalid size for numels {len(numels)}")

def get_grid_dim(numel, block):
if numel is None:
return 1
Expand Down