Skip to content

Commit

Permalink
Compute bounds for the variables created during codegen
Browse files Browse the repository at this point in the history
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

ghstack-source-id: da22cfa72245b7608da154079399039c57ea9192
Pull Request resolved: #123100
  • Loading branch information
lezcano committed Apr 1, 2024
1 parent 0a8d613 commit 268bd16
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges

from .. import config, metrics
from ..utils import (
Expand Down Expand Up @@ -1468,6 +1468,28 @@ def inner(*args, **kwargs):
buf_bounds = self.node_to_bounds.get(
fx_node, ValueRanges.unknown()
)
elif bound_handler := getattr(ValueRangeAnalysis, name, None):
# If there is no FX bound but we know how to compute one we do so
assert not kwargs
arg_bounds = []
for x in args:
if isinstance(x, CSEVariable):
arg_bounds.append(x.bounds)
elif isinstance(x, str):
# No current node here
if fx_node is None:
break
# This always comes from an index_expr, otherwise it'd be a CSEVariable
assert fx_node.target == "index_expr"
# TODO a better fix would be to already return a CSEVariable with the bound computed
sympy_expr = V.kernel.current_node._body.indexing_exprs[
fx_node.args[1].args[0]
]
arg_bounds.append(bound_sympy(sympy_expr))
else:
arg_bounds.append(x)
else:
buf_bounds = bound_handler(*arg_bounds)

value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]

Expand Down

0 comments on commit 268bd16

Please sign in to comment.