Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Enable multilayer reductions with dynamic shapes #106747

Closed
wants to merge 6 commits into from
9 changes: 9 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,15 @@ def fn(a):

self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))

def test_multilayer_prime_size(self):
def fn(a):
return torch.max(a), torch.sum(a)

# Requires masked loading for the intermediate reduction
sample = torch.full((3999971,), 0, dtype=torch.int64)
sample[-1] = 1
self.common(fn, (sample,))

def test_expanded_reduction(self):
if self.device == "cpu":
raise unittest.SkipTest(
Expand Down
27 changes: 9 additions & 18 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,9 @@ def num_splits(
def _is_static(x):
return isinstance(x, (int, sympy.Integer))

reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))

Comment on lines +598 to +600
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe as a follow up, add guards based on the heuristics here:

def inner_reduction_splits(reduction_numel_hint, numel_hint):

should_split = (
is_triton(device)
and reduction_type
Expand All @@ -604,9 +607,9 @@ def _is_static(x):
"var_unnormalized",
}
and config.split_reductions
and all(_is_static(r) for r in ranges)
and all(_is_static(r) for r in reduction_ranges)
and _is_static(reduction_numel)
# We don't support unbacked symints
and _is_static(reduction_numel_hint)
and _is_static(numel_hint)
Comment on lines +611 to +612
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be the case that numel is static but reduction_numel is not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we probably just want _is_static(reduction_numel_hint) for the general case. Then, if we also have numel_hint static even better, but it's not 100% necessary for the main optimisation, I don't think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be the case that numel is static but reduction_numel is not?

numel here is actually the number of output elements, so doesn't include the reduced dimensions.

In fact, we probably just want _is_static(reduction_numel_hint) for the general case. Then, if we also have numel_hint static even better, but it's not 100% necessary for the main optimisation, I don't think.

numel_hint is used in conditionals like numel_hint >= num_sm * 2 * 32 when deciding on the numbers of splits, so we need to have a concrete value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw as a further step we could use bound_sympy to deal with unbacked SymInts but thats more than I need at this point to get cumsum working.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bound_sympy does not currently deal with unbacked SymInts, but may be able to do something in some cases when #106568 is merged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say I had an expression s0*100 would bound_sympy not give a lower bound of 100 since shape variables are positive?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That currently happens if the symbol is marked as non-negative. When that PR is merged, we'll leverage all the other information we may as part of the value range analysis and the constraints that are put in place during tracing time.

)
if not should_split:
return ReductionHint.DEFAULT, 1
Expand Down Expand Up @@ -689,8 +692,6 @@ def outer_reduction_splits(reduction_numel_hint, numel_hint):
rvals_per_thread * split_size
)

reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
# easy cases
if numel_hint == 1:
return ReductionHint.INNER, inner_reduction_splits(
Expand Down Expand Up @@ -994,24 +995,14 @@ def create_multilayer(
"""
reduction_numel = sympy_product(reduction_ranges)

# TODO(jansel): convert this to dynamic shapes
# TODO(jansel): realize the reduction so we can do dynamic indexing
reduction_ranges = [
sympy.Integer(V.graph.sizevars.evaluate_static_shape(s))
for s in reduction_ranges
]
reduction_numel = sympy.Integer(
V.graph.sizevars.evaluate_static_shape(reduction_numel)
need_mask = not V.graph.sizevars.is_expr_static_and_true(
sympy.Eq(reduction_numel % split, 0)
)

if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
need_mask = False
else:
need_mask = True

split = sympy.Integer(split)
block_size = FloorDiv(reduction_numel + (split - 1), split)

# TODO(jansel): realize the reduction so we can do dynamic indexing
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])

def wrapper_fn(index, reduction_index):
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ def evaluate_static_shape(self, left: Expr) -> int:
def evaluate_static_shapes(self, left: List[Expr]) -> List[int]:
return [self.evaluate_static_shape(x) for x in left]

def size_hint(self, expr: Expr) -> int:
def symbolic_hint(self, expr: Expr) -> Expr:
# Substitute all hints into expr, but leave unbacked symints alone
if not isinstance(expr, Expr):
assert isinstance(expr, int)
return expr
Expand All @@ -368,7 +369,10 @@ def size_hint(self, expr: Expr) -> int:
while any(s.name.startswith("ps") for s in free_symbols):
expr = sympy_subs(expr, self.inv_precomputed_replacements)
free_symbols = expr.free_symbols
out = sympy_subs(expr, self.var_to_val)
return sympy_subs(expr, self.var_to_val)

def size_hint(self, expr: Expr) -> int:
out = self.symbolic_hint(expr)
try:
return int(out)
except Exception:
Expand Down