diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 674ac846b..541f4a787 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -25,3 +25,8 @@ "PatternSearch": PatternSearch, "RandomSearch": RandomSearch, } + +cache_classes = { + "LocalAutotuneCache": LocalAutotuneCache, + "StrictLocalAutotuneCache": StrictLocalAutotuneCache, +} diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 54f329c7a..adf8e6d15 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -131,6 +131,13 @@ def _env_get_literal( ) +def _env_get_str(var_name: str, default: str) -> str: + value = os.environ.get(var_name) + if value is None or (value := value.strip()) == "": + return default + return value + + def _get_index_dtype() -> torch.dtype: value = os.environ.get("HELION_INDEX_DTYPE") if value is None or (token := value.strip()) == "": @@ -184,7 +191,7 @@ def _get_autotune_config_overrides() -> dict[str, object]: def default_autotuner_fn( bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object ) -> BaseAutotuner: - from ..autotuner import LocalAutotuneCache + from ..autotuner import cache_classes from ..autotuner import search_algorithms autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch") @@ -223,7 +230,16 @@ def default_autotuner_fn( assert profile.random_search is not None kwargs.setdefault("count", profile.random_search.count) - return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] + settings = bound_kernel.settings + cache_name = settings.autotune_cache + cache_cls = cache_classes.get(cache_name) + if cache_cls is None: + raise ValueError( + f"Unknown HELION_AUTOTUNE_CACHE value: {cache_name}, valid options are: " + f"{', '.join(cache_classes.keys())}" + ) + + return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] def _get_autotune_random_seed() -> int: @@ -348,6 +364,11 @@ class _Settings: ) ) ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode) + autotune_cache: str = dataclasses.field( + default_factory=functools.partial( + _env_get_str, "HELION_AUTOTUNE_CACHE", "LocalAutotuneCache" + ) + ) autotuner_fn: AutotunerFunction = default_autotuner_fn autotune_baseline_fn: Callable[..., object] | None = None @@ -413,6 +434,11 @@ class Settings(_Settings): "Should have the same signature as the kernel function. " "Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)." ), + "autotune_cache": ( + "The name of the autotuner cache class to use. " + "Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. " + "Defaults to 'LocalAutotuneCache'." + ), } def __init__(self, **settings: object) -> None: diff --git a/test/test_autotuner.py b/test/test_autotuner.py index ce2d4abdb..db748aacf 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -7,6 +7,7 @@ import logging import math import multiprocessing as mp +import operator import os from pathlib import Path import pickle @@ -41,6 +42,8 @@ from helion.autotuner.config_generation import ConfigGeneration from helion.autotuner.effort_profile import get_effort_profile from helion.autotuner.finite_search import FiniteSearch +from helion.autotuner.local_cache import LocalAutotuneCache +from helion.autotuner.local_cache import StrictLocalAutotuneCache from helion.autotuner.logger import LambdaLogger from helion.autotuner.random_search import RandomSearch import helion.language as hl @@ -955,5 +958,59 @@ def test_autotune_random_seed_from_settings(self) -> None: self.assertNotEqual(first, second) +class TestAutotuneCacheSelection(TestCase): + """Selection of the autotune cache via HELION_AUTOTUNE_CACHE.""" + + def _make_bound(self): + @helion.kernel(autotune_baseline_fn=operator.add, autotune_log_level=0) + def add(a: torch.Tensor, b: torch.Tensor): + out = torch.empty_like(a) + for tile in hl.tile(out.size()): + out[tile] = a[tile] + b[tile] + return out + + args = ( + torch.randn([8], device=DEVICE), + torch.randn([8], device=DEVICE), + ) + return add.bind(args), args + + def test_autotune_cache_default_is_local(self): + """Default (no env var set) -> LocalAutotuneCache.""" + with without_env_var("HELION_AUTOTUNE_CACHE"): + bound, args = self._make_bound() + with patch("torch.accelerator.synchronize", autospec=True) as sync: + sync.return_value = None + autotuner = bound.settings.autotuner_fn(bound, args) + self.assertIsInstance(autotuner, LocalAutotuneCache) + self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache) + + def test_autotune_cache_strict_selected_by_env(self): + """HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache -> StrictLocalAutotuneCache.""" + with patch.dict( + os.environ, + {"HELION_AUTOTUNE_CACHE": "StrictLocalAutotuneCache"}, + clear=False, + ): + bound, args = self._make_bound() + with patch("torch.accelerator.synchronize", autospec=True) as sync: + sync.return_value = None + autotuner = bound.settings.autotuner_fn(bound, args) + self.assertIsInstance(autotuner, StrictLocalAutotuneCache) + + def test_autotune_cache_invalid_raises(self): + """Invalid HELION_AUTOTUNE_CACHE value should raise a ValueError.""" + with patch.dict( + os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False + ): + bound, args = self._make_bound() + with patch("torch.accelerator.synchronize", autospec=True) as sync: + sync.return_value = None + with self.assertRaisesRegex( + ValueError, "Unknown HELION_AUTOTUNE_CACHE" + ): + bound.settings.autotuner_fn(bound, args) + + if __name__ == "__main__": unittest.main()