Skip to content

Commit

Permalink
[inductor] make thread order consistent with loop order (#106827)
Browse files Browse the repository at this point in the history
I found that for a tiled kernel for tensor with shape [a, b], we map 'a' with XBLOCK and 'b' with YBLOCK. However, 'a' actually should be the outer looper while 'b' corresponding to the inner loop. This order is picked by our loop ordering algorithm. Mapping 'a' with XBLOCK has the semantic like assigning 'a' to the inner loop instead.

For a simple 'A + B.t()' kernel, making the loop order consistent can brings 1.027x speedup ( 1.938ms -> 1.887ms speedup) . Here are the dump of kernels:

- before fix: https://gist.github.com/shunting314/4dacf73cf495cdd7e84dede7c3e0872d
- after fix (this one is done manually): https://gist.github.com/shunting314/441e8839d24e1878c313e539b1ebd551

I tried this on DistillGPT2 and found perf is neutral. But that because DistillGPT2 has a single tiled pointwise kernel in it's backward graph. Will check the dashboard.

Pull Request resolved: #106827
Approved by: https://github.com/jansel
  • Loading branch information
shunting314 authored and pytorchmergebot committed Aug 11, 2023
1 parent 745d29b commit 6696a75
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
14 changes: 12 additions & 2 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,11 +820,14 @@ def set_last_usage(self, nodes):
)

def initialize_range_tree(self, pid_cache):
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
names = list(
reversed(["xindex", "yindex", "zindex"][: len(self.numels) - 1])
) + ["rindex"]
for i in range(len(self.numels)):
pid_idx = i if names[i][0] == "r" else "xyz".find(names[i][0])
self.range_trees.append(
IterationRangesRoot(
names[i], self.numels[i], names[i][0], i, self, pid_cache
names[i], self.numels[i], names[i][0], pid_idx, self, pid_cache
)
)
for tree in self.range_trees:
Expand Down Expand Up @@ -1895,6 +1898,13 @@ def dense_size_str(self):
sizes.append(f"{tree.prefix.upper()}BLOCK")
elif tree.prefix == "r" and tree.numel != 1:
sizes.append("1")

if sizes[0:3] == ["ZBLOCK", "YBLOCK", "XBLOCK"]:
sizes[0:3] = reversed(sizes[0:3])

if sizes[0:2] == ["YBLOCK", "XBLOCK"]:
sizes[0:2] = reversed(sizes[0:2])

return f"[{', '.join(sizes)}]"

def call_kernel(self, name: str):
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,15 +801,15 @@ def index_cmp(a, b):

# equivalent to
# np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
a_first = all(
a_first = sum(
sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
)
b_first = all(
b_first = sum(
sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
)
if a_first and not b_first:
if a_first > b_first:
return -1
if b_first and not a_first:
if b_first > a_first:
return 1

# otherwise contiguous
Expand Down
15 changes: 14 additions & 1 deletion torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,10 @@ def triton_config(
override the num_elements_per_warp.
"""
# Ideally we want to read this from some device config

# for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK
size_hints = list(reversed(size_hints))

maxGridSize = [2147483647, 65535, 65535]

target = conditional_product(x, y, z)
Expand Down Expand Up @@ -976,9 +980,18 @@ def foreach(meta, num_warps, filename=None):
)


def grid(xnumel, ynumel=None, znumel=None):
def grid(*numels):
"""Helper function to compute triton grids"""

if len(numels) == 1:
xnumel, ynumel, znumel = numels[0], None, None
elif len(numels) == 2:
xnumel, ynumel, znumel = numels[1], numels[0], None
elif len(numels) == 3:
xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
else:
raise AssertionError(f"invalid size for numels {len(numels)}")

def get_grid_dim(numel, block):
if numel is None:
return 1
Expand Down

0 comments on commit 6696a75

Please sign in to comment.