From 750b20e08d432cad24f854e70dd6377b1d268d39 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 14 Oct 2025 21:23:35 -0700 Subject: [PATCH] Auto-shrink autotune_precompile_jobs based on free memory stack-info: PR: https://github.com/pytorch/helion/pull/940, branch: jansel/stack/192 --- docs/api/settings.md | 2 + helion/autotuner/base_search.py | 99 +++++++++++++++++++++++++++++---- helion/runtime/settings.py | 14 ++++- 3 files changed, 102 insertions(+), 13 deletions(-) diff --git a/docs/api/settings.md b/docs/api/settings.md index b4bc71aca..aad01b9f3 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -139,6 +139,8 @@ with helion.set_default_settings( .. autoattribute:: Settings.autotune_precompile_jobs Cap the number of concurrent Triton precompile subprocesses. ``None`` (default) uses the machine CPU count. + Controlled by ``HELION_AUTOTUNE_PRECOMPILE_JOBS``. + When using ``"spawn"`` precompile mode, Helion may automatically lower this cap if free GPU memory is limited. .. autoattribute:: Settings.autotune_max_generations diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index bdf21f17d..bd997a9fc 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -32,6 +32,7 @@ import torch from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_map_only from triton.testing import do_bench from .. import exc @@ -81,6 +82,10 @@ class BaseSearch(BaseAutotuner): counters (collections.Counter): A counter to track various metrics during the search. """ + _baseline_output: object + _kernel_mutates_args: bool + _baseline_post_args: Sequence[object] | None + def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: """ Initialize the BaseSearch object. @@ -101,17 +106,14 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: random.seed(seed) self.log(f"Autotune random seed: {seed}") self._original_args: Sequence[object] = self._clone_args(self.args) - self._baseline_output: object | None = None - self._baseline_post_args: Sequence[object] | None = None - self._kernel_mutates_args: bool = False self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None self._precompile_args_path: str | None = None - if self.settings.autotune_accuracy_check: - ( - self._baseline_output, - self._kernel_mutates_args, - self._baseline_post_args, - ) = self._compute_baseline() + ( + self._baseline_output, + self._kernel_mutates_args, + self._baseline_post_args, + ) = self._compute_baseline() + self._jobs = self._decide_num_jobs() def cleanup(self) -> None: if self._precompile_tmpdir is not None: @@ -165,6 +167,55 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: baseline_post_args = self._clone_args(new_args) return baseline_output, mutated, baseline_post_args + def _decide_num_jobs(self) -> int: + if not self.settings.autotune_precompile: + return 1 + + jobs = self.settings.autotune_precompile_jobs + if not jobs: + jobs = os.cpu_count() or 1 + + if self.settings.autotune_precompile != "spawn": + return jobs + + memory_per_job = _estimate_tree_bytes(self.args) + _estimate_tree_bytes( + self._baseline_output + ) + memory_per_job *= 2 # safety factor + if memory_per_job <= 0: + return jobs + + device = self.kernel.env.device + if device.type != "cuda": + # TODO(jansel): support non-cuda devices + return jobs + + available_memory, _ = torch.cuda.mem_get_info(device) + jobs_by_memory = available_memory // memory_per_job + if jobs_by_memory < jobs: + gib_per_job = memory_per_job / (1024**3) + available_gib = available_memory / (1024**3) + if jobs_by_memory > 0: + self.log.warning( + f"Reducing autotune precompile spawn jobs from {jobs} to {jobs_by_memory} " + f"due to limited GPU memory (estimated {gib_per_job:.2f} GiB per job, " + f"{available_gib:.2f} GiB free). " + f"Set HELION_AUTOTUNE_PRECOMPILE_JOBS={jobs_by_memory} " + "to make this lower cap persistent, " + 'set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.' + ) + else: + raise exc.AutotuneError( + "Autotune precompile spawn mode requires at least one job, but estimated " + "memory usage exceeds available GPU memory." + f"Estimated {gib_per_job:.2f} GiB per job, but only " + f"{available_gib:.2f} GiB free. " + 'Set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.' + ) + jobs = jobs_by_memory + + return jobs + def _validate_against_baseline( self, config: Config, output: object, args: Sequence[object] ) -> bool: @@ -179,7 +230,7 @@ def _validate_against_baseline( except AssertionError as e: self.counters["accuracy_mismatch"] += 1 self.log.warning( - f"Skipping config with accuracy mismatch: {config!r}{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n" + f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n" ) return False return True @@ -454,6 +505,31 @@ def performance(member: PopulationMember) -> float: return member.perf +def _estimate_tree_bytes(obj: object) -> int: + """Estimate the memory usage of a pytree of objects, counting shared storage only once.""" + total = 0 + seen_ptrs: set[int] = set() + + def _accumulate(tensor: torch.Tensor) -> torch.Tensor: + nonlocal total + size = tensor.element_size() * tensor.numel() + try: + storage = tensor.untyped_storage() + except RuntimeError: + pass + else: + ptr = storage.data_ptr() + if ptr in seen_ptrs: + return tensor + seen_ptrs.add(ptr) + size = storage.nbytes() + total += size + return tensor + + tree_map_only(torch.Tensor, _accumulate, obj) + return total + + class PopulationBasedSearch(BaseSearch): """ Base class for search algorithms that use a population of configurations. @@ -823,8 +899,7 @@ def _wait_for_all_step( futures: list[PrecompileFuture], ) -> list[PrecompileFuture]: """Start up to the concurrency cap, wait for progress, and return remaining futures.""" - # Concurrency cap from the settings of the first future's search - cap = futures[0].search.settings.autotune_precompile_jobs or os.cpu_count() or 1 + cap = futures[0].search._jobs if futures else 1 running = [f for f in futures if f.started and f.ok is None and f.is_alive()] # Start queued futures up to the cap diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 9362ac18a..b7c99a766 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -145,6 +145,16 @@ def _get_autotune_precompile() -> str | None: ) +def _get_autotune_precompile_jobs() -> int | None: + value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS") + if value is None or value.strip() == "": + return None + jobs = int(value) + if jobs <= 0: + raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer") + return jobs + + @dataclasses.dataclass class _Settings: # see __slots__ below for the doc strings that show up in help(Settings) @@ -164,7 +174,9 @@ class _Settings: autotune_precompile: str | None = dataclasses.field( default_factory=_get_autotune_precompile ) - autotune_precompile_jobs: int | None = None + autotune_precompile_jobs: int | None = dataclasses.field( + default_factory=_get_autotune_precompile_jobs + ) autotune_random_seed: int = dataclasses.field( default_factory=_get_autotune_random_seed )