Skip to content

Commit b6fd42d

Browse files
committed
[inductor] make thread order consistent with loop order
ghstack-source-id: 92cbccd Pull Request resolved: #106827
1 parent f8817d8 commit b6fd42d

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,11 +820,14 @@ def set_last_usage(self, nodes):
820820
)
821821

822822
def initialize_range_tree(self, pid_cache):
823-
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
823+
names = list(
824+
reversed(["xindex", "yindex", "zindex"][: len(self.numels) - 1])
825+
) + ["rindex"]
824826
for i in range(len(self.numels)):
827+
pid_idx = i if names[i][0] == "r" else "xyz".find(names[i][0])
825828
self.range_trees.append(
826829
IterationRangesRoot(
827-
names[i], self.numels[i], names[i][0], i, self, pid_cache
830+
names[i], self.numels[i], names[i][0], pid_idx, self, pid_cache
828831
)
829832
)
830833
for tree in self.range_trees:

torch/_inductor/triton_heuristics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,9 +976,18 @@ def foreach(meta, num_warps, filename=None):
976976
)
977977

978978

979-
def grid(xnumel, ynumel=None, znumel=None):
979+
def grid(*numels):
980980
"""Helper function to compute triton grids"""
981981

982+
if len(numels) == 1:
983+
xnumel, ynumel, znumel = numels[0], None, None
984+
elif len(numels) == 2:
985+
xnumel, ynumel, znumel = numels[1], numels[0], None
986+
elif len(numels) == 3:
987+
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
988+
else:
989+
raise AssertionError(f"invalid size for numels {len(numels)}")
990+
982991
def get_grid_dim(numel, block):
983992
if numel is None:
984993
return 1

0 commit comments

Comments
 (0)