Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,4 @@ uv.lock
docs/examples/
docs/sg_execution_times.rst
AGENTS.md
*.csv
6 changes: 6 additions & 0 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:

You can also use ``0`` to completely disable all autotuning output. Controlled by ``HELION_AUTOTUNE_LOG_LEVEL``.

.. autoattribute:: Settings.autotune_log

When set, Helion writes per-config autotuning telemetry (config index, generation, status, perf, compile time, timestamp, config JSON) to ``<value>.csv`` and mirrors the autotune log output to ``<value>.log`` for population-based autotuners (currently ``PatternSearch`` and ``DifferentialEvolution``).
Controlled by ``HELION_AUTOTUNE_LOG``.

.. autoattribute:: Settings.autotune_compile_timeout

Timeout in seconds for Triton compilation during autotuning. Default is ``60``. Controlled by ``HELION_AUTOTUNE_COMPILE_TIMEOUT``.
Expand Down Expand Up @@ -250,6 +255,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe
| ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. |
| ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. |
| ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. |
| ``HELION_AUTOTUNE_LOG`` | ``autotune_log`` | Base filename for per-config CSV telemetry and mirrored autotune logs. |
| ``HELION_AUTOTUNE_PRECOMPILE`` | ``autotune_precompile`` | Select the autotuner precompile mode (``"fork"`` (default), ``"spawn"``, or disable when empty). |
| ``HELION_AUTOTUNE_PRECOMPILE_JOBS`` | ``autotune_precompile_jobs`` | Cap the number of concurrent Triton precompile subprocesses. |
| ``HELION_AUTOTUNE_RANDOM_SEED`` | ``autotune_random_seed`` | Seed used for randomized autotuning searches. |
Expand Down
149 changes: 100 additions & 49 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import NamedTuple
from typing import NoReturn
from typing import Sequence
from typing import cast
from unittest.mock import patch
import uuid
Expand All @@ -47,7 +49,8 @@
from .config_generation import ConfigGeneration
from .config_generation import FlatConfig
from .logger import SUPPRESSED_TRITON_CODE_MSG
from .logger import LambdaLogger
from .logger import AutotuneLogEntry
from .logger import AutotuningLogger
from .logger import classify_triton_exception
from .logger import format_triton_compile_failure
from .logger import log_generated_triton_code_debug
Expand All @@ -63,8 +66,6 @@
from ..runtime.settings import Settings
from . import ConfigSpec

log = logging.getLogger(__name__)


