Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api/kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ result = bound_static(torch.randn(100, 50)) # Must be exactly [100, 50]

```{warning}
Helion shape-specializes kernels by default (`static_shapes=True`) for the best performance. Bound kernels and caches require tensors with the exact same shapes and strides as the examples you compile against. Set `static_shapes=False` if you need the same compiled kernel to serve many shapes.
With dynamic shapes (`static_shapes=False`), Helion also specializes on if a tensor shape is 0 or 1 and whether a tensor needs 64-bit indexing (more than ``2**31 - 1`` elements).
This 64-bit indexing specialization can be avoided by setting `index_dtype=torch.int64`.
```

### BoundKernel Methods
Expand Down Expand Up @@ -139,6 +141,10 @@ Kernels are automatically cached based on:

By default (`static_shapes=True`), Helion treats shapes and strides as compile-time constants, baking them into generated Triton code for the best performance. To reuse a single compiled kernel across size variations, set `static_shapes=False`, which instead buckets each dimension as `{0, 1, ≥2}` and allows more inputs to share the same cache entry.

```{note}
Dynamic buckets also track whether any tensor exceeds the ``torch.int32`` indexing limit so that cache entries diverge as soon as large inputs show up. Set ``index_dtype=torch.int64`` on the kernel to avoid this.
```

