Skip to content

Commit c0e1160

Browse files
committed
Auto-select index_dtype
Fixes #1123 stack-info: PR: #1131, branch: jansel/stack/227
1 parent ba5becc commit c0e1160

File tree

13 files changed

+148
-35
lines changed

13 files changed

+148
-35
lines changed

docs/api/kernel.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ result = bound_static(torch.randn(100, 50)) # Must be exactly [100, 50]
9090

9191
```{warning}
9292
Helion shape-specializes kernels by default (`static_shapes=True`) for the best performance. Bound kernels and caches require tensors with the exact same shapes and strides as the examples you compile against. Set `static_shapes=False` if you need the same compiled kernel to serve many shapes.
93+
With dynamic shapes (`static_shapes=False`), Helion also specializes on if a tensor shape is 0 or 1 and whether a tensor needs 64-bit indexing (more than ``2**31 - 1`` elements).
94+
This 64-bit indexing specialization can be avoided by setting `index_dtype=torch.int64`.
9395
```
9496

9597
### BoundKernel Methods
@@ -139,6 +141,10 @@ Kernels are automatically cached based on:
139141

140142
By default (`static_shapes=True`), Helion treats shapes and strides as compile-time constants, baking them into generated Triton code for the best performance. To reuse a single compiled kernel across size variations, set `static_shapes=False`, which instead buckets each dimension as `{0, 1, ≥2}` and allows more inputs to share the same cache entry.
141143

