Skip to content

Commit

Permalink
[inductor] Persistent reductions
Browse files Browse the repository at this point in the history
ghstack-source-id: b49f6e06caad33e4dc50b82bda477c42212a0df2
Pull Request resolved: #92267
  • Loading branch information
jansel committed Feb 8, 2023
1 parent 83275d8 commit 6a53a83
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -149,6 +149,7 @@ torchgen/packaged/*
*.swo
*.swp
*~
.~lock.*

# macOS dir files
.DS_Store
Expand Down
24 changes: 22 additions & 2 deletions test/inductor/test_torchinductor.py
Expand Up @@ -965,6 +965,13 @@ def fn(a):
for i in inputs:
self.common(fn, (i,))

@config.patch(unroll_reductions_threshold=1)
def test_reduction5(self):
def fn(a):
return (a.sum(), a.max(), a.min(), a.argmax())

self.common(fn, (torch.full((4,), float("-inf")),))

def test_unroll_small_reduction(self):
def fn(x):
val1, index1 = x.min(-1)
Expand Down Expand Up @@ -2865,18 +2872,31 @@ def fn(a, b):
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

def test_softmax_one_kernel(self):
@patch.object(config.triton, "persistent_reductions", True)
def test_softmax_one_kernel_persist(self):
def fn(x):
dim = 1
x_max = torch.amax(x, dim, keepdim=True)
unnormalized = torch.exp(x * x_max)
unnormalized = torch.exp(x - x_max)
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
return result

self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

@patch.object(config.triton, "persistent_reductions", False)
def test_softmax_one_kernel_loop(self):
def fn(x):
x_max = torch.amax(x, 1, keepdim=True)
unnormalized = torch.exp(x - x_max)
result = unnormalized / torch.sum(unnormalized, 1, keepdim=True)
return result

self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

def test_cauchy(self):
def fn(x, y):
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
Expand Down
85 changes: 66 additions & 19 deletions torch/_inductor/codegen/triton.py
Expand Up @@ -352,6 +352,8 @@ def __init__(
var_ranges: Dict[sympy.Symbol, sympy.Expr],
numel: sympy.Expr,
prefix: str,
*,
kernel: "Kernel",
divisor=sympy.Integer(1),
length=sympy.Integer(1),
):
Expand All @@ -363,9 +365,10 @@ def __init__(
self.prefix = prefix
self.divisor = divisor
self.length = length
self.kernel = kernel

def is_loop(self):
return self.prefix == "r"
return self.prefix == "r" and not self.kernel.persistent_reduction


class IterationRangesRoot(IterationRanges):
Expand All @@ -386,9 +389,9 @@ def __init__(
var_ranges={},
numel=numel,
prefix=prefix,
kernel=kernel,
)
self.index = index
self.kernel = kernel
# Store all the nodes in one flat list
self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
# This is for re-ordering program ID in triton mm template
Expand Down Expand Up @@ -475,6 +478,11 @@ def codegen_header(self, code):
x = self.prefix
if self.is_loop():
code.writeline(f"{self.name} = {x}offset + {x}base")
elif x == "r" and self.kernel.persistent_reduction:
# no need to "roffset = "
code.writeline(
f"{self.name} = {self.ranges_code()}",
)
else:
pid = self.pid_cache_lookup(f"tl.program_id({self.index})")
code.writelines(
Expand Down Expand Up @@ -503,6 +511,7 @@ def __init__(
prefix=parent.prefix,
divisor=divisor,
length=length,
kernel=parent.kernel,
)
self.parent = parent
self.codegen = functools.lru_cache(None)(self._codegen)
Expand Down Expand Up @@ -575,8 +584,9 @@ def __init__(
self.indexing_code = IndentedBuffer()
self.suffix = IndentedBuffer()
self.outside_loop_vars = set()
self.initialize_range_tree(pid_cache)
self.reduction_hint = reduction_hint
self.persistent_reduction = self.should_use_persistent_reduction()
self.initialize_range_tree(pid_cache)

# define this in a closure to make cache local to object
@functools.lru_cache(None)
Expand All @@ -588,6 +598,26 @@ def simplify_indexing(index: sympy.Expr):

self.simplify_indexing = simplify_indexing

def should_use_persistent_reduction(self):
"""
Heuristic to set self.persistent_reduction and add guards
if needed.
"""
if not (self.inside_reduction and config.triton.persistent_reductions):
return False
threshold = {
ReductionHint.INNER: 1024,
}.get(self.reduction_hint, 64)
hint = V.graph.sizevars.size_hint(self.numels[-1])
if hint > threshold:
return False

from triton import next_power_of_2

# will need to recompile if we cross a larger power of 2 boundary
V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
return True

def initialize_range_tree(self, pid_cache):
names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
for i in range(len(self.numels)):
Expand All @@ -598,7 +628,7 @@ def initialize_range_tree(self, pid_cache):
)
for tree in self.range_trees:
# reduction indexing goes inside a loop
if tree.prefix != "r":
if not tree.is_loop():
tree.codegen_header(self.body)
if self.inside_reduction and self.range_trees[-1].is_loop():
# workaround for this issue:
Expand All @@ -612,13 +642,15 @@ def ctx():
assert not self.inside_reduction
yield
return
# calling codegen_body() will flush all the pending buffers
# and write out a reduction loop
self.codegen_body()
if not self.persistent_reduction:
# calling codegen_body() will flush all the pending buffers
# and write out a reduction loop
self.codegen_body()
self.inside_reduction = False
yield
# flush out any code before opening the next loop
self.codegen_body()
if not self.persistent_reduction:
# flush out any code before opening the next loop
self.codegen_body()
self.inside_reduction = True

return ctx()
Expand Down Expand Up @@ -875,7 +907,7 @@ def load(self, name: str, index: sympy.Expr):
original_index = index
index, mask_vars, mask = self.indexing(index)

if "rmask" in mask:
if "rmask" in mask and not self.persistent_reduction:
# This eviction policy heuristic is untested.
# ptillet suggested we should try only doing this for
# the first N-1 loops and not for the final loop.
Expand Down Expand Up @@ -906,6 +938,7 @@ def load(self, name: str, index: sympy.Expr):

if (
self.inside_reduction
and not self.persistent_reduction
and "rmask" not in mask
and "tmp" not in mask
and not indirect_indexing
Expand Down Expand Up @@ -960,7 +993,16 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
result_var.mask_vars = set(var for var in masks if var[0] != "r")
if (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
if self.persistent_reduction:
cond = " & ".join(masks)
masked_value = self.cse.generate(
self.compute, f"tl.where({cond}, {value}, {default})"
)
result_var = self.cse.generate(
self.compute,
f"tl.{reduction_type}({masked_value}, {dim})[{', '.join(sizes)}]",
)
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
default_value = f" + {default}" if default != 0 else ""
Expand Down Expand Up @@ -1046,7 +1088,7 @@ def codegen_body(self):
):
return

if self.inside_reduction:
if self.inside_reduction and not self.persistent_reduction:
self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
with self.body.indent():
# last range tree is always reduction
Expand Down Expand Up @@ -1078,11 +1120,14 @@ def codegen_kernel(self, name=None):
size_hints = [
next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels
]
if not self.inside_reduction:
if self.persistent_reduction:
assert self.inside_reduction
heuristics = "persistent_reduction"
elif self.inside_reduction:
heuristics = "reduction"
else:
size_hints.pop()
heuristics = "pointwise"
else:
heuristics = "reduction"

if name is None:
code.splice(
Expand Down Expand Up @@ -1145,10 +1190,12 @@ def codegen_kernel(self, name=None):
if self.inside_reduction:
reduction_hint = self.reduction_hint
heuristics_line = f"""
@{heuristics}(size_hints={size_hints!r},
reduction_hint={reduction_hint},
filename=__file__,
meta={triton_meta!r})
@{heuristics}(
size_hints={size_hints!r},
reduction_hint={reduction_hint},
filename=__file__,
meta={triton_meta!r}
)
@triton.jit
"""
else:
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Expand Up @@ -175,6 +175,9 @@ class triton:
# should we put op names in kernel names
descriptive_kernel_names = False

# use alternate codegen for smaller reductions
persistent_reductions = True


# create a directory containing lots of debug information
class trace:
Expand Down
28 changes: 28 additions & 0 deletions torch/_inductor/triton_ops/autotune.py
Expand Up @@ -479,6 +479,34 @@ def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
raise NotImplementedError(f"size_hints: {size_hints}")


def persistent_reduction(size_hints, reduction_hint=False, meta=None, filename=None):
xnumel, rnumel = size_hints

configs = [
triton_config_reduction(size_hints, xblock, rnumel)
for xblock in (1, 8, 32, 128)
if rnumel * xblock <= 4096 and xblock <= xnumel
]

# TODO(jansel): we should be able to improve these heuristics
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
configs = configs[:1]
elif reduction_hint == ReductionHint.OUTER:
configs = configs[-1:]
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [
triton_config_reduction(
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
)
]

return cached_autotune(
configs,
meta=meta,
filename=filename,
)


def template(num_stages, num_warps, meta, filename=None):
"""
Compile a triton template
Expand Down

0 comments on commit 6a53a83

Please sign in to comment.