class BaseAutotuner(abc.ABC):
"""
Expand All @@ -76,6 +77,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
raise NotImplementedError


class BenchmarkResult(NamedTuple):
"""Result tuple returned by parallel_benchmark."""

config: Config
fn: Callable[..., object]
perf: float
status: Literal["ok", "error", "timeout"]
compile_time: float | None


class BaseSearch(BaseAutotuner):
"""
Base class for search algorithms. This class defines the interface and utilities for all
Expand Down Expand Up @@ -109,7 +120,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
self.config_spec: ConfigSpec = kernel.config_spec
self.args: Sequence[object] = args
self.counters: collections.Counter[str] = collections.Counter()
self.log = LambdaLogger(self.settings.autotune_log_level)
self.log = AutotuningLogger(self.settings)
self.best_perf_so_far = inf
seed = self.settings.autotune_random_seed
random.seed(seed)
Expand Down Expand Up @@ -439,7 +450,7 @@ def start_precompile_and_check_for_hangs(
process.daemon = True
else:
precompiler = _prepare_precompiler_for_fork(
fn, device_args, config, self.kernel, decorator
fn, device_args, config, self.kernel, decorator, self.log
)
if precompiler is None:
return PrecompileFuture.skip(self, config, True)
Expand All @@ -463,14 +474,7 @@ def start_precompile_and_check_for_hangs(

def parallel_benchmark(
self, configs: list[Config], *, desc: str = "Benchmarking"
) -> list[
tuple[
Config,
Callable[..., object],
float,
Literal["ok", "error", "timeout"],
]
]:
) -> list[BenchmarkResult]:
"""
Benchmark multiple configurations in parallel.

Expand All @@ -479,24 +483,26 @@ def parallel_benchmark(
desc: Description for the progress bar.

Returns:
A list of tuples containing configurations and their performance.
A list of BenchmarkResult entries containing the configuration, compiled
callable, measured performance, status, and compilation time.
"""
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
precompile_status: list[Literal["ok", "error", "timeout"]]
fns: list[Callable[..., object]] = []
futures: list[PrecompileFuture] | None = None
for config in configs:
fn = self.kernel.compile_config(config, allow_print=False)
fns.append(fn)
if self.settings.autotune_precompile:
futures = [
*starmap(
futures = list(
starmap(
self.start_precompile_and_check_for_hangs,
zip(configs, fns, strict=True),
)
]
is_workings = PrecompileFuture.wait_for_all(
futures,
desc=f"{desc} precompiling"
if self.settings.autotune_progress_bar
else None,
)
precompile_status = []
precompile_desc = (
f"{desc} precompiling" if self.settings.autotune_progress_bar else None
)
is_workings = PrecompileFuture.wait_for_all(futures, desc=precompile_desc)
precompile_status: list[Literal["ok", "error", "timeout"]] = []
for future, ok in zip(futures, is_workings, strict=True):
reason = future.failure_reason
if ok:
Expand All @@ -508,29 +514,52 @@ def parallel_benchmark(
else:
is_workings = [True] * len(configs)
precompile_status = ["ok"] * len(configs)
results: list[
tuple[
Config, Callable[..., object], float, Literal["ok", "error", "timeout"]
]
] = []

results: list[BenchmarkResult] = []

# Render a progress bar only when the user requested it.
iterator = iter_with_progress(
zip(configs, fns, is_workings, precompile_status, strict=True),
enumerate(zip(fns, is_workings, precompile_status, strict=True)),
total=len(configs),
description=f"{desc} exploring neighbors",
enabled=self.settings.autotune_progress_bar,
)
for config, fn, is_working, reason in iterator:
for index, (fn, is_working, reason) in iterator:
config = configs[index]
if futures is not None:
future = futures[index]
compile_time = (
future.elapsed
if future.process is not None and future.started
else None
)
else:
compile_time = None
status: Literal["ok", "error", "timeout"]
if is_working:
# benchmark one-by-one to avoid noisy results
perf = self.benchmark_function(config, fn)
status = "ok" if math.isfinite(perf) else "error"
results.append((config, fn, perf, status))
results.append(
BenchmarkResult(
config=config,
fn=fn,
perf=perf,
status=status,
compile_time=compile_time,
)
)
else:
status = "timeout" if reason == "timeout" else "error"
results.append((config, fn, inf, status))
results.append(
BenchmarkResult(
config=config,
fn=fn,
perf=inf,
status=status,
compile_time=compile_time,
)
)
return results

def autotune(self, *, skip_cache: bool = False) -> Config:
Expand All @@ -543,9 +572,11 @@ def autotune(self, *, skip_cache: bool = False) -> Config:
The best configuration found during autotuning.
"""
start = time.perf_counter()
self.log.reset()
exit_stack = contextlib.ExitStack()
with exit_stack:
if self.settings.autotune_log and isinstance(self, PopulationBasedSearch):
exit_stack.enter_context(self.log.autotune_logging())
self.log.reset()
# Autotuner triggers bugs in remote triton compile service
exit_stack.enter_context(
patch.dict(os.environ, {"TRITON_LOCAL_BUILD": "1"}, clear=False)
Expand Down Expand Up @@ -600,6 +631,7 @@ class PopulationMember:
flat_values: FlatConfig
config: Config
status: Literal["ok", "error", "timeout", "unknown"] = "unknown"
compile_time: float | None = None

@property
def perf(self) -> float:
Expand Down Expand Up @@ -667,6 +699,7 @@ def __init__(
"""
super().__init__(kernel, args)
self.population: list[PopulationMember] = []
self._current_generation: int = 0
overrides = self.settings.autotune_config_overrides or None
self.config_gen: ConfigGeneration = ConfigGeneration(
self.config_spec,
Expand All @@ -683,6 +716,9 @@ def best(self) -> PopulationMember:
"""
return min(self.population, key=performance)

def set_generation(self, generation: int) -> None:
self._current_generation = generation

def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
"""
Benchmark a flat configuration.
Expand All @@ -694,9 +730,9 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
A population member with the benchmark results.
"""
config = self.config_gen.unflatten(flat_values)
fn, perf = self.benchmark(config)
status: Literal["ok", "error"] = "ok" if math.isfinite(perf) else "error"
return PopulationMember(fn, [perf], flat_values, config, status=status)
member = PopulationMember(_unset_fn, [], flat_values, config)
self.parallel_benchmark_population([member], desc="Benchmarking")
return member

def parallel_benchmark_flat(
self, to_check: list[FlatConfig]
Expand Down Expand Up @@ -737,17 +773,31 @@ def parallel_benchmark_population(
members: The list of population members to benchmark.
desc: Description for the progress bar.
"""
for member, (config_out, fn, perf, status) in zip(
members,
self.parallel_benchmark([m.config for m in members], desc=desc),
strict=True,
):
assert config_out is member.config
member.perfs.append(perf)
member.fn = fn
member.status = status
results = self.parallel_benchmark([m.config for m in members], desc=desc)
for member, result in zip(members, results, strict=True):
assert result.config is member.config
member.perfs.append(result.perf)
member.fn = result.fn
member.status = result.status
member.compile_time = result.compile_time
self._log_population_results(members)
return members

def _log_population_results(self, members: Sequence[PopulationMember]) -> None:
for member in members:
perf_value = member.perf if member.perfs else None
if perf_value is not None and not math.isfinite(perf_value):
perf_value = None
self.log.record_autotune_entry(
AutotuneLogEntry(
generation=self._current_generation,
status=member.status,
perf_ms=perf_value,
compile_time=member.compile_time,
config=member.config,
)
)

def compare(self, a: PopulationMember, b: PopulationMember) -> int:
"""
Compare two population members based on their performance, possibly with re-benchmarking.
Expand Down Expand Up @@ -1320,6 +1370,7 @@ def _prepare_precompiler_for_fork(
config: Config,
kernel: BoundKernel,
decorator: str,
logger: AutotuningLogger,
) -> Callable[[], None] | None:
def extract_launcher(
triton_kernel: object,
Expand All @@ -1344,12 +1395,12 @@ def extract_launcher(
return precompiler
except Exception:
log_generated_triton_code_debug(
log,
logger,
kernel,
config,
prefix=f"Generated Triton code for {decorator}:",
)
log.warning(
logger.warning(
"Helion autotuner precompile error for %s. %s",
decorator,
SUPPRESSED_TRITON_CODE_MSG,
Expand Down
23 changes: 17 additions & 6 deletions helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def mutate(self, x_index: int) -> FlatConfig:

def initial_two_generations(self) -> None:
# The initial population is 2x larger so we can throw out the slowest half and give the tuning process a head start
self.set_generation(0)
oversized_population = sorted(
self.parallel_benchmark_flat(
self.config_gen.random_population_flat(self.population_size * 2),
Expand All @@ -68,16 +69,25 @@ def initial_two_generations(self) -> None:
)
self.population = oversized_population[: self.population_size]

def _benchmark_mutation_batch(
self, indices: Sequence[int]
) -> list[PopulationMember]:
if not indices:
return []
flat_configs = [self.mutate(i) for i in indices]
return self.parallel_benchmark_flat(flat_configs)

def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]:
if self.immediate_update:
for i in range(len(self.population)):
yield i, self.benchmark_flat(self.mutate(i))
candidates = self._benchmark_mutation_batch([i])
if not candidates:
continue
yield i, candidates[0]
else:
yield from enumerate(
self.parallel_benchmark_flat(
[self.mutate(i) for i in range(len(self.population))]
)
)
indices = list(range(len(self.population)))
candidates = self._benchmark_mutation_batch(indices)
yield from zip(indices, candidates, strict=True)

def evolve_population(self) -> int:
replaced = 0
Expand All @@ -96,6 +106,7 @@ def _autotune(self) -> Config:
)
self.initial_two_generations()
for i in range(2, self.max_generations):
self.set_generation(i)
self.log(f"Generation {i} starting")
replaced = self.evolve_population()
self.log(f"Generation {i} complete: replaced={replaced}", self.statistics)
Expand Down
10 changes: 4 additions & 6 deletions helion/autotuner/finite_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ def __init__(
def _autotune(self) -> Config:
best_config = None
best_time = float("inf")
for config, _fn, time, _status in self.parallel_benchmark(
self.configs, desc="Benchmarking"
):
if time < best_time:
best_time = time
best_config = config
for result in self.parallel_benchmark(self.configs, desc="Benchmarking"):
if result.perf < best_time:
best_time = result.perf
best_config = result.config
assert best_config is not None
return best_config
Loading
Loading