144+
```{note}
145+
Dynamic buckets also track whether any tensor exceeds the ``torch.int32`` indexing limit so that cache entries diverge as soon as large inputs show up. If your deployment regularly touches that regime, pin ``index_dtype=torch.int64`` on the kernel to avoid a cache miss or limit errors.
146+
```
147+
142148
```python
143149
# These create separate cache entries
144150
tensor_float = torch.randn(100, dtype=torch.float32, device='cuda')

docs/api/settings.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
8181
8282
.. autoattribute:: Settings.index_dtype
8383
84-
The data type used for index variables in generated code. Default is ``torch.int32``.
85-
Override via ``HELION_INDEX_DTYPE=int64`` (or any ``torch.<dtype>`` name).
84+
The data type used for index variables in generated code. By default Helion auto-selects
85+
between ``torch.int32`` and ``torch.int64`` based on whether any input tensor exceeds
86+
``torch.iinfo(torch.int32).max`` elements. Override via ``HELION_INDEX_DTYPE=<dtype>``
87+
(or set it to ``auto`` to keep the automatic behavior).
8688
8789
.. autoattribute:: Settings.dot_precision
8890
@@ -259,7 +261,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
259261
| Environment Variable | Maps To | Description |
260262
|----------------------|---------|-------------|
261263
| ``TRITON_F32_DEFAULT`` | ``dot_precision`` | Sets default floating-point precision for Triton dot products (``"tf32"``, ``"tf32x3"``, ``"ieee"``). |
262-
| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the default index dtype (accepts any ``torch.<dtype>`` name, e.g. ``int64``). |
264+
| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the index dtype (accepts any ``torch.<dtype>`` name, e.g. ``int64``), or set to ``auto``/unset to allow Helion to pick ``int32`` vs ``int64`` based on input sizes. |
263265
| ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. |
264266
| ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). |
265267
| ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. |

docs/deployment_autotuning.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,12 @@ determines when to re-benchmark. Options include:
154154
- **`static_shapes=False`:** switch to bucketed dynamic shapes. Helion
155155
reuses results as long as tensor dtypes and device types stay constant.
156156
Shape changes only trigger a re-selection when a dimension size crosses
157-
the buckets `{0, 1, ≥2}`. Use this when you need one compiled kernel to
158-
handle many input sizes.
157+
the buckets `{0, 1, ≥2}`. Helion also tracks whether any tensor exceeds the
158+
`torch.int32` indexing limit (more than ``2**31 - 1`` elements) and will
159+
automatically regenerate code with 64-bit indexing in that case. Use this
160+
mode when you need one compiled kernel to handle many input sizes, and pin
161+
``@helion.kernel(..., index_dtype=torch.int64)`` if large tensors are the norm
162+
so you avoid an extra specialization boundary.
159163

160164
- **Custom keys:** pass `key=` to group calls however you like.
161165
This custom key is in addition to the above.
@@ -206,10 +210,13 @@ exact shape/stride signature of the example inputs. The generated code
206210
has shapes baked in, which often provides a performance boost.
207211

208212
- With `static_shapes=False` it will specialize on the input dtypes,
209-
device types, and whether each dynamic dimension falls into the 0, 1,
210-
or ≥2 bucket. Python types are also specialized. For dimensions that
211-
can vary across those buckets, supply representative inputs ≥2 to avoid
212-
excessive specialization.
213+
device types, and whether each dynamic dimension falls into the 0, 1,
214+
or ≥2 bucket. Python types are also specialized. For dimensions that
215+
can vary across those buckets, supply representative inputs ≥2 to avoid
216+
excessive specialization. Just like the autotuning flow above, Helion
217+
records whether any tensor crosses the int32 indexing limit when
218+
`static_shapes=False`; explicitly set `index_dtype=torch.int64` if your
219+
deployment commonly exceeds that threshold to avoid recompilation.
213220

214221
If you need to support multiple input types, bind multiple times with
215222
representative inputs.

helion/_compiler/aten_lowering.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,7 @@ def codegen_iota(ctx: LoweringContext, node: Node) -> object:
460460
"""Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding."""
461461
start = node.kwargs.get("start", 0)
462462
step = node.kwargs.get("step", 1)
463-
dtype = (
464-
node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype
465-
)
463+
dtype = node.kwargs.get("dtype") or CompileEnvironment.current().index_dtype
466464
assert isinstance(dtype, torch.dtype)
467465
(length_arg,) = node.args # expecting a single argument for length
468466

helion/_compiler/compile_environment.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,21 @@ class CompileEnvironment:
7373
No config or codegen specific state should be stored here.
7474
"""
7575

76-
def __init__(self, device: torch.device, settings: Settings) -> None:
76+
def __init__(
77+
self,
78+
device: torch.device,
79+
settings: Settings,
80+
*,
81+
index_dtype: torch.dtype | None = None,
82+
) -> None:
7783
from ..autotuner.config_spec import ConfigSpec
7884

7985
super().__init__()
8086
self.device = device
8187
self.settings = settings
88+
self.index_dtype: torch.dtype = (
89+
index_dtype or settings.index_dtype or torch.int32
90+
)
8291
# TODO(jansel): make backend configurable
8392
self.backend = "triton"
8493
self.shape_env = ShapeEnv(
@@ -383,7 +392,7 @@ def known_multiple(self, a: sympy.Expr, b: int | torch.SymInt) -> bool:
383392

384393
def triton_index_type(self) -> str:
385394
"""tl.int32 or tl.int64 depending on Settings()"""
386-
return triton_type(self.settings.index_dtype)
395+
return triton_type(self.index_dtype)
387396

388397
def sympy_debug(self, expr: sympy.Expr) -> str:
389398
return str(expr.xreplace(self.debug_shape_renames))

helion/_compiler/indexing_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def create(
665665
env = CompileEnvironment.current()
666666
dtype = env.triton_index_type()
667667
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
668-
raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype)
668+
raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype)
669669

670670
def _is_size_one(size: int | torch.SymInt) -> bool:
671671
return env.known_equal(size, 1)

helion/_compiler/reduction_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def codegen_reduction(
332332
)
333333
else:
334334
acc_index = self.fn.new_var(f"{state.fx_node.name}_acc_index", dce=True)
335-
index_dtype = CompileEnvironment.current().settings.index_dtype
335+
index_dtype = CompileEnvironment.current().index_dtype
336336
device_loop.outer_prefix.append(
337337
statement_from_string(
338338
f"{acc_index} = tl.full({shape}, {torch.iinfo(index_dtype).max!r}, {triton_type(index_dtype)})"

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ class IndexOffsetOutOfRangeForInt32(BaseError):
136136
)
137137

138138

139+
class InputTensorNumelExceedsIndexType(BaseError):
140+
message = (
141+
"Kernel index_dtype is {index_dtype}, but input input tensor is too large to fit. "
142+
"Use @helion.kernel(index_dtype=torch.int64)."
143+
)
144+
145+
139146
class DataDependentOutputShapeNotSupported(BaseError):
140147
message = (
141148
"{op_desc} is not supported in Helion device loops because it produces "

helion/language/creation_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def arange(
209209
"""
210210
env = CompileEnvironment.current()
211211
if dtype is None:
212-
dtype = env.settings.index_dtype
212+
dtype = env.index_dtype
213213
return torch.arange(
214214
*args,
215215
**kwargs,

helion/language/tile_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _(tile: torch.SymInt) -> torch.Tensor:
5050
assert isinstance(tile, torch.SymInt)
5151
env = CompileEnvironment.current()
5252
assert env.get_block_id(tile) is not None
53-
return torch.empty([tile], dtype=env.settings.index_dtype, device=env.device)
53+
return torch.empty([tile], dtype=env.index_dtype, device=env.device)
5454

5555

5656
@_decorators.codegen(tile_index, "triton")

0 commit comments

Comments
 (0)