Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:

Force autotuning even when explicit configs are provided. Default is ``False``. Controlled by ``HELION_FORCE_AUTOTUNE=1``.

.. autoattribute:: Settings.autotune_force_persistent

Restrict ``pid_type`` choices to the persistent strategies (``"persistent_blocked"`` or ``"persistent_interleaved"``).
Default is ``False``. Enable globally with ``HELION_AUTOTUNE_FORCE_PERSISTENT=1`` or per kernel via ``@helion.kernel(..., autotune_force_persistent=True)``.

.. autoattribute:: Settings.autotune_log_level

Controls verbosity of autotuning output using Python logging levels:
Expand Down Expand Up @@ -258,6 +263,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
| ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. |
| ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). |
| ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. |
| ``HELION_AUTOTUNE_FORCE_PERSISTENT`` | ``autotune_force_persistent`` | Restrict ``pid_type`` to persistent kernel strategies during config search. |
| ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. |
| ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. |
| ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. |
Expand Down
3 changes: 3 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
self.block_sizes: list[BlockSizeInfo] = []
self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {}
self.config_spec = ConfigSpec()
if settings.autotune_force_persistent:
for pid_type in ("flat", "xyz"):
self.config_spec.disallow_pid_type(pid_type)
self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = (
collections.Counter()
)
Expand Down
11 changes: 11 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ class _Settings:
0,
)
)
autotune_force_persistent: bool = dataclasses.field(
default_factory=functools.partial(
_env_get_bool,
"HELION_AUTOTUNE_FORCE_PERSISTENT",
False,
)
)
autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level)
autotune_log: str | None = dataclasses.field(default_factory=_get_autotune_log_path)
autotune_compile_timeout: int = dataclasses.field(
Expand Down Expand Up @@ -412,6 +419,10 @@ class Settings(_Settings):
"Number of streaming multiprocessors to reserve when launching persistent kernels. "
"Set HELION_PERSISTENT_RESERVED_SMS=N (default 0) or pass persistent_reserved_sms=N to helion.kernel."
),
"autotune_force_persistent": (
"If True, restrict pid_type choices to persistent kernels only during config selection. "
"Set HELION_AUTOTUNE_FORCE_PERSISTENT=1 to force persistent kernel autotuning globally."
),
"autotune_log_level": (
"Log level for autotuning using Python logging levels. Default is logging.INFO. "
"Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output."
Expand Down
23 changes: 23 additions & 0 deletions test/test_config_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

import importlib
import inspect
import os
import pickle
from typing import Any
import unittest
from unittest.mock import patch

from hypothesis import given
from hypothesis import settings
from hypothesis import strategies as st
import torch

import helion
from helion._compiler.compile_environment import CompileEnvironment
from helion._testing import TestCase


Expand Down Expand Up @@ -232,5 +236,24 @@ def test_pre_serialized_json_backward_compat(self) -> None:
self.assertEqual(dict(reread), expected)


class TestSettingsEnv(TestCase):
def test_persistent_reserved_sms_env_var(self) -> None:
with patch.dict(
os.environ,
{"HELION_PERSISTENT_RESERVED_SMS": "5"},
clear=False,
):
settings = helion.Settings()
self.assertEqual(settings.persistent_reserved_sms, 5)

def test_autotune_force_persistent_limits_config_spec(self) -> None:
settings = helion.Settings(autotune_force_persistent=True)
env = CompileEnvironment(torch.device("cpu"), settings)
self.assertEqual(
env.config_spec.allowed_pid_types,
("persistent_blocked", "persistent_interleaved"),
)


if __name__ == "__main__":
unittest.main()