Skip to content

Commit

Permalink
WIP: [inductor] Enable multilayer reductions with dynamic shapes
Browse files Browse the repository at this point in the history
Currently multilayer reduction (aka split reductions) are only used with static
shapes which results in worse performance and accuracy when dynamic shapes are
enabled. Instead, this only requires that the shape has a hint value.

ghstack-source-id: 3a689e9a6ac41cda4a517d94efe012c091716cd7
Pull Request resolved: pytorch#106747
  • Loading branch information
peterbell10 committed Aug 8, 2023
1 parent 3aa78ac commit e015da7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
10 changes: 10 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,16 @@ def fn(a):

self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))

def test_multilayer_prime_size(self):

def fn(a):
return torch.max(a), torch.sum(a)

# Requires masked loading for the intermediate reduction
sample = torch.full((3999971,), torch.iinfo(torch.int64).min, dtype=torch.int64)
sample[-1] = 0
self.common(fn, (sample,))

def test_expanded_reduction(self):
if self.device == "cpu":
raise unittest.SkipTest(
Expand Down
27 changes: 9 additions & 18 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,9 @@ def num_splits(
def _is_static(x):
return isinstance(x, (int, sympy.Integer))

reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))

should_split = (
is_triton(device)
and reduction_type
Expand All @@ -604,9 +607,9 @@ def _is_static(x):
"var_unnormalized",
}
and config.split_reductions
and all(_is_static(r) for r in ranges)
and all(_is_static(r) for r in reduction_ranges)
and _is_static(reduction_numel)
# We don't support unbacked symints
and _is_static(reduction_numel_hint)
and _is_static(numel_hint)
)
if not should_split:
return ReductionHint.DEFAULT, 1
Expand Down Expand Up @@ -689,8 +692,6 @@ def outer_reduction_splits(reduction_numel_hint, numel_hint):
rvals_per_thread * split_size
)

reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
# easy cases
if numel_hint == 1:
return ReductionHint.INNER, inner_reduction_splits(
Expand Down Expand Up @@ -994,24 +995,14 @@ def create_multilayer(
"""
reduction_numel = sympy_product(reduction_ranges)

# TODO(jansel): convert this to dynamic shapes
# TODO(jansel): realize the reduction so we can do dynamic indexing
reduction_ranges = [
sympy.Integer(V.graph.sizevars.evaluate_static_shape(s))
for s in reduction_ranges
]
reduction_numel = sympy.Integer(
V.graph.sizevars.evaluate_static_shape(reduction_numel)
need_mask = not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(reduction_numel % split, 0)
)

if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
need_mask = False
else:
need_mask = True

split = sympy.Integer(split)
block_size = FloorDiv(reduction_numel + (split - 1), split)

# TODO(jansel): realize the reduction so we can do dynamic indexing
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])

def wrapper_fn(index, reduction_index):
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ def evaluate_static_shape(self, left: Expr) -> int:
def evaluate_static_shapes(self, left: List[Expr]) -> List[int]:
return [self.evaluate_static_shape(x) for x in left]

def size_hint(self, expr: Expr) -> int:
def symbolic_hint(self, expr: Expr) -> Expr:
# Substitute all hints into expr, but leave unbacked symints alone
if not isinstance(expr, Expr):
assert isinstance(expr, int)
return expr
Expand All @@ -368,7 +369,10 @@ def size_hint(self, expr: Expr) -> int:
while any(s.name.startswith("ps") for s in free_symbols):
expr = sympy_subs(expr, self.inv_precomputed_replacements)
free_symbols = expr.free_symbols
out = sympy_subs(expr, self.var_to_val)
return sympy_subs(expr, self.var_to_val)

def size_hint(self, expr: Expr) -> int:
out = self.symbolic_hint(expr)
try:
return int(out)
except Exception:
Expand Down

0 comments on commit e015da7

Please sign in to comment.