Skip to content

Commit

Permalink
Compute bounds for the variables created during codegen (#123100)
Browse files Browse the repository at this point in the history
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: #123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
  • Loading branch information
lezcano authored and pytorchmergebot committed May 6, 2024
1 parent 3827810 commit bb668c6
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 27 deletions.
7 changes: 6 additions & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9611,7 +9611,12 @@ def test_randint_int64_mod(self):
# This used to not compile due to a wrong return type of randint64_cpu
# See https://github.com/pytorch/pytorch/issues/117435
def fn(n):
return torch.randint(low=-5, high=5, size=(n,), dtype=torch.int64) % 10
return (
torch.randint(
low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device
)
% 10
)

res = torch.compile(fn)(20)
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())
Expand Down
67 changes: 54 additions & 13 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges

from .. import config, metrics
from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique
Expand Down Expand Up @@ -269,6 +269,7 @@ def deduce_node_dtype(self, node: torch.fx.Node):
if node.target in (
"get_index",
"index_expr",
"randint64",
):
return torch.int64

Expand Down Expand Up @@ -529,7 +530,7 @@ def constant(value, dtype):

@staticmethod
def reciprocal(x):
return ops.truediv("1", x)
return ops.truediv(ops.constant(1, torch.int32), x)

@staticmethod
def square(x):
Expand Down Expand Up @@ -566,7 +567,11 @@ def bitwise_right_shift(x, y):
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)

@staticmethod
def load_seed(name, offset):
Expand Down Expand Up @@ -1473,31 +1478,67 @@ def __enter__(self):
# TODO: hoist this to top level
class CSEProxy:
self.name = "CSEProxy"
vr_analysis = ValueRangeAnalysis()

@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
# TritonTemplateKernel has no current_node
buf_bounds = ValueRanges.unknown()
if (
fx_node := getattr(V.interpreter, "current_node", None)
) and fx_node.target == name:
assert isinstance(self.node_to_bounds, dict)
buf_bounds = self.node_to_bounds.get(
fx_node, ValueRanges.unknown()
)
bounds = CSEProxy._bound_variable(name, *args, **kwargs)

value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]

def do_cse(v):
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
csevar = self.cse.generate(self.compute, v, bounds=bounds)
csevar.update_on_args(name, args, kwargs)
return csevar

return pytree.tree_map(do_cse, value)

return inner

@staticmethod
def _bound_variable(name, *args, **kwargs):
"""
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
"""
from ..select_algorithm import TritonTemplateKernel

if isinstance(V.kernel, TritonTemplateKernel):
return ValueRanges.unknown()

fx_node = V.interpreter.current_node
if fx_node.target == name:
assert isinstance(self.node_to_bounds, dict)
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
# These create lots of inner strings. We would need to compute the bounds at the ops
# We will also likely not get much from computing VRs on these nodes
if any(
s in fx_node.target
for s in ("set_indirect", "reduction", "scan")
):
return ValueRanges.unknown()

# We assume that the inputs come from `ops.` and are not strings. If you want to generate
# intermediary strings, wrap them in CSE variables with properly initialised bounds.

# If there is no FX bound but we know how to compute one we do so
assert not kwargs

def arg_to_bound(x):
if isinstance(x, CSEVariable):
return x.bounds
elif isinstance(x, sympy.Expr):
return bound_sympy(x)
else:
return x

arg_bounds = list(map(arg_to_bound, args))
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
else:
return ValueRanges.unknown()

@staticmethod
def indirect_indexing(
var: CSEVariable, size: sympy.Expr, check: bool = True
Expand Down
15 changes: 12 additions & 3 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from ..utils import (
cache_on_self,
get_bounds_index_expr,
get_fused_kernel_name,
is_welford_reduction,
parallel_num_threads,
Expand Down Expand Up @@ -841,7 +842,7 @@ def mod(a, b):
@staticmethod
def constant(val, dtype):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx and opt_ctx.dtype is not None
assert opt_ctx and opt_ctx.dtype is not None, opt_ctx
dtype = opt_ctx.dtype
if dtype in DTYPE_LOWP_FP:
# Since load promotes all half-precision inputs to float, constants
Expand All @@ -854,7 +855,12 @@ def index_expr(expr, dtype):
opt_ctx: OptimizationContext = get_current_node_opt_ctx()
assert opt_ctx and opt_ctx.dtype is not None
dtype = opt_ctx.dtype
return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype)

idx_str = cexpr(V.kernel.rename_indexing(expr))
var = V.kernel.cse.generate(
V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
)
return ops.to_dtype(var, dtype)

@staticmethod
def masked(mask, body, other):
Expand Down Expand Up @@ -1451,7 +1457,10 @@ def index_expr(expr, dtype):
if stride == 0:
return CppOverrides.index_expr(expr, dtype)
elif stride is not None:
value = ops.to_dtype(cexpr(index), dtype)
idx = V.kernel.cse.generate(
V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
)
value = ops.to_dtype(idx, dtype)
if isinstance(value, OpsValue):
value = value.value
csevar = V.kernel.arange(value, stride)
Expand Down
20 changes: 12 additions & 8 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
from ..utils import (
cache_on_self,
get_bounds_index_expr,
get_dtype_size,
get_fused_kernel_name,
get_kernel_metadata,
Expand Down Expand Up @@ -619,7 +620,7 @@ def relu(x):
elif bug == "accuracy":
return f"{x} + 1"
elif bug is None:
return ops.maximum("0", x)
return ops.maximum(ops.constant(0, torch.int32), x)
else:
raise AssertionError(
f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
Expand Down Expand Up @@ -864,11 +865,9 @@ def floordiv(a, b):

@staticmethod
def sign(x):
def to_int(s):
return f"{s}.to(tl.int8)"

left = to_int(ops.lt("0", x))
right = to_int(ops.lt(x, "0"))
z = ops.constant(0, torch.int32)
left = ops.to_dtype((ops.lt(z, x)), torch.int8)
right = ops.to_dtype((ops.lt(x, z)), torch.int8)
sub = ops.sub(left, right)
return f"{sub}.to({x}.dtype)"

Expand Down Expand Up @@ -916,8 +915,9 @@ def constant(cls, value, dtype):
def index_expr(cls, expr, dtype):
indexing = V.kernel.indexing(expr, block_ptr=False)
assert isinstance(indexing, IndexingOptions)
# This is called from CSEProxy.__getattr__, so we'll set the bounds there
var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str)
var = V.kernel.cse.generate(
V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
)

if dtype not in {torch.int32, torch.int64}:
var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
Expand All @@ -929,10 +929,14 @@ def masked(mask, body, other):
with V.kernel.mask_loads(mask) as new_mask:
result = body()

# Remove once CSEVariables track the dtype
if result.bounds.is_bool:
other = bool(other)
# Take dtype from result to prevent accidental promotion
other = V.kernel.cse.generate(
V.kernel.compute,
f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)",
bounds=ValueRanges.wrap(other),
)
return ops.where(new_mask, result, other)

Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def is_fbcode():
# assert that indirect indexing does not read / write out of bounds
assert_indirect_indexing = True

# compute CSE bounds on variables that do not appear in the FX graph
compute_all_bounds = False

# constant folding on the joint graph
joint_graph_constant_folding = True

Expand Down
15 changes: 15 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from torch.fx.passes.shape_prop import ShapeProp
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import make_symbol, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from . import config
from .runtime.runtime_utils import ceildiv as runtime_ceildiv

Expand Down Expand Up @@ -539,6 +540,20 @@ def sympy_str(expr: sympy.Expr) -> str:
return str(expr)


def get_bounds_index_expr(index):
from .virtualized import V

# If this expression does not come from an FX node, we compute its bounds
if (
config.compute_all_bounds
and (fx_node := getattr(V.interpreter, "current_node", None))
and fx_node.target != "index_expr"
):
return bound_sympy(index)
else:
return ValueRanges.unknown()


def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
"""
Used to generate an integer-nonnegative symbol.
Expand Down
9 changes: 7 additions & 2 deletions torch/utils/_sympy/value_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None:
if not sympy_generic_le(lower, upper):
raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
except TypeError:
raise TypeError(f"Could not compare {lower} <= {upper}")
raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200
# Because this is a frozen class
object.__setattr__(self, "lower", lower)
object.__setattr__(self, "upper", upper)
Expand Down Expand Up @@ -340,6 +340,9 @@ class SymPyValueRangeAnalysis:

@staticmethod
def constant(value, dtype):
if isinstance(value, ValueRanges):
assert value.is_singleton()
value = value.lower
# NB: value is NOT a sympy expression, it's a constant!
is_python = isinstance(value, (int, float, bool))
assert is_python or isinstance(
Expand Down Expand Up @@ -663,7 +666,9 @@ def where(a, b, c):
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
a = a.boolify()
assert b.is_bool == c.is_bool
# We sometimes write unknown without specifying the type correctly
# In particular, we do that when initialising the bounds for loads in bounds.py
assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
if b.is_bool:
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
else:
Expand Down

1 comment on commit bb668c6

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #123100 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing inductor tests https://hud.pytorch.org/pytorch/pytorch/commit/bb668c6468dd4adf7737a069e7af4c3f612cfc81 (comment)

Please sign in to comment.