diff --git a/docs/api/kernel.md b/docs/api/kernel.md index 21c814bab..c22c4a30d 100644 --- a/docs/api/kernel.md +++ b/docs/api/kernel.md @@ -90,6 +90,8 @@ result = bound_static(torch.randn(100, 50)) # Must be exactly [100, 50] ```{warning} Helion shape-specializes kernels by default (`static_shapes=True`) for the best performance. Bound kernels and caches require tensors with the exact same shapes and strides as the examples you compile against. Set `static_shapes=False` if you need the same compiled kernel to serve many shapes. +With dynamic shapes (`static_shapes=False`), Helion also specializes on if a tensor shape is 0 or 1 and whether a tensor needs 64-bit indexing (more than ``2**31 - 1`` elements). +This 64-bit indexing specialization can be avoided by setting `index_dtype=torch.int64`. ``` ### BoundKernel Methods @@ -139,6 +141,10 @@ Kernels are automatically cached based on: By default (`static_shapes=True`), Helion treats shapes and strides as compile-time constants, baking them into generated Triton code for the best performance. To reuse a single compiled kernel across size variations, set `static_shapes=False`, which instead buckets each dimension as `{0, 1, ≥2}` and allows more inputs to share the same cache entry. +```{note} +Dynamic buckets also track whether any tensor exceeds the ``torch.int32`` indexing limit so that cache entries diverge as soon as large inputs show up. Set ``index_dtype=torch.int64`` on the kernel to avoid this. +``` + ```python # These create separate cache entries tensor_float = torch.randn(100, dtype=torch.float32, device='cuda') diff --git a/docs/api/settings.md b/docs/api/settings.md index e72b95aa9..7b15ee565 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -81,8 +81,10 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: .. autoattribute:: Settings.index_dtype - The data type used for index variables in generated code. Default is ``torch.int32``. - Override via ``HELION_INDEX_DTYPE=int64`` (or any ``torch.`` name). + The data type used for index variables in generated code. By default Helion auto-selects + between ``torch.int32`` and ``torch.int64`` based on whether any input tensor exceeds + ``torch.iinfo(torch.int32).max`` elements. Override via ``HELION_INDEX_DTYPE=`` + (or set it to ``auto`` to keep the automatic behavior). .. autoattribute:: Settings.dot_precision @@ -259,7 +261,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe | Environment Variable | Maps To | Description | |----------------------|---------|-------------| | ``TRITON_F32_DEFAULT`` | ``dot_precision`` | Sets default floating-point precision for Triton dot products (``"tf32"``, ``"tf32x3"``, ``"ieee"``). | -| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the default index dtype (accepts any ``torch.`` name, e.g. ``int64``). | +| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the index dtype (accepts any ``torch.`` name, e.g. ``int64``), or set to ``auto``/unset to allow Helion to pick ``int32`` vs ``int64`` based on input sizes. | | ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. | | ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). | | ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. | diff --git a/docs/deployment_autotuning.md b/docs/deployment_autotuning.md index 43fd8f87e..b79ee046d 100644 --- a/docs/deployment_autotuning.md +++ b/docs/deployment_autotuning.md @@ -154,8 +154,12 @@ determines when to re-benchmark. Options include: - **`static_shapes=False`:** switch to bucketed dynamic shapes. Helion reuses results as long as tensor dtypes and device types stay constant. Shape changes only trigger a re-selection when a dimension size crosses - the buckets `{0, 1, ≥2}`. Use this when you need one compiled kernel to - handle many input sizes. + the buckets `{0, 1, ≥2}`. Helion also tracks whether any tensor exceeds the + `torch.int32` indexing limit (more than ``2**31 - 1`` elements) and will + automatically regenerate code with 64-bit indexing in that case. Use this + mode when you need one compiled kernel to handle many input sizes, and pin + ``@helion.kernel(..., index_dtype=torch.int64)`` if large tensors are the norm + so you avoid an extra specialization boundary. - **Custom keys:** pass `key=` to group calls however you like. This custom key is in addition to the above. @@ -206,10 +210,13 @@ exact shape/stride signature of the example inputs. The generated code has shapes baked in, which often provides a performance boost. - With `static_shapes=False` it will specialize on the input dtypes, -device types, and whether each dynamic dimension falls into the 0, 1, -or ≥2 bucket. Python types are also specialized. For dimensions that -can vary across those buckets, supply representative inputs ≥2 to avoid -excessive specialization. + device types, and whether each dynamic dimension falls into the 0, 1, + or ≥2 bucket. Python types are also specialized. For dimensions that + can vary across those buckets, supply representative inputs ≥2 to avoid + excessive specialization. Just like the autotuning flow above, Helion + records whether any tensor crosses the int32 indexing limit when + `static_shapes=False`; explicitly set `index_dtype=torch.int64` if your + deployment commonly exceeds that threshold to avoid recompilation. If you need to support multiple input types, bind multiple times with representative inputs. diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index 517a89e23..f3e5e0991 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -460,9 +460,7 @@ def codegen_iota(ctx: LoweringContext, node: Node) -> object: """Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding.""" start = node.kwargs.get("start", 0) step = node.kwargs.get("step", 1) - dtype = ( - node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype - ) + dtype = node.kwargs.get("dtype") or CompileEnvironment.current().index_dtype assert isinstance(dtype, torch.dtype) (length_arg,) = node.args # expecting a single argument for length diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 4317a0e3a..268bfa31c 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -73,12 +73,21 @@ class CompileEnvironment: No config or codegen specific state should be stored here. """ - def __init__(self, device: torch.device, settings: Settings) -> None: + def __init__( + self, + device: torch.device, + settings: Settings, + *, + index_dtype: torch.dtype | None = None, + ) -> None: from ..autotuner.config_spec import ConfigSpec super().__init__() self.device = device self.settings = settings + self.index_dtype: torch.dtype = ( + index_dtype or settings.index_dtype or torch.int32 + ) # TODO(jansel): make backend configurable self.backend = "triton" self.shape_env = ShapeEnv( @@ -383,7 +392,7 @@ def known_multiple(self, a: sympy.Expr, b: int | torch.SymInt) -> bool: def triton_index_type(self) -> str: """tl.int32 or tl.int64 depending on Settings()""" - return triton_type(self.settings.index_dtype) + return triton_type(self.index_dtype) def sympy_debug(self, expr: sympy.Expr) -> str: return str(expr.xreplace(self.debug_shape_renames)) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index b6c8eaa66..f9d05bee6 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -665,7 +665,7 @@ def create( env = CompileEnvironment.current() dtype = env.triton_index_type() if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value): - raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype) + raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype) def _is_size_one(size: int | torch.SymInt) -> bool: return env.known_equal(size, 1) diff --git a/helion/_compiler/reduction_strategy.py b/helion/_compiler/reduction_strategy.py index 1271c77af..e4856fac8 100644 --- a/helion/_compiler/reduction_strategy.py +++ b/helion/_compiler/reduction_strategy.py @@ -332,7 +332,7 @@ def codegen_reduction( ) else: acc_index = self.fn.new_var(f"{state.fx_node.name}_acc_index", dce=True) - index_dtype = CompileEnvironment.current().settings.index_dtype + index_dtype = CompileEnvironment.current().index_dtype device_loop.outer_prefix.append( statement_from_string( f"{acc_index} = tl.full({shape}, {torch.iinfo(index_dtype).max!r}, {triton_type(index_dtype)})" diff --git a/helion/exc.py b/helion/exc.py index 3f4aea565..f29e67d15 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -136,6 +136,13 @@ class IndexOffsetOutOfRangeForInt32(BaseError): ) +class InputTensorNumelExceedsIndexType(BaseError): + message = ( + "Kernel index_dtype is {index_dtype}, but input input tensor is too large to fit. " + "Use @helion.kernel(index_dtype=torch.int64)." + ) + + class DataDependentOutputShapeNotSupported(BaseError): message = ( "{op_desc} is not supported in Helion device loops because it produces " diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index ef12d3168..62075efd1 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -209,7 +209,7 @@ def arange( """ env = CompileEnvironment.current() if dtype is None: - dtype = env.settings.index_dtype + dtype = env.index_dtype return torch.arange( *args, **kwargs, diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 11bace3bd..7a97e79ff 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -50,7 +50,7 @@ def _(tile: torch.SymInt) -> torch.Tensor: assert isinstance(tile, torch.SymInt) env = CompileEnvironment.current() assert env.get_block_id(tile) is not None - return torch.empty([tile], dtype=env.settings.index_dtype, device=env.device) + return torch.empty([tile], dtype=env.index_dtype, device=env.device) @_decorators.codegen(tile_index, "triton") diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index f2a853096..32f5f3ee9 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -15,6 +15,7 @@ from typing import Callable from typing import Generic from typing import Hashable +from typing import Sequence from typing import TypeVar from typing import cast from typing import overload @@ -27,6 +28,7 @@ from torch._inductor.codecache import PyCodeCache from torch._inductor.codecache import compiled_fx_graph_hash from torch._subclasses import FakeTensor +from torch.utils._pytree import tree_map_only from torch.utils.weak import WeakIdKeyDictionary from .. import exc @@ -64,6 +66,36 @@ # Cache for GraphModule hashes _graph_module_hash_cache: WeakIdKeyDictionary = WeakIdKeyDictionary() +_INT32_INDEX_LIMIT = torch.iinfo(torch.int32).max + + +def _resolve_index_dtype( + settings: Settings, + args: Sequence[object] | tuple[object, ...], +) -> torch.dtype: + if (index_dtype := settings.index_dtype) is not None: + limit = torch.iinfo(index_dtype).max + else: + limit = _INT32_INDEX_LIMIT + over_limit = False + + def _check(tensor: torch.Tensor) -> None: + nonlocal over_limit + if over_limit: + return + try: + over_limit = bool(tensor.numel() > limit) + except RuntimeError: # unbacked SymInt + if index_dtype is None: + over_limit = True + + tree_map_only(torch.Tensor, _check, args) + if index_dtype is None: # Auto-select when not provided + return torch.int64 if over_limit else torch.int32 + if over_limit: + raise exc.InputTensorNumelExceedsIndexType(index_dtype=index_dtype) + return index_dtype + class Kernel(Generic[_R]): def __init__( @@ -321,7 +353,11 @@ def __init__( self._run: Callable[..., _R] | None = None self._config: Config | None = None self._compile_cache: dict[Config, CompiledConfig] = {} - self.env = CompileEnvironment(_find_device(args), self.kernel.settings) + self.env = CompileEnvironment( + _find_device(args), + self.kernel.settings, + index_dtype=_resolve_index_dtype(self.kernel.settings, args), + ) if is_ref_mode_enabled(self.kernel.settings): self.fake_args = [] # type: ignore[assignment] @@ -830,11 +866,22 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: (*obj.size(),), (*obj.stride(),), ) + bucketed = tuple([min(s, 2) for s in obj.size()]) + if fn.settings.index_dtype is None: + try: + needs_int64 = bool(obj.numel() > _INT32_INDEX_LIMIT) + except RuntimeError: + needs_int64 = True # unbacked SymInt + return ( + obj.dtype, + obj.device.type, + bucketed, + needs_int64, + ) return ( obj.dtype, obj.device.type, - # 0, 1, or >=2 specialization - tuple([min(s, 2) for s in obj.size()]), + bucketed, ) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 1b0bdb2cd..b48f3b99f 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -138,10 +138,12 @@ def _env_get_str(var_name: str, default: str) -> str: return value -def _get_index_dtype() -> torch.dtype: +def _get_index_dtype() -> torch.dtype | None: value = os.environ.get("HELION_INDEX_DTYPE") if value is None or (token := value.strip()) == "": - return torch.int32 + return None + if token.lower() == "auto": + return None try: dtype = getattr(torch, token) except AttributeError as err: @@ -266,7 +268,9 @@ class _Settings: ignore_warnings: list[type[exc.BaseWarning]] = dataclasses.field( default_factory=_get_ignore_warnings ) - index_dtype: torch.dtype = dataclasses.field(default_factory=_get_index_dtype) + index_dtype: torch.dtype | None = dataclasses.field( + default_factory=_get_index_dtype + ) dot_precision: DotPrecision = dataclasses.field( default_factory=functools.partial( _env_get_literal, @@ -407,8 +411,8 @@ class Settings(_Settings): "Set HELION_IGNORE_WARNINGS=WarningA,WarningB (names from helion.exc) to configure via env." ), "index_dtype": ( - "The dtype to use for index variables. Default is torch.int32. " - "Override with HELION_INDEX_DTYPE=torch.int64, etc." + "The dtype to use for index variables. Default auto-selects torch.int32 or torch.int64 based on input sizes. " + "Override with HELION_INDEX_DTYPE= (or set to 'auto')." ), "dot_precision": "Precision for dot products, see `triton.language.dot`. Can be 'tf32', 'tf32x3', or 'ieee'.", "static_shapes": ( diff --git a/test/test_indexing.py b/test/test_indexing.py index 8e1118390..f06dac6f6 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -415,9 +415,10 @@ def test_int32_offset_out_of_range_error(self): range_warp_specializes=[], ) - def make_kernel(*, index_dtype: torch.dtype): + def make_kernel(*, index_dtype: torch.dtype | None = None): kwargs = {"config": repro_config, "static_shapes": True} - kwargs["index_dtype"] = index_dtype + if index_dtype is not None: + kwargs["index_dtype"] = index_dtype decorator = helion.kernel(**kwargs) @decorator @@ -435,15 +436,19 @@ def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return repro_bf16_add def run_case( - shape, *, index_dtype, expect_int64_in_code=False, expect_error=False - ): + shape, + *, + index_dtype: torch.dtype | None, + expect_int64_in_code: bool = False, + expect_error: type[Exception] | None = None, + ) -> None: kernel = make_kernel(index_dtype=index_dtype) x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16) y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16) torch.accelerator.synchronize() - if expect_error: + if expect_error is not None: with self.assertRaisesRegex( - helion.exc.IndexOffsetOutOfRangeForInt32, + expect_error, f"index_dtype is {index_dtype}", ): code_and_output(kernel, (x, y)) @@ -479,19 +484,47 @@ def run_case( small_shape, index_dtype=torch.int32, expect_int64_in_code=False, - expect_error=False, + expect_error=None, ) run_case( large_shape, index_dtype=torch.int32, expect_int64_in_code=False, - expect_error=True, + expect_error=helion.exc.InputTensorNumelExceedsIndexType, ) run_case( large_shape, index_dtype=torch.int64, expect_int64_in_code=True, - expect_error=False, + expect_error=None, + ) + run_case( + large_shape, + index_dtype=None, + expect_int64_in_code=True, + expect_error=None, + ) + + def test_dynamic_shape_specialization_key_tracks_large_tensors(self) -> None: + @helion.kernel(static_shapes=False) + def passthrough(x: torch.Tensor) -> torch.Tensor: + return x + + @helion.kernel(static_shapes=False, index_dtype=torch.int64) + def passthrough_int64(x: torch.Tensor) -> torch.Tensor: + return x + + meta = "meta" + small = torch.empty((4, 4), device=meta) + large = torch.empty((51200, 51200), device=meta) + + self.assertNotEqual( + passthrough.specialization_key((small,)), + passthrough.specialization_key((large,)), + ) + self.assertEqual( + passthrough_int64.specialization_key((small,)), + passthrough_int64.specialization_key((large,)), ) def test_assign_int(self):