diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 35e6ed84f1d42..4a010ad836cf7 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -9611,7 +9611,12 @@ def test_randint_int64_mod(self): # This used to not compile due to a wrong return type of randint64_cpu # See https://github.com/pytorch/pytorch/issues/117435 def fn(n): - return torch.randint(low=-5, high=5, size=(n,), dtype=torch.int64) % 10 + return ( + torch.randint( + low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device + ) + % 10 + ) res = torch.compile(fn)(20) self.assertTrue(torch.all((0 <= res) & (res < 10)).item()) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f8d278bb94d5e..2b7d6c65704e7 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -27,7 +27,7 @@ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT -from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges from .. import config, metrics from ..utils import DeferredLineBase, IndentedBuffer, sympy_dot, sympy_subs, unique @@ -269,6 +269,7 @@ def deduce_node_dtype(self, node: torch.fx.Node): if node.target in ( "get_index", "index_expr", + "randint64", ): return torch.int64 @@ -529,7 +530,7 @@ def constant(value, dtype): @staticmethod def reciprocal(x): - return ops.truediv("1", x) + return ops.truediv(ops.constant(1, torch.int32), x) @staticmethod def square(x): @@ -566,7 +567,11 @@ def bitwise_right_shift(x, y): @staticmethod def remainder(a, b): r = ops.mod(a, b) - return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) @staticmethod def load_seed(name, offset): @@ -1473,24 +1478,17 @@ def __enter__(self): # TODO: hoist this to top level class CSEProxy: self.name = "CSEProxy" + vr_analysis = ValueRangeAnalysis() @staticmethod def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - # TritonTemplateKernel has no current_node - buf_bounds = ValueRanges.unknown() - if ( - fx_node := getattr(V.interpreter, "current_node", None) - ) and fx_node.target == name: - assert isinstance(self.node_to_bounds, dict) - buf_bounds = self.node_to_bounds.get( - fx_node, ValueRanges.unknown() - ) + bounds = CSEProxy._bound_variable(name, *args, **kwargs) value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] def do_cse(v): - csevar = self.cse.generate(self.compute, v, bounds=buf_bounds) + csevar = self.cse.generate(self.compute, v, bounds=bounds) csevar.update_on_args(name, args, kwargs) return csevar @@ -1498,6 +1496,49 @@ def do_cse(v): return inner + @staticmethod + def _bound_variable(name, *args, **kwargs): + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from ..select_algorithm import TritonTemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name: + assert isinstance(self.node_to_bounds, dict) + return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any( + s in fx_node.target + for s in ("set_indirect", "reduction", "scan") + ): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + assert not kwargs + + def arg_to_bound(x): + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) + else: + return ValueRanges.unknown() + @staticmethod def indirect_indexing( var: CSEVariable, size: sympy.Expr, check: bool = True diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 984178cc8e8b5..e7fb7a9286b45 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -34,6 +34,7 @@ ) from ..utils import ( cache_on_self, + get_bounds_index_expr, get_fused_kernel_name, is_welford_reduction, parallel_num_threads, @@ -841,7 +842,7 @@ def mod(a, b): @staticmethod def constant(val, dtype): opt_ctx: OptimizationContext = get_current_node_opt_ctx() - assert opt_ctx and opt_ctx.dtype is not None + assert opt_ctx and opt_ctx.dtype is not None, opt_ctx dtype = opt_ctx.dtype if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, constants @@ -854,7 +855,12 @@ def index_expr(expr, dtype): opt_ctx: OptimizationContext = get_current_node_opt_ctx() assert opt_ctx and opt_ctx.dtype is not None dtype = opt_ctx.dtype - return ops.to_dtype(cexpr(V.kernel.rename_indexing(expr)), dtype) + + idx_str = cexpr(V.kernel.rename_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) @staticmethod def masked(mask, body, other): @@ -1451,7 +1457,10 @@ def index_expr(expr, dtype): if stride == 0: return CppOverrides.index_expr(expr, dtype) elif stride is not None: - value = ops.to_dtype(cexpr(index), dtype) + idx = V.kernel.cse.generate( + V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) + ) + value = ops.to_dtype(idx, dtype) if isinstance(value, OpsValue): value = value.value csevar = V.kernel.arange(value, stride) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 81fc19a0f3f46..594d5f004e38f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -58,6 +58,7 @@ from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse from ..utils import ( cache_on_self, + get_bounds_index_expr, get_dtype_size, get_fused_kernel_name, get_kernel_metadata, @@ -619,7 +620,7 @@ def relu(x): elif bug == "accuracy": return f"{x} + 1" elif bug is None: - return ops.maximum("0", x) + return ops.maximum(ops.constant(0, torch.int32), x) else: raise AssertionError( f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" @@ -864,11 +865,9 @@ def floordiv(a, b): @staticmethod def sign(x): - def to_int(s): - return f"{s}.to(tl.int8)" - - left = to_int(ops.lt("0", x)) - right = to_int(ops.lt(x, "0")) + z = ops.constant(0, torch.int32) + left = ops.to_dtype((ops.lt(z, x)), torch.int8) + right = ops.to_dtype((ops.lt(x, z)), torch.int8) sub = ops.sub(left, right) return f"{sub}.to({x}.dtype)" @@ -916,8 +915,9 @@ def constant(cls, value, dtype): def index_expr(cls, expr, dtype): indexing = V.kernel.indexing(expr, block_ptr=False) assert isinstance(indexing, IndexingOptions) - # This is called from CSEProxy.__getattr__, so we'll set the bounds there - var = V.kernel.cse.generate(V.kernel.compute, indexing.index_str) + var = V.kernel.cse.generate( + V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr) + ) if dtype not in {torch.int32, torch.int64}: var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype)) @@ -929,10 +929,14 @@ def masked(mask, body, other): with V.kernel.mask_loads(mask) as new_mask: result = body() + # Remove once CSEVariables track the dtype + if result.bounds.is_bool: + other = bool(other) # Take dtype from result to prevent accidental promotion other = V.kernel.cse.generate( V.kernel.compute, f"tl.full({result}.shape, {triton_constant(other)}, {result}.dtype)", + bounds=ValueRanges.wrap(other), ) return ops.where(new_mask, result, other) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d09621affe52c..538b7e69fb4c8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -355,6 +355,9 @@ def is_fbcode(): # assert that indirect indexing does not read / write out of bounds assert_indirect_indexing = True +# compute CSE bounds on variables that do not appear in the FX graph +compute_all_bounds = False + # constant folding on the joint graph joint_graph_constant_folding = True diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 782e9d7764531..f632f278a0cdf 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -49,6 +49,7 @@ from torch.fx.passes.shape_prop import ShapeProp from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import make_symbol, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from . import config from .runtime.runtime_utils import ceildiv as runtime_ceildiv @@ -539,6 +540,20 @@ def sympy_str(expr: sympy.Expr) -> str: return str(expr) +def get_bounds_index_expr(index): + from .virtualized import V + + # If this expression does not come from an FX node, we compute its bounds + if ( + config.compute_all_bounds + and (fx_node := getattr(V.interpreter, "current_node", None)) + and fx_node.target != "index_expr" + ): + return bound_sympy(index) + else: + return ValueRanges.unknown() + + def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: """ Used to generate an integer-nonnegative symbol. diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index a056db6dbb9eb..f2319e930d769 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -138,7 +138,7 @@ def __init__(self, lower: AllIn, upper: AllIn) -> None: if not sympy_generic_le(lower, upper): raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") except TypeError: - raise TypeError(f"Could not compare {lower} <= {upper}") + raise TypeError(f"Could not compare {lower} <= {upper}") # noqa: TRY200 # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) @@ -340,6 +340,9 @@ class SymPyValueRangeAnalysis: @staticmethod def constant(value, dtype): + if isinstance(value, ValueRanges): + assert value.is_singleton() + value = value.lower # NB: value is NOT a sympy expression, it's a constant! is_python = isinstance(value, (int, float, bool)) assert is_python or isinstance( @@ -663,7 +666,9 @@ def where(a, b, c): b = ValueRanges.wrap(b) c = ValueRanges.wrap(c) a = a.boolify() - assert b.is_bool == c.is_bool + # We sometimes write unknown without specifying the type correctly + # In particular, we do that when initialising the bounds for loads in bounds.py + assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c) if b.is_bool: return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper)) else: