From f2e07b6aec52b32ee7693278286a75bc7fb50b7d Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Mon, 3 Nov 2025 16:51:20 -0600 Subject: [PATCH 1/4] feat(runtime): Add env var to enable StrictLocalAutotuneCache Adds a new boolean setting, `use_strict_cache`, to `helion.runtime.Settings`, controlled by the `HELION_USE_STRICT_CACHE` environment variable. The `default_autotuner_fn` now reads this setting and selects `StrictLocalAutotuneCache` if set to true, falling back to `LocalAutotuneCache` otherwise. This allows users to easily enable stricter cache validation without needing to write a custom `autotuner_fn`. Signed-off-by: Alessandro Sangiorgi --- helion/runtime/settings.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 54f329c7a..601477284 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -185,6 +185,7 @@ def default_autotuner_fn( bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object ) -> BaseAutotuner: from ..autotuner import LocalAutotuneCache + from ..autotuner import StrictLocalAutotuneCache from ..autotuner import search_algorithms autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch") @@ -223,7 +224,13 @@ def default_autotuner_fn( assert profile.random_search is not None kwargs.setdefault("count", profile.random_search.count) - return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] + settings = bound_kernel.settings + if settings.use_strict_cache: + cache_cls = StrictLocalAutotuneCache + else: + cache_cls = LocalAutotuneCache + + return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] def _get_autotune_random_seed() -> int: @@ -347,6 +354,11 @@ class _Settings: _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False ) ) + use_strict_cache: bool = dataclasses.field( + default_factory=functools.partial( + _env_get_bool, "HELION_USE_STRICT_CACHE", False + ) + ) ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) autotuner_fn: AutotunerFunction = default_autotuner_fn autotune_baseline_fn: Callable[..., object] | None = None From c43a661ab333663b8b7fd563c98bd9184e497efa Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Tue, 4 Nov 2025 07:51:42 -0600 Subject: [PATCH 2/4] feat(autotuner): Make autotune cache class configurable via env var Adds a new `HELION_AUTOTUNE_CACHE` environment variable to allow users to select the autotuner cache implementation, similar to how `HELION_AUTOTUNER` selects the search algorithm. - `helion/autotuner/__init__.py`: - Adds a `cache_classes` dictionary to register available cache implementations (LocalAutotuneCache, StrictLocalAutotuneCache). - `helion/runtime/settings.py`: - Adds a new string setting, `autotune_cache`, which defaults to "LocalAutotuneCache" and is controlled by the `HELION_AUTOTUNE_CACHE` environment variable. - Updates `default_autotuner_fn` to use this new setting to look up and instantiate the correct cache class from `autotuner.cache_classes`. - `test_autotuner.py`: - Adds the `TestAutotuneCacheSelection` test case. - Verifies that the default cache is `LocalAutotuneCache`. - Verifies that `HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache` correctly selects `StrictLocalAutotuneCache`. - Verifies that an invalid cache name raises a `ValueError`. Signed-off-by: Alessandro Sangiorgi --- helion/autotuner/__init__.py | 5 ++++ helion/runtime/settings.py | 32 +++++++++++++++++------- test/test_autotuner.py | 48 ++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 674ac846b..541f4a787 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -25,3 +25,8 @@ "PatternSearch": PatternSearch, "RandomSearch": RandomSearch, } + +cache_classes = { + "LocalAutotuneCache": LocalAutotuneCache, + "StrictLocalAutotuneCache": StrictLocalAutotuneCache, +} diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 601477284..adf8e6d15 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -131,6 +131,13 @@ def _env_get_literal( ) +def _env_get_str(var_name: str, default: str) -> str: + value = os.environ.get(var_name) + if value is None or (value := value.strip()) == "": + return default + return value + + def _get_index_dtype() -> torch.dtype: value = os.environ.get("HELION_INDEX_DTYPE") if value is None or (token := value.strip()) == "": @@ -184,8 +191,7 @@ def _get_autotune_config_overrides() -> dict[str, object]: def default_autotuner_fn( bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object ) -> BaseAutotuner: - from ..autotuner import LocalAutotuneCache - from ..autotuner import StrictLocalAutotuneCache + from ..autotuner import cache_classes from ..autotuner import search_algorithms autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch") @@ -225,10 +231,13 @@ def default_autotuner_fn( kwargs.setdefault("count", profile.random_search.count) settings = bound_kernel.settings - if settings.use_strict_cache: - cache_cls = StrictLocalAutotuneCache - else: - cache_cls = LocalAutotuneCache + cache_name = settings.autotune_cache + cache_cls = cache_classes.get(cache_name) + if cache_cls is None: + raise ValueError( + f"Unknown HELION_AUTOTUNE_CACHE value: {cache_name}, valid options are: " + f"{', '.join(cache_classes.keys())}" + ) return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] @@ -354,12 +363,12 @@ class _Settings: _env_get_bool, "HELION_DEBUG_DTYPE_ASSERTS", False ) ) - use_strict_cache: bool = dataclasses.field( + ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) + autotune_cache: str = dataclasses.field( default_factory=functools.partial( - _env_get_bool, "HELION_USE_STRICT_CACHE", False + _env_get_str, "HELION_AUTOTUNE_CACHE", "LocalAutotuneCache" ) ) - ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) autotuner_fn: AutotunerFunction = default_autotuner_fn autotune_baseline_fn: Callable[..., object] | None = None @@ -425,6 +434,11 @@ class Settings(_Settings): "Should have the same signature as the kernel function. " "Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)." ), + "autotune_cache": ( + "The name of the autotuner cache class to use. " + "Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. " + "Defaults to 'LocalAutotuneCache'." + ), } def __init__(self, **settings: object) -> None: diff --git a/test/test_autotuner.py b/test/test_autotuner.py index ce2d4abdb..007c5131f 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -41,6 +41,8 @@ from helion.autotuner.config_generation import ConfigGeneration from helion.autotuner.effort_profile import get_effort_profile from helion.autotuner.finite_search import FiniteSearch +from helion.autotuner.local_cache import LocalAutotuneCache +from helion.autotuner.local_cache import StrictLocalAutotuneCache from helion.autotuner.logger import LambdaLogger from helion.autotuner.random_search import RandomSearch import helion.language as hl @@ -955,5 +957,51 @@ def test_autotune_random_seed_from_settings(self) -> None: self.assertNotEqual(first, second) +class TestAutotuneCacheSelection(TestCase): + """Selection of the autotune cache via HELION_AUTOTUNE_CACHE.""" + + def _make_bound(self): + @helion.kernel() + def add(a: torch.Tensor, b: torch.Tensor): + out = torch.empty_like(a) + for tile in hl.tile(out.size()): + out[tile] = a[tile] + b[tile] + return out + + args = ( + torch.randn([8], device=DEVICE), + torch.randn([8], device=DEVICE), + ) + return add.bind(args), args + + def test_autotune_cache_default_is_local(self): + """Default (no env var set) -> LocalAutotuneCache.""" + with without_env_var("HELION_AUTOTUNE_CACHE"): + bound, args = self._make_bound() + autotuner = bound.settings.autotuner_fn(bound, args) + self.assertIsInstance(autotuner, LocalAutotuneCache) + self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache) + + def test_autotune_cache_strict_selected_by_env(self): + """HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache -> StrictLocalAutotuneCache.""" + with patch.dict( + os.environ, + {"HELION_AUTOTUNE_CACHE": "StrictLocalAutotuneCache"}, + clear=False, + ): + bound, args = self._make_bound() + autotuner = bound.settings.autotuner_fn(bound, args) + self.assertIsInstance(autotuner, StrictLocalAutotuneCache) + + def test_autotune_cache_invalid_raises(self): + """Invalid HELION_AUTOTUNE_CACHE value should raise a ValueError.""" + with patch.dict( + os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False + ): + bound, args = self._make_bound() + with self.assertRaisesRegex(ValueError, "Unknown HELION_AUTOTUNE_CACHE"): + bound.settings.autotuner_fn(bound, args) + + if __name__ == "__main__": unittest.main() From d503ca31672fd68cbdeb9da92bfa77272a7fa710 Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Tue, 4 Nov 2025 13:46:00 -0600 Subject: [PATCH 3/4] Fix test Signed-off-by: Alessandro Sangiorgi --- test/test_autotuner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 007c5131f..c24e2a4a9 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -7,6 +7,7 @@ import logging import math import multiprocessing as mp +import operator import os from pathlib import Path import pickle @@ -961,7 +962,7 @@ class TestAutotuneCacheSelection(TestCase): """Selection of the autotune cache via HELION_AUTOTUNE_CACHE.""" def _make_bound(self): - @helion.kernel() + @helion.kernel(autotune_baseline_fn=operator.add, autotune_log_level=0) def add(a: torch.Tensor, b: torch.Tensor): out = torch.empty_like(a) for tile in hl.tile(out.size()): From 32fc11d5e15116d0862b7d71a1be8e4fe4405504 Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Wed, 5 Nov 2025 17:10:10 -0600 Subject: [PATCH 4/4] Fix test for cpu --- test/test_autotuner.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index c24e2a4a9..db748aacf 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -979,7 +979,9 @@ def test_autotune_cache_default_is_local(self): """Default (no env var set) -> LocalAutotuneCache.""" with without_env_var("HELION_AUTOTUNE_CACHE"): bound, args = self._make_bound() - autotuner = bound.settings.autotuner_fn(bound, args) + with patch("torch.accelerator.synchronize", autospec=True) as sync: + sync.return_value = None + autotuner = bound.settings.autotuner_fn(bound, args) self.assertIsInstance(autotuner, LocalAutotuneCache) self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache) @@ -991,7 +993,9 @@ def test_autotune_cache_strict_selected_by_env(self): clear=False, ): bound, args = self._make_bound() - autotuner = bound.settings.autotuner_fn(bound, args) + with patch("torch.accelerator.synchronize", autospec=True) as sync: + sync.return_value = None + autotuner = bound.settings.autotuner_fn(bound, args) self.assertIsInstance(autotuner, StrictLocalAutotuneCache) def test_autotune_cache_invalid_raises(self): @@ -1000,8 +1004,12 @@ def test_autotune_cache_invalid_raises(self): os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False ): bound, args = self._make_bound() - with self.assertRaisesRegex(ValueError, "Unknown HELION_AUTOTUNE_CACHE"): - bound.settings.autotuner_fn(bound, args) + with patch("torch.accelerator.synchronize", autospec=True) as sync: + sync.return_value = None + with self.assertRaisesRegex( + ValueError, "Unknown HELION_AUTOTUNE_CACHE" + ): + bound.settings.autotuner_fn(bound, args) if __name__ == "__main__":