Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ with helion.set_default_settings(
.. autoattribute:: Settings.autotune_precompile_jobs

Cap the number of concurrent Triton precompile subprocesses. ``None`` (default) uses the machine CPU count.
Controlled by ``HELION_AUTOTUNE_PRECOMPILE_JOBS``.
When using ``"spawn"`` precompile mode, Helion may automatically lower this cap if free GPU memory is limited.

.. autoattribute:: Settings.autotune_max_generations

Expand Down
99 changes: 87 additions & 12 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_map_only
from triton.testing import do_bench

from .. import exc
Expand Down Expand Up @@ -81,6 +82,10 @@ class BaseSearch(BaseAutotuner):
counters (collections.Counter): A counter to track various metrics during the search.
"""

_baseline_output: object
_kernel_mutates_args: bool
_baseline_post_args: Sequence[object] | None

def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
"""
Initialize the BaseSearch object.
Expand All @@ -101,17 +106,14 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
random.seed(seed)
self.log(f"Autotune random seed: {seed}")
self._original_args: Sequence[object] = self._clone_args(self.args)
self._baseline_output: object | None = None
self._baseline_post_args: Sequence[object] | None = None
self._kernel_mutates_args: bool = False
self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None
self._precompile_args_path: str | None = None
if self.settings.autotune_accuracy_check:
(
self._baseline_output,
self._kernel_mutates_args,
self._baseline_post_args,
) = self._compute_baseline()
(
self._baseline_output,
self._kernel_mutates_args,
self._baseline_post_args,
) = self._compute_baseline()
self._jobs = self._decide_num_jobs()

def cleanup(self) -> None:
if self._precompile_tmpdir is not None:
Expand Down Expand Up @@ -165,6 +167,55 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
baseline_post_args = self._clone_args(new_args)
return baseline_output, mutated, baseline_post_args

def _decide_num_jobs(self) -> int:
if not self.settings.autotune_precompile:
return 1

jobs = self.settings.autotune_precompile_jobs
if not jobs:
jobs = os.cpu_count() or 1

if self.settings.autotune_precompile != "spawn":
return jobs

memory_per_job = _estimate_tree_bytes(self.args) + _estimate_tree_bytes(
self._baseline_output
)
memory_per_job *= 2 # safety factor
if memory_per_job <= 0:
return jobs

device = self.kernel.env.device
if device.type != "cuda":
# TODO(jansel): support non-cuda devices
return jobs

available_memory, _ = torch.cuda.mem_get_info(device)
jobs_by_memory = available_memory // memory_per_job
if jobs_by_memory < jobs:
gib_per_job = memory_per_job / (1024**3)
available_gib = available_memory / (1024**3)
if jobs_by_memory > 0:
self.log.warning(
f"Reducing autotune precompile spawn jobs from {jobs} to {jobs_by_memory} "
f"due to limited GPU memory (estimated {gib_per_job:.2f} GiB per job, "
f"{available_gib:.2f} GiB free). "
f"Set HELION_AUTOTUNE_PRECOMPILE_JOBS={jobs_by_memory} "
"to make this lower cap persistent, "
'set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.'
)
else:
raise exc.AutotuneError(
"Autotune precompile spawn mode requires at least one job, but estimated "
"memory usage exceeds available GPU memory."
f"Estimated {gib_per_job:.2f} GiB per job, but only "
f"{available_gib:.2f} GiB free. "
'Set HELION_AUTOTUNE_PRECOMPILE="fork" to disable spawning, or reduce GPU memory usage.'
)
jobs = jobs_by_memory

return jobs

def _validate_against_baseline(
self, config: Config, output: object, args: Sequence[object]
) -> bool:
Expand All @@ -179,7 +230,7 @@ def _validate_against_baseline(
except AssertionError as e:
self.counters["accuracy_mismatch"] += 1
self.log.warning(
f"Skipping config with accuracy mismatch: {config!r}{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n"
f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check.\n"
)
return False
return True
Expand Down Expand Up @@ -454,6 +505,31 @@ def performance(member: PopulationMember) -> float:
return member.perf


def _estimate_tree_bytes(obj: object) -> int:
"""Estimate the memory usage of a pytree of objects, counting shared storage only once."""
total = 0
seen_ptrs: set[int] = set()

def _accumulate(tensor: torch.Tensor) -> torch.Tensor:
nonlocal total
size = tensor.element_size() * tensor.numel()
try:
storage = tensor.untyped_storage()
except RuntimeError:
pass
else:
ptr = storage.data_ptr()
if ptr in seen_ptrs:
return tensor
seen_ptrs.add(ptr)
size = storage.nbytes()
total += size
return tensor

tree_map_only(torch.Tensor, _accumulate, obj)
return total


class PopulationBasedSearch(BaseSearch):
"""
Base class for search algorithms that use a population of configurations.
Expand Down Expand Up @@ -823,8 +899,7 @@ def _wait_for_all_step(
futures: list[PrecompileFuture],
) -> list[PrecompileFuture]:
"""Start up to the concurrency cap, wait for progress, and return remaining futures."""
# Concurrency cap from the settings of the first future's search
cap = futures[0].search.settings.autotune_precompile_jobs or os.cpu_count() or 1
cap = futures[0].search._jobs if futures else 1
running = [f for f in futures if f.started and f.ok is None and f.is_alive()]

# Start queued futures up to the cap
Expand Down
14 changes: 13 additions & 1 deletion helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ def _get_autotune_precompile() -> str | None:
)


def _get_autotune_precompile_jobs() -> int | None:
value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE_JOBS")
if value is None or value.strip() == "":
return None
jobs = int(value)
if jobs <= 0:
raise ValueError("HELION_AUTOTUNE_PRECOMPILE_JOBS must be a positive integer")
return jobs


@dataclasses.dataclass
class _Settings:
# see __slots__ below for the doc strings that show up in help(Settings)
Expand All @@ -164,7 +174,9 @@ class _Settings:
autotune_precompile: str | None = dataclasses.field(
default_factory=_get_autotune_precompile
)
autotune_precompile_jobs: int | None = None
autotune_precompile_jobs: int | None = dataclasses.field(
default_factory=_get_autotune_precompile_jobs
)
autotune_random_seed: int = dataclasses.field(
default_factory=_get_autotune_random_seed
)
Expand Down