diff --git a/docs/api/settings.md b/docs/api/settings.md index a9a28c0f4..03a54ebdf 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -37,6 +37,13 @@ Settings can be configured via: If both are provided, decorator arguments take precedence. +```{note} +Helion reads the environment variables for `Settings` when the +`@helion.kernel` decorator defines the function (typically at import +time). One can modify Kernel.settings to change settings +for an already defined kernel. +``` + ## Configuration Examples ### Using Environment Variables @@ -74,6 +81,7 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: .. autoattribute:: Settings.index_dtype The data type used for index variables in generated code. Default is ``torch.int32``. + Override via ``HELION_INDEX_DTYPE=int64`` (or any ``torch.`` name). .. autoattribute:: Settings.dot_precision @@ -81,7 +89,8 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: .. autoattribute:: Settings.static_shapes - When enabled, tensor shapes are treated as compile-time constants for optimization. Default is ``True``. Set this to ``False`` if you need a single compiled kernel instance to serve many shape variants. + When enabled, tensor shapes are treated as compile-time constants for optimization. Default is ``True``. + Set ``HELION_STATIC_SHAPES=0`` the default if you need a compiled kernel instance to serve many shape variants. ``` ### Autotuning Settings @@ -100,7 +109,7 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: - ``logging.INFO``: Standard progress messages (default) - ``logging.DEBUG``: Verbose debugging output - You can also use ``0`` to completely disable all autotuning output. + You can also use ``0`` to completely disable all autotuning output. Controlled by ``HELION_AUTOTUNE_LOG_LEVEL``. .. autoattribute:: Settings.autotune_compile_timeout @@ -150,6 +159,7 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: .. autoattribute:: Settings.autotune_config_overrides Dict of config key/value pairs to force during autotuning. Useful for disabling problematic candidates or pinning experimental options. + Provide JSON via ``HELION_AUTOTUNE_CONFIG_OVERRIDES='{"num_warps": 4}'`` for global overrides. .. autoattribute:: Settings.autotune_effort @@ -183,6 +193,7 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b .. autoattribute:: Settings.ignore_warnings List of warning types to suppress during compilation. Default is an empty list. + Accepts comma-separated warning class names from ``helion.exc`` via ``HELION_IGNORE_WARNINGS`` (for example, ``HELION_IGNORE_WARNINGS=TensorOperationInWrapper``). .. autoattribute:: Settings.debug_dtype_asserts @@ -207,6 +218,7 @@ See :class:`helion.autotuner.LocalAutotuneCache` for details on cache keys and b .. autoattribute:: Settings.autotuner_fn Override the callable that constructs autotuner instances. Accepts the same signature as :func:`helion.runtime.settings.default_autotuner_fn`. + Pass a replacement callable via ``@helion.kernel(..., autotuner_fn=...)`` or ``helion.kernel(autotuner_fn=...)`` at definition time. ``` Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"DifferentialEvolutionSearch"``, ``"FiniteSearch"``, and ``"RandomSearch"``. @@ -222,9 +234,12 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe | Environment Variable | Maps To | Description | |----------------------|---------|-------------| | ``TRITON_F32_DEFAULT`` | ``dot_precision`` | Sets default floating-point precision for Triton dot products (``"tf32"``, ``"tf32x3"``, ``"ieee"``). | +| ``HELION_INDEX_DTYPE`` | ``index_dtype`` | Choose the default index dtype (accepts any ``torch.`` name, e.g. ``int64``). | +| ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. | | ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. | | ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. | | ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. | +| ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. | | ``HELION_AUTOTUNE_PRECOMPILE`` | ``autotune_precompile`` | Select the autotuner precompile mode (``"spawn"``, ``"fork"``, or disable when empty). | | ``HELION_AUTOTUNE_PRECOMPILE_JOBS`` | ``autotune_precompile_jobs`` | Cap the number of concurrent Triton precompile subprocesses. | | ``HELION_AUTOTUNE_RANDOM_SEED`` | ``autotune_random_seed`` | Seed used for randomized autotuning searches. | @@ -234,9 +249,11 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe | ``HELION_REBENCHMARK_THRESHOLD`` | ``autotune_rebenchmark_threshold`` | Re-run configs whose performance is within a multiplier of the current best. | | ``HELION_AUTOTUNE_PROGRESS_BAR`` | ``autotune_progress_bar`` | Enable or disable the progress bar UI during autotuning. | | ``HELION_AUTOTUNE_IGNORE_ERRORS`` | ``autotune_ignore_errors`` | Continue autotuning even when recoverable runtime errors occur. | +| ``HELION_AUTOTUNE_CONFIG_OVERRIDES`` | ``autotune_config_overrides`` | Supply JSON forcing particular autotuner config key/value pairs. | | ``HELION_CACHE_DIR`` | ``LocalAutotuneCache`` | Override the on-disk directory used for cached autotuning artifacts. | | ``HELION_SKIP_CACHE`` | ``LocalAutotuneCache`` | When set to ``1``, ignore cached autotuning entries and rerun searches. | | ``HELION_PRINT_OUTPUT_CODE`` | ``print_output_code`` | Print generated Triton code to stderr for inspection. | +| ``HELION_IGNORE_WARNINGS`` | ``ignore_warnings`` | Comma-separated warning names defined in ``helion.exc`` to suppress. | | ``HELION_ALLOW_WARP_SPECIALIZE`` | ``allow_warp_specialize`` | Permit warp-specialized code generation for ``tl.range``. | | ``HELION_DEBUG_DTYPE_ASSERTS`` | ``debug_dtype_asserts`` | Inject dtype assertions after each lowering step. | | ``HELION_INTERPRET`` | ``ref_mode`` | Run kernels through the reference interpreter when set to ``1`` (maps to ``RefMode.EAGER``). | diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index d4e4315d5..9f794db45 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -1,6 +1,8 @@ from __future__ import annotations import dataclasses +import functools +import json import logging import os import time @@ -8,6 +10,7 @@ from typing import Literal from typing import Protocol from typing import Sequence +from typing import TypeVar from typing import cast import torch @@ -22,12 +25,161 @@ from ..autotuner.base_search import BaseAutotuner from .kernel import BoundKernel + _T = TypeVar("_T") + class AutotunerFunction(Protocol): def __call__( self, bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object ) -> BaseAutotuner: ... +DotPrecision = Literal["tf32", "tf32x3", "ieee"] +PrecompileMode = Literal["spawn", "fork"] | None +_TRUE_LITERALS = frozenset({"1", "true", "yes", "on"}) +_FALSE_LITERALS = frozenset({"0", "false", "no", "off"}) + + +def _resolve_warning_name(name: str) -> type[exc.BaseWarning]: + attr = name.strip() + if not attr: + raise ValueError("HELION_IGNORE_WARNINGS entries must be non-empty names") + try: + warning_cls = getattr(exc, attr) + except AttributeError as err: + raise ValueError( + f"HELION_IGNORE_WARNINGS entry {name!r} is not a warning defined in helion.exc" + ) from err + if not isinstance(warning_cls, type) or not issubclass( + warning_cls, exc.BaseWarning + ): + raise ValueError( + f"HELION_IGNORE_WARNINGS entry {name!r} does not refer to a helion.exc.BaseWarning subclass" + ) + return warning_cls + + +def _get_ignore_warnings() -> list[type[exc.BaseWarning]]: + value = os.environ.get("HELION_IGNORE_WARNINGS") + if not value: + return [] + result: list[type[exc.BaseWarning]] = [] + for entry in value.split(","): + entry = entry.strip() + if not entry: + continue + result.append(_resolve_warning_name(entry)) + return result + + +def _env_get_optional_int(var_name: str) -> int | None: + value = os.environ.get(var_name) + if value is None or (value := value.strip()) == "": + return None + try: + parsed = int(value) + except ValueError as err: + raise ValueError(f"{var_name} must be an integer, got {value!r}") from err + return parsed + + +def _env_get_int(var_name: str, default: int) -> int: + result = _env_get_optional_int(var_name) + return default if result is None else result + + +def _env_get_optional_float(var_name: str) -> float | None: + value = os.environ.get(var_name) + if value is None or (value := value.strip()) == "": + return None + try: + return float(value) + except ValueError as err: + raise ValueError(f"{var_name} must be a float, got {value!r}") from err + + +def _env_get_bool(var_name: str, default: bool) -> bool: + value = os.environ.get(var_name) + if value is None or (value := value.strip()) == "": + return default + lowered = value.lower() + if lowered in _TRUE_LITERALS: + return True + if lowered in _FALSE_LITERALS: + return False + raise ValueError( + f"{var_name} must be one of {_TRUE_LITERALS | _FALSE_LITERALS}, got {value!r}" + ) + + +def _env_get_literal( + var_name: str, + default: _T, + *, + mapping: dict[str, object], +) -> _T: + value = os.environ.get(var_name) + if value is None: + return default + value = value.strip() + if value in mapping: + return cast("_T", mapping[value]) + if value == "": + return default + raise ValueError( + f"{var_name} must be one of {', '.join(sorted(mapping))}, got {value!r}" + ) + + +def _get_index_dtype() -> torch.dtype: + value = os.environ.get("HELION_INDEX_DTYPE") + if value is None or (token := value.strip()) == "": + return torch.int32 + try: + dtype = getattr(torch, token) + except AttributeError as err: + raise ValueError( + f"HELION_INDEX_DTYPE must map to a torch dtype attribute, got {value!r}" + ) from err + if not isinstance(dtype, torch.dtype): + raise ValueError(f"HELION_INDEX_DTYPE {value!r} is not a torch.dtype") + return dtype + + +def _get_autotune_log_level() -> int: + value = os.environ.get("HELION_AUTOTUNE_LOG_LEVEL") + if value is None or value.strip() == "": + return logging.INFO + text = value.strip() + if text.lstrip("+-").isdigit(): + return int(text) + upper = text.upper() + level = logging.getLevelName(upper) + if isinstance(level, int): + return level + raise ValueError( + f"HELION_AUTOTUNE_LOG_LEVEL must be an integer or logging level name, got {value!r}" + ) + + +def _get_autotune_config_overrides() -> dict[str, object]: + value = os.environ.get("HELION_AUTOTUNE_CONFIG_OVERRIDES") + if not value or (value := value.strip()) == "": + return {} + if not value.startswith("{") and os.path.exists(value): + value = open(value).read() + try: + parsed = json.loads(value) + except json.JSONDecodeError as err: + raise ValueError( + "HELION_AUTOTUNE_CONFIG_OVERRIDES must be valid JSON mapping of config keys to values" + ) from err + if not isinstance(parsed, dict): + raise ValueError( + "HELION_AUTOTUNE_CONFIG_OVERRIDES must decode to a JSON dictionary" + ) + return parsed + + def default_autotuner_fn( bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object ) -> BaseAutotuner: @@ -74,113 +226,119 @@ def default_autotuner_fn( def _get_autotune_random_seed() -> int: - value = os.environ.get("HELION_AUTOTUNE_RANDOM_SEED") - if value is not None: - return int(value) + if (seed := _env_get_optional_int("HELION_AUTOTUNE_RANDOM_SEED")) is not None: + return seed return int(time.time() * 1000) % 2**32 -def _get_autotune_max_generations() -> int | None: - value = os.environ.get("HELION_AUTOTUNE_MAX_GENERATIONS") - if value is not None: - return int(value) - return None - - -def _get_autotune_rebenchmark_threshold() -> float | None: - value = os.environ.get("HELION_REBENCHMARK_THRESHOLD") - if value is not None: - return float(value) - return None # Will use effort profile default - - -def _get_autotune_effort() -> AutotuneEffort: - return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) - - -def _get_autotune_precompile() -> str | None: - value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE") - if value is None: - return "spawn" - mode = value.strip().lower() - if mode in {"", "0", "false", "none"}: - return None - if mode in {"spawn", "fork"}: - return mode - raise ValueError( - "HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile" - ) - - -def _get_autotune_precompile_jobs() -> int | None: - value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS") - if value is None or value.strip() == "": - return None - jobs = int(value) - if jobs <= 0: - raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer") - return jobs - - -def _get_autotune_ignore_errors() -> bool: - return os.environ.get("HELION_AUTOTUNE_IGNORE_ERRORS", "0") == "1" +def _get_ref_mode() -> RefMode: + interpret = _env_get_bool("HELION_INTERPRET", False) + return RefMode.EAGER if interpret else RefMode.OFF @dataclasses.dataclass class _Settings: # see __slots__ below for the doc strings that show up in help(Settings) ignore_warnings: list[type[exc.BaseWarning]] = dataclasses.field( - default_factory=list - ) - index_dtype: torch.dtype = torch.int32 - dot_precision: Literal["tf32", "tf32x3", "ieee"] = cast( - "Literal['tf32', 'tf32x3', 'ieee']", - os.environ.get("TRITON_F32_DEFAULT", "tf32"), + default_factory=_get_ignore_warnings ) - static_shapes: bool = True - autotune_log_level: int = logging.INFO - autotune_compile_timeout: int = int( - os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60") + index_dtype: torch.dtype = dataclasses.field(default_factory=_get_index_dtype) + dot_precision: DotPrecision = dataclasses.field( + default_factory=functools.partial( + _env_get_literal, + "TRITON_F32_DEFAULT", + cast("DotPrecision", "tf32"), + mapping={k: k for k in ("tf32", "tf32x3", "ieee")}, + ) + ) # pyright: ignore[reportAssignmentType] + static_shapes: bool = dataclasses.field( + default_factory=functools.partial(_env_get_bool, "HELION_STATIC_SHAPES", True) ) - autotune_precompile: str | None = dataclasses.field( - default_factory=_get_autotune_precompile + autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level) + autotune_compile_timeout: int = dataclasses.field( + default_factory=functools.partial( + _env_get_int, "HELION_AUTOTUNE_COMPILE_TIMEOUT", 60 + ) ) + autotune_precompile: PrecompileMode = dataclasses.field( + default_factory=functools.partial( + _env_get_literal, + "HELION_AUTOTUNE_PRECOMPILE", + cast("PrecompileMode", "spawn"), + mapping={ + "spawn": "spawn", + "fork": "fork", + "": None, + "0": None, + }, + ) + ) # pyright: ignore[reportAssignmentType] autotune_precompile_jobs: int | None = dataclasses.field( - default_factory=_get_autotune_precompile_jobs + default_factory=functools.partial( + _env_get_optional_int, + "HELION_AUTOTUNE_PRECOMPILE_JOBS", + ) ) autotune_random_seed: int = dataclasses.field( default_factory=_get_autotune_random_seed ) - autotune_accuracy_check: bool = ( - os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1" + autotune_accuracy_check: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_AUTOTUNE_ACCURACY_CHECK", True + ) ) autotune_rebenchmark_threshold: float | None = dataclasses.field( - default_factory=_get_autotune_rebenchmark_threshold + default_factory=functools.partial( + _env_get_optional_float, + "HELION_REBENCHMARK_THRESHOLD", + ) ) - autotune_progress_bar: bool = ( - os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1" + autotune_progress_bar: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_AUTOTUNE_PROGRESS_BAR", True + ) ) autotune_max_generations: int | None = dataclasses.field( - default_factory=_get_autotune_max_generations + default_factory=functools.partial( + _env_get_optional_int, + "HELION_AUTOTUNE_MAX_GENERATIONS", + ) ) autotune_ignore_errors: bool = dataclasses.field( - default_factory=_get_autotune_ignore_errors + default_factory=functools.partial( + _env_get_bool, "HELION_AUTOTUNE_IGNORE_ERRORS", False + ) + ) + print_output_code: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_PRINT_OUTPUT_CODE", False + ) + ) + force_autotune: bool = dataclasses.field( + default_factory=functools.partial(_env_get_bool, "HELION_FORCE_AUTOTUNE", False) ) - print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1" - force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1" autotune_config_overrides: dict[str, object] = dataclasses.field( - default_factory=dict + default_factory=_get_autotune_config_overrides ) autotune_effort: AutotuneEffort = dataclasses.field( - default_factory=_get_autotune_effort - ) - allow_warp_specialize: bool = ( - os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" + default_factory=functools.partial( + _env_get_literal, + "HELION_AUTOTUNE_EFFORT", + cast("AutotuneEffort", "full"), + mapping={key: key for key in ("none", "quick", "full")}, + ) + ) # pyright: ignore[reportAssignmentType] + allow_warp_specialize: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_ALLOW_WARP_SPECIALIZE", True + ) ) - debug_dtype_asserts: bool = os.environ.get("HELION_DEBUG_DTYPE_ASSERTS", "0") == "1" - ref_mode: RefMode = ( - RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF + debug_dtype_asserts: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False + ) ) + ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) autotuner_fn: AutotunerFunction = default_autotuner_fn @@ -191,11 +349,23 @@ class Settings(_Settings): """ __slots__ = { - "ignore_warnings": "Subtypes of exc.BaseWarning to ignore when compiling.", - "index_dtype": "The dtype to use for index variables. Default is torch.int32.", + "ignore_warnings": ( + "Subtypes of exc.BaseWarning to ignore when compiling. " + "Set HELION_IGNORE_WARNINGS=WarningA,WarningB (names from helion.exc) to configure via env." + ), + "index_dtype": ( + "The dtype to use for index variables. Default is torch.int32. " + "Override with HELION_INDEX_DTYPE=torch.int64, etc." + ), "dot_precision": "Precision for dot products, see `triton.language.dot`. Can be 'tf32', 'tf32x3', or 'ieee'.", - "static_shapes": "If True, use static shapes for all tensors. This is a performance optimization.", - "autotune_log_level": "Log level for autotuning using Python logging levels. Default is logging.INFO. Use 0 to disable all output.", + "static_shapes": ( + "If True, use static shapes for all tensors. This is a performance optimization. " + "Set HELION_STATIC_SHAPES=0 to disable." + ), + "autotune_log_level": ( + "Log level for autotuning using Python logging levels. Default is logging.INFO. " + "Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output." + ), "autotune_compile_timeout": "Timeout for Triton compilation in seconds used for autotuning. Default is 60 seconds.", "autotune_precompile": "Autotuner precompile mode: 'spawn', 'fork', or falsy/None to disable. Defaults to 'spawn' on non-Windows platforms.", "autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.", @@ -210,11 +380,17 @@ class Settings(_Settings): ), "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", - "autotune_config_overrides": "Dictionary of config key/value pairs forced during autotuning.", + "autotune_config_overrides": ( + "Dictionary of config key/value pairs forced during autotuning. " + "Accepts HELION_AUTOTUNE_CONFIG_OVERRIDES='{\"num_warps\":4}'." + ), "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", "debug_dtype_asserts": "If True, emit tl.static_assert checks for dtype after each device node.", "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", - "autotuner_fn": "Function to create an autotuner", + "autotuner_fn": ( + "Function to create an autotuner. " + "Override by passing a callable to @helion.kernel(..., autotuner_fn=...)." + ), "autotune_effort": "Autotuning effort preset. One of 'none', 'quick', 'full'.", }