From d618e77d8237aa916c639e2ca0c3c56c930a75ee Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 5 Nov 2025 20:22:02 -0800 Subject: [PATCH] Add autotuning log stack-info: PR: https://github.com/pytorch/helion/pull/1095, branch: jansel/stack/218 --- .gitignore | 1 + docs/api/settings.md | 6 + helion/autotuner/base_search.py | 149 ++++++---- helion/autotuner/differential_evolution.py | 23 +- helion/autotuner/finite_search.py | 10 +- helion/autotuner/logger.py | 325 ++++++++++++++++++--- helion/autotuner/pattern_search.py | 2 + helion/runtime/settings.py | 12 + test/test_autotuner.py | 109 ++++++- 9 files changed, 538 insertions(+), 99 deletions(-) diff --git a/.gitignore b/.gitignore index 45d893c18..ee2582fa6 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,4 @@ uv.lock docs/examples/ docs/sg_execution_times.rst AGENTS.md +*.csv diff --git a/docs/api/settings.md b/docs/api/settings.md index 299728628..0316dee89 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -112,6 +112,11 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor: You can also use ``0`` to completely disable all autotuning output. Controlled by ``HELION_AUTOTUNE_LOG_LEVEL``. +.. autoattribute:: Settings.autotune_log + + When set, Helion writes per-config autotuning telemetry (config index, generation, status, perf, compile time, timestamp, config JSON) to ``.csv`` and mirrors the autotune log output to ``.log`` for population-based autotuners (currently ``PatternSearch`` and ``DifferentialEvolution``). + Controlled by ``HELION_AUTOTUNE_LOG``. + .. autoattribute:: Settings.autotune_compile_timeout Timeout in seconds for Triton compilation during autotuning. Default is ``60``. Controlled by ``HELION_AUTOTUNE_COMPILE_TIMEOUT``. @@ -250,6 +255,7 @@ Built-in values for ``HELION_AUTOTUNER`` include ``"PatternSearch"``, ``"Differe | ``HELION_DISALLOW_AUTOTUNING`` | ``check_autotuning_disabled`` | Hard-disable autotuning; kernels must supply explicit configs when this is ``1``. | | ``HELION_AUTOTUNE_COMPILE_TIMEOUT`` | ``autotune_compile_timeout`` | Maximum seconds to wait for Triton compilation during autotuning. | | ``HELION_AUTOTUNE_LOG_LEVEL`` | ``autotune_log_level`` | Adjust logging verbosity; accepts names like ``INFO`` or numeric levels. | +| ``HELION_AUTOTUNE_LOG`` | ``autotune_log`` | Base filename for per-config CSV telemetry and mirrored autotune logs. | | ``HELION_AUTOTUNE_PRECOMPILE`` | ``autotune_precompile`` | Select the autotuner precompile mode (``"fork"`` (default), ``"spawn"``, or disable when empty). | | ``HELION_AUTOTUNE_PRECOMPILE_JOBS`` | ``autotune_precompile_jobs`` | Cap the number of concurrent Triton precompile subprocesses. | | ``HELION_AUTOTUNE_RANDOM_SEED`` | ``autotune_random_seed`` | Seed used for randomized autotuning searches. | diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index b6d3ab4b0..01b96ae4a 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -28,7 +28,9 @@ from typing import Callable from typing import Iterable from typing import Literal +from typing import NamedTuple from typing import NoReturn +from typing import Sequence from typing import cast from unittest.mock import patch import uuid @@ -47,7 +49,8 @@ from .config_generation import ConfigGeneration from .config_generation import FlatConfig from .logger import SUPPRESSED_TRITON_CODE_MSG -from .logger import LambdaLogger +from .logger import AutotuneLogEntry +from .logger import AutotuningLogger from .logger import classify_triton_exception from .logger import format_triton_compile_failure from .logger import log_generated_triton_code_debug @@ -63,8 +66,6 @@ from ..runtime.settings import Settings from . import ConfigSpec -log = logging.getLogger(__name__) - class BaseAutotuner(abc.ABC): """ @@ -76,6 +77,16 @@ def autotune(self, *, skip_cache: bool = False) -> Config: raise NotImplementedError +class BenchmarkResult(NamedTuple): + """Result tuple returned by parallel_benchmark.""" + + config: Config + fn: Callable[..., object] + perf: float + status: Literal["ok", "error", "timeout"] + compile_time: float | None + + class BaseSearch(BaseAutotuner): """ Base class for search algorithms. This class defines the interface and utilities for all @@ -109,7 +120,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: self.config_spec: ConfigSpec = kernel.config_spec self.args: Sequence[object] = args self.counters: collections.Counter[str] = collections.Counter() - self.log = LambdaLogger(self.settings.autotune_log_level) + self.log = AutotuningLogger(self.settings) self.best_perf_so_far = inf seed = self.settings.autotune_random_seed random.seed(seed) @@ -439,7 +450,7 @@ def start_precompile_and_check_for_hangs( process.daemon = True else: precompiler = _prepare_precompiler_for_fork( - fn, device_args, config, self.kernel, decorator + fn, device_args, config, self.kernel, decorator, self.log ) if precompiler is None: return PrecompileFuture.skip(self, config, True) @@ -463,14 +474,7 @@ def start_precompile_and_check_for_hangs( def parallel_benchmark( self, configs: list[Config], *, desc: str = "Benchmarking" - ) -> list[ - tuple[ - Config, - Callable[..., object], - float, - Literal["ok", "error", "timeout"], - ] - ]: + ) -> list[BenchmarkResult]: """ Benchmark multiple configurations in parallel. @@ -479,24 +483,26 @@ def parallel_benchmark( desc: Description for the progress bar. Returns: - A list of tuples containing configurations and their performance. + A list of BenchmarkResult entries containing the configuration, compiled + callable, measured performance, status, and compilation time. """ - fns = [self.kernel.compile_config(c, allow_print=False) for c in configs] - precompile_status: list[Literal["ok", "error", "timeout"]] + fns: list[Callable[..., object]] = [] + futures: list[PrecompileFuture] | None = None + for config in configs: + fn = self.kernel.compile_config(config, allow_print=False) + fns.append(fn) if self.settings.autotune_precompile: - futures = [ - *starmap( + futures = list( + starmap( self.start_precompile_and_check_for_hangs, zip(configs, fns, strict=True), ) - ] - is_workings = PrecompileFuture.wait_for_all( - futures, - desc=f"{desc} precompiling" - if self.settings.autotune_progress_bar - else None, ) - precompile_status = [] + precompile_desc = ( + f"{desc} precompiling" if self.settings.autotune_progress_bar else None + ) + is_workings = PrecompileFuture.wait_for_all(futures, desc=precompile_desc) + precompile_status: list[Literal["ok", "error", "timeout"]] = [] for future, ok in zip(futures, is_workings, strict=True): reason = future.failure_reason if ok: @@ -508,29 +514,52 @@ def parallel_benchmark( else: is_workings = [True] * len(configs) precompile_status = ["ok"] * len(configs) - results: list[ - tuple[ - Config, Callable[..., object], float, Literal["ok", "error", "timeout"] - ] - ] = [] + + results: list[BenchmarkResult] = [] # Render a progress bar only when the user requested it. iterator = iter_with_progress( - zip(configs, fns, is_workings, precompile_status, strict=True), + enumerate(zip(fns, is_workings, precompile_status, strict=True)), total=len(configs), description=f"{desc} exploring neighbors", enabled=self.settings.autotune_progress_bar, ) - for config, fn, is_working, reason in iterator: + for index, (fn, is_working, reason) in iterator: + config = configs[index] + if futures is not None: + future = futures[index] + compile_time = ( + future.elapsed + if future.process is not None and future.started + else None + ) + else: + compile_time = None status: Literal["ok", "error", "timeout"] if is_working: # benchmark one-by-one to avoid noisy results perf = self.benchmark_function(config, fn) status = "ok" if math.isfinite(perf) else "error" - results.append((config, fn, perf, status)) + results.append( + BenchmarkResult( + config=config, + fn=fn, + perf=perf, + status=status, + compile_time=compile_time, + ) + ) else: status = "timeout" if reason == "timeout" else "error" - results.append((config, fn, inf, status)) + results.append( + BenchmarkResult( + config=config, + fn=fn, + perf=inf, + status=status, + compile_time=compile_time, + ) + ) return results def autotune(self, *, skip_cache: bool = False) -> Config: @@ -543,9 +572,11 @@ def autotune(self, *, skip_cache: bool = False) -> Config: The best configuration found during autotuning. """ start = time.perf_counter() - self.log.reset() exit_stack = contextlib.ExitStack() with exit_stack: + if self.settings.autotune_log and isinstance(self, PopulationBasedSearch): + exit_stack.enter_context(self.log.autotune_logging()) + self.log.reset() # Autotuner triggers bugs in remote triton compile service exit_stack.enter_context( patch.dict(os.environ, {"TRITON_LOCAL_BUILD": "1"}, clear=False) @@ -600,6 +631,7 @@ class PopulationMember: flat_values: FlatConfig config: Config status: Literal["ok", "error", "timeout", "unknown"] = "unknown" + compile_time: float | None = None @property def perf(self) -> float: @@ -667,6 +699,7 @@ def __init__( """ super().__init__(kernel, args) self.population: list[PopulationMember] = [] + self._current_generation: int = 0 overrides = self.settings.autotune_config_overrides or None self.config_gen: ConfigGeneration = ConfigGeneration( self.config_spec, @@ -683,6 +716,9 @@ def best(self) -> PopulationMember: """ return min(self.population, key=performance) + def set_generation(self, generation: int) -> None: + self._current_generation = generation + def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember: """ Benchmark a flat configuration. @@ -694,9 +730,9 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember: A population member with the benchmark results. """ config = self.config_gen.unflatten(flat_values) - fn, perf = self.benchmark(config) - status: Literal["ok", "error"] = "ok" if math.isfinite(perf) else "error" - return PopulationMember(fn, [perf], flat_values, config, status=status) + member = PopulationMember(_unset_fn, [], flat_values, config) + self.parallel_benchmark_population([member], desc="Benchmarking") + return member def parallel_benchmark_flat( self, to_check: list[FlatConfig] @@ -737,17 +773,31 @@ def parallel_benchmark_population( members: The list of population members to benchmark. desc: Description for the progress bar. """ - for member, (config_out, fn, perf, status) in zip( - members, - self.parallel_benchmark([m.config for m in members], desc=desc), - strict=True, - ): - assert config_out is member.config - member.perfs.append(perf) - member.fn = fn - member.status = status + results = self.parallel_benchmark([m.config for m in members], desc=desc) + for member, result in zip(members, results, strict=True): + assert result.config is member.config + member.perfs.append(result.perf) + member.fn = result.fn + member.status = result.status + member.compile_time = result.compile_time + self._log_population_results(members) return members + def _log_population_results(self, members: Sequence[PopulationMember]) -> None: + for member in members: + perf_value = member.perf if member.perfs else None + if perf_value is not None and not math.isfinite(perf_value): + perf_value = None + self.log.record_autotune_entry( + AutotuneLogEntry( + generation=self._current_generation, + status=member.status, + perf_ms=perf_value, + compile_time=member.compile_time, + config=member.config, + ) + ) + def compare(self, a: PopulationMember, b: PopulationMember) -> int: """ Compare two population members based on their performance, possibly with re-benchmarking. @@ -1320,6 +1370,7 @@ def _prepare_precompiler_for_fork( config: Config, kernel: BoundKernel, decorator: str, + logger: AutotuningLogger, ) -> Callable[[], None] | None: def extract_launcher( triton_kernel: object, @@ -1344,12 +1395,12 @@ def extract_launcher( return precompiler except Exception: log_generated_triton_code_debug( - log, + logger, kernel, config, prefix=f"Generated Triton code for {decorator}:", ) - log.warning( + logger.warning( "Helion autotuner precompile error for %s. %s", decorator, SUPPRESSED_TRITON_CODE_MSG, diff --git a/helion/autotuner/differential_evolution.py b/helion/autotuner/differential_evolution.py index df172ecda..f17e42631 100644 --- a/helion/autotuner/differential_evolution.py +++ b/helion/autotuner/differential_evolution.py @@ -56,6 +56,7 @@ def mutate(self, x_index: int) -> FlatConfig: def initial_two_generations(self) -> None: # The initial population is 2x larger so we can throw out the slowest half and give the tuning process a head start + self.set_generation(0) oversized_population = sorted( self.parallel_benchmark_flat( self.config_gen.random_population_flat(self.population_size * 2), @@ -68,16 +69,25 @@ def initial_two_generations(self) -> None: ) self.population = oversized_population[: self.population_size] + def _benchmark_mutation_batch( + self, indices: Sequence[int] + ) -> list[PopulationMember]: + if not indices: + return [] + flat_configs = [self.mutate(i) for i in indices] + return self.parallel_benchmark_flat(flat_configs) + def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]: if self.immediate_update: for i in range(len(self.population)): - yield i, self.benchmark_flat(self.mutate(i)) + candidates = self._benchmark_mutation_batch([i]) + if not candidates: + continue + yield i, candidates[0] else: - yield from enumerate( - self.parallel_benchmark_flat( - [self.mutate(i) for i in range(len(self.population))] - ) - ) + indices = list(range(len(self.population))) + candidates = self._benchmark_mutation_batch(indices) + yield from zip(indices, candidates, strict=True) def evolve_population(self) -> int: replaced = 0 @@ -96,6 +106,7 @@ def _autotune(self) -> Config: ) self.initial_two_generations() for i in range(2, self.max_generations): + self.set_generation(i) self.log(f"Generation {i} starting") replaced = self.evolve_population() self.log(f"Generation {i} complete: replaced={replaced}", self.statistics) diff --git a/helion/autotuner/finite_search.py b/helion/autotuner/finite_search.py index 6fc49fc42..430278f07 100644 --- a/helion/autotuner/finite_search.py +++ b/helion/autotuner/finite_search.py @@ -35,11 +35,9 @@ def __init__( def _autotune(self) -> Config: best_config = None best_time = float("inf") - for config, _fn, time, _status in self.parallel_benchmark( - self.configs, desc="Benchmarking" - ): - if time < best_time: - best_time = time - best_config = config + for result in self.parallel_benchmark(self.configs, desc="Benchmarking"): + if result.perf < best_time: + best_time = result.perf + best_config = result.config assert best_config is not None return best_config diff --git a/helion/autotuner/logger.py b/helion/autotuner/logger.py index fa0729bb2..0fe03d83f 100644 --- a/helion/autotuner/logger.py +++ b/helion/autotuner/logger.py @@ -1,23 +1,59 @@ from __future__ import annotations +import contextlib +import csv import itertools import logging +import math +from pathlib import Path import re import sys import time +from types import TracebackType from typing import TYPE_CHECKING +from typing import Any from typing import Callable +from typing import Iterator from typing import Literal +from typing import NamedTuple +from typing import TypeAlias +from typing import TypeVar +from typing_extensions import Self from torch._inductor.runtime.triton_compat import OutOfResources from torch._inductor.runtime.triton_compat import PTXASError if TYPE_CHECKING: + from _csv import _writer as CsvWriter + import io + from ..runtime.config import Config from ..runtime.kernel import BoundKernel + from ..runtime.settings import Settings + +else: + CsvWriter = Any # type: ignore[assignment] + +SinkSelf = TypeVar("SinkSelf", bound="AutotuneLogSink") +ExcInfoParam: TypeAlias = ( + bool + | BaseException + | tuple[type[BaseException], BaseException, TracebackType | None] + | None +) -class LambdaLogger: +class _ElapsedFormatter(logging.Formatter): + def __init__(self, elapsed_fn: Callable[[], int]) -> None: + super().__init__() + self._elapsed_fn = elapsed_fn + + def format(self, record: logging.LogRecord) -> str: # type: ignore[override] + elapsed = self._elapsed_fn() + return f"[{elapsed}s] {record.getMessage()}" + + +class AutotuningLogger: """ A self-contained logger that does not propagate to the root logger and prints each record to stderr in the form: @@ -32,21 +68,100 @@ class LambdaLogger: _count: itertools.count[int] = itertools.count() - def __init__(self, level: int) -> None: + def __init__(self, settings: Settings) -> None: + self._settings = settings + level = settings.autotune_log_level self.level = level self._logger: logging.Logger = logging.getLogger( f"{__name__}.{next(self._count)}" ) self._logger.setLevel(level) self._logger.propagate = False + self._start_time: float = time.perf_counter() + self._extra_handlers: list[logging.Handler] = [] + self._active_handlers: list[logging.Handler] = [] + self._log_sink: AutotuneLogSink | None = None self.reset() def reset(self) -> None: - self._logger.handlers.clear() - self._logger.addHandler(_make_handler()) + self._start_time = time.perf_counter() + for handler in list(self._active_handlers): + self._logger.removeHandler(handler) + self._active_handlers = [] + self._register_handler(self._make_stream_handler()) + for handler in self._extra_handlers: + self._register_handler(handler) + + def add_handler(self, handler: logging.Handler) -> None: + if handler in self._extra_handlers: + return + self._extra_handlers.append(handler) + self._register_handler(handler) + + def remove_handler(self, handler: logging.Handler) -> None: + if handler in self._extra_handlers: + self._extra_handlers.remove(handler) + if handler in self._active_handlers: + self._active_handlers.remove(handler) + self._logger.removeHandler(handler) + + @contextlib.contextmanager + def autotune_logging( + self, base_path: str | None = None + ) -> Iterator[AutotuneLogSink | None]: + """Attach an :class:`AutotuneLogSink` for the duration of a tuning run.""" + + path = base_path or self._settings.autotune_log + if not path: + yield None + return + with AutotuneLogSink(path) as sink: + self._attach_sink(sink) + sink.start_run() + try: + yield sink + finally: + sink.end_run() + self._detach_sink() + + def record_autotune_entry(self, entry: AutotuneLogEntry) -> None: + """Write a structured autotune log entry when a sink is active.""" + + if self._log_sink is None: + return + self._log_sink.record(entry) + + def _attach_sink(self, sink: AutotuneLogSink) -> None: + self._log_sink = sink + self.add_handler(sink.handler) + + def _detach_sink(self) -> None: + sink = self._log_sink + if sink is None: + return + self.remove_handler(sink.handler) + self._log_sink = None + + def _elapsed_seconds(self) -> int: + return int(time.perf_counter() - self._start_time) + + def _configure_handler(self, handler: logging.Handler) -> None: + handler.setFormatter(_ElapsedFormatter(self._elapsed_seconds)) + + def _register_handler(self, handler: logging.Handler) -> None: + self._configure_handler(handler) + self._logger.addHandler(handler) + self._active_handlers.append(handler) + + def _make_stream_handler(self) -> logging.Handler: + return logging.StreamHandler(sys.stderr) def __call__( - self, *msg: str | Callable[[], str], level: int = logging.INFO + self, + *msg: str | Callable[[], str], + level: int = logging.INFO, + exc_info: ExcInfoParam = None, + stacklevel: int | None = None, ) -> None: """ Log a message at a specified log level. @@ -54,31 +169,69 @@ def __call__( Args: msg: The message(s) to log. Can be strings or callables that return strings. level: The log level for the message. + exc_info: Optional exception info forwarded to ``logging.Logger``. + stacklevel: Optional stack level forwarded to ``logging.Logger``. """ if level >= self.level: - self._logger.log(level, " ".join(map(_maybe_call, msg))) - - def error(self, *msg: str | Callable[[], str]) -> None: - return self(*msg, level=logging.ERROR) - - def warning(self, *msg: str | Callable[[], str]) -> None: - return self(*msg, level=logging.WARNING) - - def debug(self, *msg: str | Callable[[], str]) -> None: - return self(*msg, level=logging.DEBUG) - - -def _make_handler() -> logging.Handler: - start = time.perf_counter() + message = " ".join(map(_maybe_call, msg)) + if stacklevel is not None: + if exc_info is not None: + self._logger.log( + level, + message, + exc_info=exc_info, + stacklevel=stacklevel, + ) + else: + self._logger.log( + level, + message, + stacklevel=stacklevel, + ) + else: + if exc_info is not None: + self._logger.log(level, message, exc_info=exc_info) + else: + self._logger.log(level, message) + + def error( + self, + *msg: str | Callable[[], str], + exc_info: ExcInfoParam = None, + stacklevel: int | None = None, + ) -> None: + return self( + *msg, + level=logging.ERROR, + exc_info=exc_info, + stacklevel=stacklevel, + ) - class _ElapsedFormatter(logging.Formatter): - def format(self, record: logging.LogRecord) -> str: # type: ignore[override] - elapsed = int(time.perf_counter() - start) - return f"[{elapsed}s] {record.getMessage()}" + def warning( + self, + *msg: str | Callable[[], str], + exc_info: ExcInfoParam = None, + stacklevel: int | None = None, + ) -> None: + return self( + *msg, + level=logging.WARNING, + exc_info=exc_info, + stacklevel=stacklevel, + ) - handler = logging.StreamHandler(sys.stderr) - handler.setFormatter(_ElapsedFormatter()) - return handler + def debug( + self, + *msg: str | Callable[[], str], + exc_info: ExcInfoParam = None, + stacklevel: int | None = None, + ) -> None: + return self( + *msg, + level=logging.DEBUG, + exc_info=exc_info, + stacklevel=stacklevel, + ) def _maybe_call(fn: Callable[[], str] | str) -> str: @@ -96,13 +249,121 @@ def _maybe_call(fn: Callable[[], str] | str) -> str: return fn +class AutotuneLogEntry(NamedTuple): + generation: int + status: str + perf_ms: float | None + compile_time: float | None + config: Config + + +class AutotuneLogSink: + """ + Writes autotune results to CSV and connects autotune logs to a file handler. + """ + + def __init__(self, base_path: str) -> None: + self._base_path = Path(base_path) + self.csv_path = self._base_path.with_suffix(".csv") + self.log_path = self._base_path.with_suffix(".log") + self._csv_file: io.TextIOWrapper | None = None + self._csv_writer: CsvWriter | None = None + self._log_handler: logging.FileHandler | None = None + self._run_start_time: float | None = None + self._config_counter: int = 0 + + def __enter__(self) -> Self: + self.open() + return self + + def __exit__(self, *_exc: object) -> None: + self.close() + + @property + def handler(self) -> logging.Handler: + assert self._log_handler is not None, "Log sink not opened" + return self._log_handler + + def open(self) -> None: + if self._csv_writer is not None: + return + self.csv_path.parent.mkdir(parents=True, exist_ok=True) + self.log_path.parent.mkdir(parents=True, exist_ok=True) + self._csv_file = self.csv_path.open("w", encoding="utf-8", newline="") + self._csv_writer = csv.writer(self._csv_file) + self._csv_writer.writerow( + [ + "timestamp_s", + "config_index", + "generation", + "status", + "perf_ms", + "compile_time_s", + "config", + ] + ) + self._csv_file.flush() + handler = logging.FileHandler(self.log_path, mode="w", encoding="utf-8") + handler.setLevel(logging.DEBUG) + self._log_handler = handler + + def close(self) -> None: + if self._csv_file is not None: + self._csv_file.flush() + self._csv_file.close() + self._csv_file = None + self._csv_writer = None + if self._log_handler is not None: + self._log_handler.flush() + self._log_handler.close() + self._log_handler = None + self._run_start_time = None + self._config_counter = 0 + + def start_run(self) -> None: + self._run_start_time = time.perf_counter() + self._config_counter = 0 + + def end_run(self) -> None: + self._run_start_time = None + self._config_counter = 0 + + def record(self, entry: AutotuneLogEntry) -> None: + if self._csv_writer is None: + return + self._config_counter += 1 + timestamp_field = "" + if self._run_start_time is not None: + timestamp = time.perf_counter() - self._run_start_time + timestamp_field = f"{timestamp:.2f}" + perf_field = "" + if entry.perf_ms is not None and math.isfinite(entry.perf_ms): + perf_field = f"{entry.perf_ms:.6f}" + compile_field = "" + if entry.compile_time is not None: + compile_field = f"{entry.compile_time:.2f}" + self._csv_writer.writerow( + [ + timestamp_field, + self._config_counter, + entry.generation, + entry.status, + perf_field, + compile_field, + str(entry.config), + ] + ) + if self._csv_file is not None: + self._csv_file.flush() + + SUPPRESSED_TRITON_CODE_MSG = ( "Enable HELION_AUTOTUNE_LOG_LEVEL=DEBUG to log generated Triton code." ) def log_generated_triton_code_debug( - logger: logging.Logger | LambdaLogger, + logger: AutotuningLogger, bound_kernel: BoundKernel, config: Config, *, @@ -118,15 +379,7 @@ def log_generated_triton_code_debug( prefix: Optional prefix for the log message. """ message_prefix = prefix or "Generated Triton code:" - if isinstance(logger, LambdaLogger): - logger.debug(lambda: _format_triton_code(bound_kernel, config, message_prefix)) - return - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "%s\n%s", - message_prefix, - bound_kernel.to_triton_code(config), - ) + logger.debug(lambda: _format_triton_code(bound_kernel, config, message_prefix)) def _format_triton_code(bound_kernel: BoundKernel, config: Config, prefix: str) -> str: diff --git a/helion/autotuner/pattern_search.py b/helion/autotuner/pattern_search.py index 59e85f373..d8d759b31 100644 --- a/helion/autotuner/pattern_search.py +++ b/helion/autotuner/pattern_search.py @@ -61,6 +61,7 @@ def _autotune(self) -> Config: if member.config not in visited: visited.add(member.config) self.population.append(member) + self.set_generation(0) self.parallel_benchmark_population(self.population, desc="Initial population") # again with higher accuracy self.rebenchmark_population(self.population, desc="Verifying initial results") @@ -102,6 +103,7 @@ def _autotune(self) -> Config: # compile any unbenchmarked members in parallel unbenchmarked = [m for m in self.population if len(m.perfs) == 0] if unbenchmarked: + self.set_generation(generation) self.parallel_benchmark_population( unbenchmarked, desc=f"Generation {generation}:" ) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index adf8e6d15..eeb7b6278 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -169,6 +169,13 @@ def _get_autotune_log_level() -> int: ) +def _get_autotune_log_path() -> str | None: + value = os.environ.get("HELION_AUTOTUNE_LOG") + if value is None or (value := value.strip()) == "": + return None + return value + + def _get_autotune_config_overrides() -> dict[str, object]: value = os.environ.get("HELION_AUTOTUNE_CONFIG_OVERRIDES") if not value or (value := value.strip()) == "": @@ -272,6 +279,7 @@ class _Settings: default_factory=functools.partial(_env_get_bool, "HELION_STATIC_SHAPES", True) ) autotune_log_level: int = dataclasses.field(default_factory=_get_autotune_log_level) + autotune_log: str | None = dataclasses.field(default_factory=_get_autotune_log_path) autotune_compile_timeout: int = dataclasses.field( default_factory=functools.partial( _env_get_int, "HELION_AUTOTUNE_COMPILE_TIMEOUT", 60 @@ -397,6 +405,10 @@ class Settings(_Settings): "Log level for autotuning using Python logging levels. Default is logging.INFO. " "Use HELION_AUTOTUNE_LOG_LEVEL to override or set 0 to disable output." ), + "autotune_log": ( + "Base filename for autotune logs. Set HELION_AUTOTUNE_LOG=/tmp/run to write " + "/tmp/run.csv and /tmp/run.log with per-config metrics and debug logs." + ), "autotune_compile_timeout": "Timeout for Triton compilation in seconds used for autotuning. Default is 60 seconds.", "autotune_precompile": "Autotuner precompile mode: 'fork', 'spawn', or falsy/None to disable. Defaults to 'fork' on non-Windows platforms.", "autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.", diff --git a/test/test_autotuner.py b/test/test_autotuner.py index db748aacf..e95fde358 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -3,6 +3,7 @@ import collections from contextlib import contextmanager from contextlib import nullcontext +import csv from itertools import count import logging import math @@ -15,6 +16,7 @@ import tempfile from types import SimpleNamespace from typing import Callable +from typing import Sequence import unittest from unittest import skip from unittest.mock import patch @@ -34,6 +36,7 @@ from helion.autotuner import DifferentialEvolutionSearch from helion.autotuner import PatternSearch from helion.autotuner.base_search import BaseSearch +from helion.autotuner.base_search import PopulationMember from helion.autotuner.config_fragment import BooleanFragment from helion.autotuner.config_fragment import EnumFragment from helion.autotuner.config_fragment import IntegerFragment @@ -44,7 +47,8 @@ from helion.autotuner.finite_search import FiniteSearch from helion.autotuner.local_cache import LocalAutotuneCache from helion.autotuner.local_cache import StrictLocalAutotuneCache -from helion.autotuner.logger import LambdaLogger +from helion.autotuner.logger import AutotuneLogEntry +from helion.autotuner.logger import AutotuningLogger from helion.autotuner.random_search import RandomSearch import helion.language as hl from helion.language import loops @@ -90,7 +94,7 @@ def _make_search( ) search.args = args search.counters = collections.Counter() - search.log = LambdaLogger(logging.CRITICAL) + search.log = AutotuningLogger(settings) search._kernel_mutates_args = False search.best_perf_so_far = float("inf") tempdir = tempfile.TemporaryDirectory() @@ -142,6 +146,107 @@ def bad_fn(*_args): self.assertEqual(result, float("inf")) warn.assert_not_called() + def test_autotune_log_sink_writes_csv_and_log(self): + tmpdir = tempfile.TemporaryDirectory() + self.addCleanup(tmpdir.cleanup) + base_path = Path(tmpdir.name) / "autotune_run" + settings = Settings( + autotune_log=str(base_path), + autotune_log_level=logging.CRITICAL, + ) + logger = AutotuningLogger(settings) + with logger.autotune_logging(): + entry = AutotuneLogEntry( + generation=5, + status="ok", + perf_ms=1.234, + compile_time=0.5, + config=helion.Config(foo=1, bar=[2, 3]), + ) + logger.record_autotune_entry(entry) + logger("finalized entry", level=logging.CRITICAL) + + csv_path = base_path.with_suffix(".csv") + log_path = base_path.with_suffix(".log") + self.assertTrue(csv_path.exists()) + self.assertTrue(log_path.exists()) + rows = list(csv.reader(csv_path.read_text().splitlines())) + self.assertEqual( + rows[0], + [ + "timestamp_s", + "config_index", + "generation", + "status", + "perf_ms", + "compile_time_s", + "config", + ], + ) + self.assertEqual(rows[1][1], "1") + self.assertEqual(rows[1][2], "5") + self.assertEqual(rows[1][3], "ok") + self.assertEqual(rows[1][4], "1.234000") + log_text = log_path.read_text() + self.assertIn("finalized entry", log_text) + + def test_differential_evolution_immediate_iter_uses_batch_helper(self): + search = DifferentialEvolutionSearch.__new__(DifferentialEvolutionSearch) + search.immediate_update = True + search.population = [object(), object(), object()] + + calls: list[list[int]] = [] + + def batch(indices: Sequence[int]) -> list[PopulationMember]: + calls.append(list(indices)) + members: list[PopulationMember] = [] + for idx in indices: + members.append( + PopulationMember( + lambda *args, **kwargs: None, + [float(idx)], + [], + SimpleNamespace(config={"idx": idx}), + status="ok", + ) + ) + return members + + search._benchmark_mutation_batch = batch # type: ignore[assignment] + candidates = list(search.iter_candidates()) + self.assertEqual(calls, [[0], [1], [2]]) + self.assertEqual([idx for idx, _ in candidates], [0, 1, 2]) + + def test_differential_evolution_parallel_iter_uses_batch_helper(self): + search = DifferentialEvolutionSearch.__new__(DifferentialEvolutionSearch) + search.immediate_update = False + search.population = [object(), object()] + + def batch(indices: Sequence[int]) -> list[PopulationMember]: + members: list[PopulationMember] = [] + for idx in indices: + members.append( + PopulationMember( + lambda *args, **kwargs: None, + [float(idx)], + [], + SimpleNamespace(config={"idx": idx}), + status="ok", + ) + ) + return members + + calls: list[list[int]] = [] + + def recording_batch(indices: Sequence[int]) -> list[PopulationMember]: + calls.append(list(indices)) + return batch(indices) + + search._benchmark_mutation_batch = recording_batch # type: ignore[assignment] + candidates = list(search.iter_candidates()) + self.assertEqual(calls, [[0, 1]]) + self.assertEqual([idx for idx, _ in candidates], [0, 1]) + @pytest.mark.skipif( "fork" not in mp.get_all_start_methods(), reason="fork start method is unavailable on this platform",