Skip to content

Commit

Permalink
[inductor] make thread order consistent with loop order
Browse files Browse the repository at this point in the history
ghstack-source-id: 92cbccde98c5e74b363fef0cdd9eec2bdaa1a78a
Pull Request resolved: #106827
  • Loading branch information
shunting314 committed Aug 8, 2023
1 parent f8817d8 commit b6fd42d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
7 changes: 5 additions & 2 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,11 +820,14 @@ def set_last_usage(self, nodes):
)

def initialize_range_tree(self, pid_cache):
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
names = list(
reversed(["xindex", "yindex", "zindex"][: len(self.numels) - 1])
) + ["rindex"]
for i in range(len(self.numels)):
pid_idx = i if names[i][0] == "r" else "xyz".find(names[i][0])
self.range_trees.append(
IterationRangesRoot(
names[i], self.numels[i], names[i][0], i, self, pid_cache
names[i], self.numels[i], names[i][0], pid_idx, self, pid_cache
)
)
for tree in self.range_trees:
Expand Down
11 changes: 10 additions & 1 deletion torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,9 +976,18 @@ def foreach(meta, num_warps, filename=None):
)


def grid(xnumel, ynumel=None, znumel=None):
def grid(*numels):
"""Helper function to compute triton grids"""

if len(numels) == 1:
xnumel, ynumel, znumel = numels[0], None, None
elif len(numels) == 2:
xnumel, ynumel, znumel = numels[1], numels[0], None
elif len(numels) == 3:
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
else:
raise AssertionError(f"invalid size for numels {len(numels)}")

def get_grid_dim(numel, block):
if numel is None:
return 1
Expand Down

0 comments on commit b6fd42d

Please sign in to comment.