From 6960b4c874e9e0a223363ced1cf8deb51273bd3e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 13 Nov 2025 15:55:36 -0800 Subject: [PATCH] Add Settings.autotune_force_persistent stack-info: PR: https://github.com/pytorch/helion/pull/1130, branch: jansel/stack/226 --- docs/api/settings.md | 6 ++++++ helion/_compiler/compile_environment.py | 3 +++ helion/runtime/settings.py | 11 +++++++++++ test/test_config_api.py | 23 +++++++++++++++++++++++ 4 files changed, 43 insertions(+) diff --git a/docs/api/settings.md b/docs/api/settings.md index 6e85cd3c8..e72b95aa9 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -106,6 +106,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: Force autotuning even when explicit configs are provided. Default is ``False``. Controlled by ``HELION_FORCE_AUTOTUNE=1``. +.. autoattribute:: Settings.autotune_force_persistent + + Restrict ``pid_type`` choices to the persistent strategies (``"persistent_blocked"`` or ``"persistent_interleaved"``). + Default is ``False``. Enable globally with ``HELION_AUTOTUNE_FORCE_PERSISTENT=1`` or per kernel via ``@helion.kernel(..., autotune_force_persistent=True)``. + .. autoattribute:: Settings.autotune_log_level Controls verbosity of autotuning output using Python logging levels: @@ -258,6 +263,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe | ``HELION_STATIC_SHAPES`` | ``static_shapes`` | Set to ``0``/``false`` to disable global static shape specialization. | | ``HELION_PERSISTENT_RESERVED_SMS`` | ``persistent_reserved_sms`` | Reserve this many streaming multiprocessors when launching persistent kernels (``0`` uses all available SMs). | | ``HELION_FORCE_AUTOTUNE`` | ``force_autotune`` | Force the autotuner to run even when explicit configs are provided. | +| ``HELION_AUTOTUNE_FORCE_PERSISTENT`` | ``autotune_force_persistent`` | Restrict ``pid_type`` to persistent kernel strategies during config search. | | ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. | | ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. | | ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. | diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 26dd19859..4317a0e3a 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -92,6 +92,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None: self.block_sizes: list[BlockSizeInfo] = [] self.debug_shape_renames: dict[sympy.Expr, sympy.Expr] = {} self.config_spec = ConfigSpec() + if settings.autotune_force_persistent: + for pid_type in ("flat", "xyz"): + self.config_spec.disallow_pid_type(pid_type) self.kernel_tensor_sizes: dict[tuple[sympy.Expr, ...], int] = ( collections.Counter() ) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 17417f977..1b0bdb2cd 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -285,6 +285,13 @@ class _Settings: 0, ) ) + autotune_force_persistent: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, + "HELION_AUTOTUNE_FORCE_PERSISTENT", + False, + ) + ) autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level) autotune_log: str | None = dataclasses.field(default_factory=_get_autotune_log_path) autotune_compile_timeout: int = dataclasses.field( @@ -412,6 +419,10 @@ class Settings(_Settings): "Number of streaming multiprocessors to reserve when launching persistent kernels. " "Set HELION_PERSISTENT_RESERVED_SMS=N (default 0) or pass persistent_reserved_sms=N to helion.kernel." ), + "autotune_force_persistent": ( + "If True, restrict pid_type choices to persistent kernels only during config selection. " + "Set HELION_AUTOTUNE_FORCE_PERSISTENT=1 to force persistent kernel autotuning globally." + ), "autotune_log_level": ( "Log level for autotuning using Python logging levels. Default is logging.INFO. " "Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output." diff --git a/test/test_config_api.py b/test/test_config_api.py index 76add20a5..690f28cd7 100644 --- a/test/test_config_api.py +++ b/test/test_config_api.py @@ -2,15 +2,19 @@ import importlib import inspect +import os import pickle from typing import Any import unittest +from unittest.mock import patch from hypothesis import given from hypothesis import settings from hypothesis import strategies as st +import torch import helion +from helion._compiler.compile_environment import CompileEnvironment from helion._testing import TestCase @@ -232,5 +236,24 @@ def test_pre_serialized_json_backward_compat(self) -> None: self.assertEqual(dict(reread), expected) +class TestSettingsEnv(TestCase): + def test_persistent_reserved_sms_env_var(self) -> None: + with patch.dict( + os.environ, + {"HELION_PERSISTENT_RESERVED_SMS": "5"}, + clear=False, + ): + settings = helion.Settings() + self.assertEqual(settings.persistent_reserved_sms, 5) + + def test_autotune_force_persistent_limits_config_spec(self) -> None: + settings = helion.Settings(autotune_force_persistent=True) + env = CompileEnvironment(torch.device("cpu"), settings) + self.assertEqual( + env.config_spec.allowed_pid_types, + ("persistent_blocked", "persistent_interleaved"), + ) + + if __name__ == "__main__": unittest.main()