```python
# These create separate cache entries
tensor_float = torch.randn(100, dtype=torch.float32, device='cuda')
Expand Down
8 changes: 5 additions & 3 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:

.. autoattribute:: Settings.index_dtype

The data type used for index variables in generated code. Default is ``torch.int32``.
Override via ``HELION_INDEX_DTYPE=int64`` (or any ``torch.<dtype>`` name).
The data type used for index variables in generated code. By default Helion auto-selects
between ``torch.int32`` and ``torch.int64`` based on whether any input tensor exceeds
``torch.iinfo(torch.int32).max`` elements. Override via ``HELION_INDEX_DTYPE=<dtype>``
(or set it to ``auto`` to keep the automatic behavior).

.. autoattribute:: Settings.dot_precision

Expand Down Expand Up @@ -259,7 +261,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
| Environment Variable | Maps To | Description |
|----------------------|---------|-------------|
| ``TRITON_F32_DEFAULT`` | ``dot_precision`` | Sets default floating-point precision for Triton dot products (``"tf32"``, ``"tf32x3"``, ``"ieee"``). |
| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the default index dtype (accepts any ``torch.<dtype>`` name, e.g. ``int64``). |
| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the index dtype (accepts any ``torch.<dtype>`` name, e.g. ``int64``), or set to ``auto``/unset to allow Helion to pick ``int32`` vs ``int64`` based on input sizes. |
| ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. |
| ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). |
| ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. |
Expand Down
19 changes: 13 additions & 6 deletions docs/deployment_autotuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ determines when to re-benchmark. Options include:
- **`static_shapes=False`:** switch to bucketed dynamic shapes. Helion
reuses results as long as tensor dtypes and device types stay constant.
Shape changes only trigger a re-selection when a dimension size crosses
the buckets `{0, 1, ≥2}`. Use this when you need one compiled kernel to
handle many input sizes.
the buckets `{0, 1, ≥2}`. Helion also tracks whether any tensor exceeds the
`torch.int32` indexing limit (more than ``2**31 - 1`` elements) and will
automatically regenerate code with 64-bit indexing in that case. Use this
mode when you need one compiled kernel to handle many input sizes, and pin
``@helion.kernel(..., index_dtype=torch.int64)`` if large tensors are the norm
so you avoid an extra specialization boundary.

- **Custom keys:** pass `key=` to group calls however you like.
This custom key is in addition to the above.
Expand Down Expand Up @@ -206,10 +210,13 @@ exact shape/stride signature of the example inputs. The generated code
has shapes baked in, which often provides a performance boost.

- With `static_shapes=False` it will specialize on the input dtypes,
device types, and whether each dynamic dimension falls into the 0, 1,
or ≥2 bucket. Python types are also specialized. For dimensions that
can vary across those buckets, supply representative inputs ≥2 to avoid
excessive specialization.
device types, and whether each dynamic dimension falls into the 0, 1,
or ≥2 bucket. Python types are also specialized. For dimensions that
can vary across those buckets, supply representative inputs ≥2 to avoid
excessive specialization. Just like the autotuning flow above, Helion
records whether any tensor crosses the int32 indexing limit when
`static_shapes=False`; explicitly set `index_dtype=torch.int64` if your
deployment commonly exceeds that threshold to avoid recompilation.

If you need to support multiple input types, bind multiple times with
representative inputs.
Expand Down
4 changes: 1 addition & 3 deletions helion/_compiler/aten_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,7 @@ def codegen_iota(ctx: LoweringContext, node: Node) -> object:
"""Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding."""
start = node.kwargs.get("start", 0)
step = node.kwargs.get("step", 1)
dtype = (
node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype
)
dtype = node.kwargs.get("dtype") or CompileEnvironment.current().index_dtype
assert isinstance(dtype, torch.dtype)
(length_arg,) = node.args # expecting a single argument for length

Expand Down
13 changes: 11 additions & 2 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,21 @@ class CompileEnvironment:
No config or codegen specific state should be stored here.
"""

def __init__(self, device: torch.device, settings: Settings) -> None:
def __init__(
self,
device: torch.device,
settings: Settings,
*,
index_dtype: torch.dtype | None = None,
) -> None:
from ..autotuner.config_spec import ConfigSpec

super().__init__()
self.device = device
self.settings = settings
self.index_dtype: torch.dtype = (
index_dtype or settings.index_dtype or torch.int32
)
# TODO(jansel): make backend configurable
self.backend = "triton"
self.shape_env = ShapeEnv(
Expand Down Expand Up @@ -383,7 +392,7 @@ def known_multiple(self, a: sympy.Expr, b: int | torch.SymInt) -> bool:

def triton_index_type(self) -> str:
"""tl.int32 or tl.int64 depending on Settings()"""
return triton_type(self.settings.index_dtype)
return triton_type(self.index_dtype)

def sympy_debug(self, expr: sympy.Expr) -> str:
return str(expr.xreplace(self.debug_shape_renames))
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def create(
env = CompileEnvironment.current()
dtype = env.triton_index_type()
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype)
raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype)

def _is_size_one(size: int | torch.SymInt) -> bool:
return env.known_equal(size, 1)
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/reduction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def codegen_reduction(
)
else:
acc_index = self.fn.new_var(f"{state.fx_node.name}_acc_index", dce=True)
index_dtype = CompileEnvironment.current().settings.index_dtype
index_dtype = CompileEnvironment.current().index_dtype
device_loop.outer_prefix.append(
statement_from_string(
f"{acc_index} = tl.full({shape}, {torch.iinfo(index_dtype).max!r}, {triton_type(index_dtype)})"
Expand Down
7 changes: 7 additions & 0 deletions helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ class IndexOffsetOutOfRangeForInt32(BaseError):
)


class InputTensorNumelExceedsIndexType(BaseError):
message = (
"Kernel index_dtype is {index_dtype}, but input input tensor is too large to fit. "
"Use @helion.kernel(index_dtype=torch.int64)."
)


class DataDependentOutputShapeNotSupported(BaseError):
message = (
"{op_desc} is not supported in Helion device loops because it produces "
Expand Down
2 changes: 1 addition & 1 deletion helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def arange(
"""
env = CompileEnvironment.current()
if dtype is None:
dtype = env.settings.index_dtype
dtype = env.index_dtype
return torch.arange(
*args,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion helion/language/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _(tile: torch.SymInt) -> torch.Tensor:
assert isinstance(tile, torch.SymInt)
env = CompileEnvironment.current()
assert env.get_block_id(tile) is not None
return torch.empty([tile], dtype=env.settings.index_dtype, device=env.device)
return torch.empty([tile], dtype=env.index_dtype, device=env.device)


@_decorators.codegen(tile_index, "triton")
Expand Down
53 changes: 50 additions & 3 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Callable
from typing import Generic
from typing import Hashable
from typing import Sequence
from typing import TypeVar
from typing import cast
from typing import overload
Expand All @@ -27,6 +28,7 @@
from torch._inductor.codecache import PyCodeCache
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._subclasses import FakeTensor
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary

from .. import exc
Expand Down Expand Up @@ -64,6 +66,36 @@
# Cache for GraphModule hashes
_graph_module_hash_cache: WeakIdKeyDictionary = WeakIdKeyDictionary()

_INT32_INDEX_LIMIT = torch.iinfo(torch.int32).max


def _resolve_index_dtype(
settings: Settings,
args: Sequence[object] | tuple[object, ...],
) -> torch.dtype:
if (index_dtype := settings.index_dtype) is not None:
limit = torch.iinfo(index_dtype).max
else:
limit = _INT32_INDEX_LIMIT
over_limit = False

def _check(tensor: torch.Tensor) -> None:
nonlocal over_limit
if over_limit:
return
try:
over_limit = bool(tensor.numel() > limit)
except RuntimeError: # unbacked SymInt
if index_dtype is None:
over_limit = True

tree_map_only(torch.Tensor, _check, args)
if index_dtype is None: # Auto-select when not provided
return torch.int64 if over_limit else torch.int32
if over_limit:
raise exc.InputTensorNumelExceedsIndexType(index_dtype=index_dtype)
return index_dtype


class Kernel(Generic[_R]):
def __init__(
Expand Down Expand Up @@ -321,7 +353,11 @@ def __init__(
self._run: Callable[..., _R] | None = None
self._config: Config | None = None
self._compile_cache: dict[Config, CompiledConfig] = {}
self.env = CompileEnvironment(_find_device(args), self.kernel.settings)
self.env = CompileEnvironment(
_find_device(args),
self.kernel.settings,
index_dtype=_resolve_index_dtype(self.kernel.settings, args),
)

if is_ref_mode_enabled(self.kernel.settings):
self.fake_args = [] # type: ignore[assignment]
Expand Down Expand Up @@ -830,11 +866,22 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
(*obj.size(),),
(*obj.stride(),),
)
bucketed = tuple([min(s, 2) for s in obj.size()])
if fn.settings.index_dtype is None:
try:
needs_int64 = bool(obj.numel() > _INT32_INDEX_LIMIT)
except RuntimeError:
needs_int64 = True # unbacked SymInt
return (
obj.dtype,
obj.device.type,
bucketed,
needs_int64,
)
return (
obj.dtype,
obj.device.type,
# 0, 1, or >=2 specialization
tuple([min(s, 2) for s in obj.size()]),
bucketed,
)


Expand Down
14 changes: 9 additions & 5 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ def _env_get_str(var_name: str, default: str) -> str:
return value


def _get_index_dtype() -> torch.dtype:
def _get_index_dtype() -> torch.dtype | None:
value = os.environ.get("HELION_INDEX_DTYPE")
if value is None or (token := value.strip()) == "":
return torch.int32
return None
if token.lower() == "auto":
return None
try:
dtype = getattr(torch, token)
except AttributeError as err:
Expand Down Expand Up @@ -266,7 +268,9 @@ class _Settings:
ignore_warnings: list[type[exc.BaseWarning]] = dataclasses.field(
default_factory=_get_ignore_warnings
)
index_dtype: torch.dtype = dataclasses.field(default_factory=_get_index_dtype)
index_dtype: torch.dtype | None = dataclasses.field(
default_factory=_get_index_dtype
)
dot_precision: DotPrecision = dataclasses.field(
default_factory=functools.partial(
_env_get_literal,
Expand Down Expand Up @@ -407,8 +411,8 @@ class Settings(_Settings):
"Set HELION_IGNORE_WARNINGS=WarningA,WarningB (names from helion.exc) to configure via env."
),
"index_dtype": (
"The dtype to use for index variables. Default is torch.int32. "
"Override with HELION_INDEX_DTYPE=torch.int64, etc."
"The dtype to use for index variables. Default auto-selects torch.int32 or torch.int64 based on input sizes. "
"Override with HELION_INDEX_DTYPE=<dtype> (or set to 'auto')."
),
"dot_precision": "Precision for dot products, see `triton.language.dot`. Can be 'tf32', 'tf32x3', or 'ieee'.",
"static_shapes": (
Expand Down
51 changes: 42 additions & 9 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,10 @@ def test_int32_offset_out_of_range_error(self):
range_warp_specializes=[],
)

def make_kernel(*, index_dtype: torch.dtype):
def make_kernel(*, index_dtype: torch.dtype | None = None):
kwargs = {"config": repro_config, "static_shapes": True}
kwargs["index_dtype"] = index_dtype
if index_dtype is not None:
kwargs["index_dtype"] = index_dtype
decorator = helion.kernel(**kwargs)

@decorator
Expand All @@ -435,15 +436,19 @@ def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return repro_bf16_add

def run_case(
shape, *, index_dtype, expect_int64_in_code=False, expect_error=False
):
shape,
*,
index_dtype: torch.dtype | None,
expect_int64_in_code: bool = False,
expect_error: type[Exception] | None = None,
) -> None:
kernel = make_kernel(index_dtype=index_dtype)
x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
torch.accelerator.synchronize()
if expect_error:
if expect_error is not None:
with self.assertRaisesRegex(
helion.exc.IndexOffsetOutOfRangeForInt32,
expect_error,
f"index_dtype is {index_dtype}",
):
code_and_output(kernel, (x, y))
Expand Down Expand Up @@ -479,19 +484,47 @@ def run_case(
small_shape,
index_dtype=torch.int32,
expect_int64_in_code=False,
expect_error=False,
expect_error=None,
)
run_case(
large_shape,
index_dtype=torch.int32,
expect_int64_in_code=False,
expect_error=True,
expect_error=helion.exc.InputTensorNumelExceedsIndexType,
)
run_case(
large_shape,
index_dtype=torch.int64,
expect_int64_in_code=True,
expect_error=False,
expect_error=None,
)
run_case(
large_shape,
index_dtype=None,
expect_int64_in_code=True,
expect_error=None,
)

def test_dynamic_shape_specialization_key_tracks_large_tensors(self) -> None:
@helion.kernel(static_shapes=False)
def passthrough(x: torch.Tensor) -> torch.Tensor:
return x

@helion.kernel(static_shapes=False, index_dtype=torch.int64)
def passthrough_int64(x: torch.Tensor) -> torch.Tensor:
return x

meta = "meta"
small = torch.empty((4, 4), device=meta)
large = torch.empty((51200, 51200), device=meta)

self.assertNotEqual(
passthrough.specialization_key((small,)),
passthrough.specialization_key((large,)),
)
self.assertEqual(
passthrough_int64.specialization_key((small,)),
passthrough_int64.specialization_key((large,)),
)

def test_assign_int(self):
Expand Down
Loading