Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions helion/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@
"PatternSearch": PatternSearch,
"RandomSearch": RandomSearch,
}

cache_classes = {
"LocalAutotuneCache": LocalAutotuneCache,
"StrictLocalAutotuneCache": StrictLocalAutotuneCache,
}
30 changes: 28 additions & 2 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def _env_get_literal(
)


def _env_get_str(var_name: str, default: str) -> str:
value = os.environ.get(var_name)
if value is None or (value := value.strip()) == "":
return default
return value


def _get_index_dtype() -> torch.dtype:
value = os.environ.get("HELION_INDEX_DTYPE")
if value is None or (token := value.strip()) == "":
Expand Down Expand Up @@ -184,7 +191,7 @@ def _get_autotune_config_overrides() -> dict[str, object]:
def default_autotuner_fn(
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
) -> BaseAutotuner:
from ..autotuner import LocalAutotuneCache
from ..autotuner import cache_classes
from ..autotuner import search_algorithms

autotuner_name = os.environ.get("HELION_AUTOTUNER", "PatternSearch")
Expand Down Expand Up @@ -223,7 +230,16 @@ def default_autotuner_fn(
assert profile.random_search is not None
kwargs.setdefault("count", profile.random_search.count)

return LocalAutotuneCache(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]
settings = bound_kernel.settings
cache_name = settings.autotune_cache
cache_cls = cache_classes.get(cache_name)
if cache_cls is None:
raise ValueError(
f"Unknown HELION_AUTOTUNE_CACHE value: {cache_name}, valid options are: "
f"{', '.join(cache_classes.keys())}"
)

return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType]


def _get_autotune_random_seed() -> int:
Expand Down Expand Up @@ -348,6 +364,11 @@ class _Settings:
)
)
ref_mode: RefMode = dataclasses.field(default_factory=_get_ref_mode)
autotune_cache: str = dataclasses.field(
default_factory=functools.partial(
_env_get_str, "HELION_AUTOTUNE_CACHE", "LocalAutotuneCache"
)
)
autotuner_fn: AutotunerFunction = default_autotuner_fn
autotune_baseline_fn: Callable[..., object] | None = None

Expand Down Expand Up @@ -413,6 +434,11 @@ class Settings(_Settings):
"Should have the same signature as the kernel function. "
"Pass as @helion.kernel(..., autotune_baseline_fn=my_baseline_fn)."
),
"autotune_cache": (
"The name of the autotuner cache class to use. "
"Set HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache to enable strict caching. "
"Defaults to 'LocalAutotuneCache'."
),
}

def __init__(self, **settings: object) -> None:
Expand Down
57 changes: 57 additions & 0 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import math
import multiprocessing as mp
import operator
import os
from pathlib import Path
import pickle
Expand Down Expand Up @@ -41,6 +42,8 @@
from helion.autotuner.config_generation import ConfigGeneration
from helion.autotuner.effort_profile import get_effort_profile
from helion.autotuner.finite_search import FiniteSearch
from helion.autotuner.local_cache import LocalAutotuneCache
from helion.autotuner.local_cache import StrictLocalAutotuneCache
from helion.autotuner.logger import LambdaLogger
from helion.autotuner.random_search import RandomSearch
import helion.language as hl
Expand Down Expand Up @@ -955,5 +958,59 @@ def test_autotune_random_seed_from_settings(self) -> None:
self.assertNotEqual(first, second)


class TestAutotuneCacheSelection(TestCase):
"""Selection of the autotune cache via HELION_AUTOTUNE_CACHE."""

def _make_bound(self):
@helion.kernel(autotune_baseline_fn=operator.add, autotune_log_level=0)
def add(a: torch.Tensor, b: torch.Tensor):
out = torch.empty_like(a)
for tile in hl.tile(out.size()):
out[tile] = a[tile] + b[tile]
return out

args = (
torch.randn([8], device=DEVICE),
torch.randn([8], device=DEVICE),
)
return add.bind(args), args

def test_autotune_cache_default_is_local(self):
"""Default (no env var set) -> LocalAutotuneCache."""
with without_env_var("HELION_AUTOTUNE_CACHE"):
bound, args = self._make_bound()
with patch("torch.accelerator.synchronize", autospec=True) as sync:
sync.return_value = None
autotuner = bound.settings.autotuner_fn(bound, args)
self.assertIsInstance(autotuner, LocalAutotuneCache)
self.assertNotIsInstance(autotuner, StrictLocalAutotuneCache)

def test_autotune_cache_strict_selected_by_env(self):
"""HELION_AUTOTUNE_CACHE=StrictLocalAutotuneCache -> StrictLocalAutotuneCache."""
with patch.dict(
os.environ,
{"HELION_AUTOTUNE_CACHE": "StrictLocalAutotuneCache"},
clear=False,
):
bound, args = self._make_bound()
with patch("torch.accelerator.synchronize", autospec=True) as sync:
sync.return_value = None
autotuner = bound.settings.autotuner_fn(bound, args)
self.assertIsInstance(autotuner, StrictLocalAutotuneCache)

def test_autotune_cache_invalid_raises(self):
"""Invalid HELION_AUTOTUNE_CACHE value should raise a ValueError."""
with patch.dict(
os.environ, {"HELION_AUTOTUNE_CACHE": "InvalidCacheName"}, clear=False
):
bound, args = self._make_bound()
with patch("torch.accelerator.synchronize", autospec=True) as sync:
sync.return_value = None
with self.assertRaisesRegex(
ValueError, "Unknown HELION_AUTOTUNE_CACHE"
):
bound.settings.autotuner_fn(bound, args)


if __name__ == "__main__":
unittest.main()
Loading