From 0d489391841db7d48d5070f7510e8d1fbc978db4 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 25 Sep 2025 18:30:01 -0700 Subject: [PATCH] Support HELION_AUTOTUNE_ACCURACY_CHECK=0 stack-info: PR: https://github.com/pytorch/helion/pull/692, branch: jansel/stack/149 --- helion/autotuner/base_search.py | 23 ++++++++++++++++------- helion/runtime/settings.py | 4 ++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index a1b9f0b00..be191fff0 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -92,11 +92,15 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: random.seed(seed) self.log(f"Autotune random seed: {seed}") self._original_args: Sequence[object] = self._clone_args(self.args) - ( - self._baseline_output, - self._kernel_mutates_args, - self._baseline_post_args, - ) = self._compute_baseline() + self._baseline_output: object | None = None + self._baseline_post_args: Sequence[object] | None = None + self._kernel_mutates_args: bool = False + if self.settings.autotune_accuracy_check: + ( + self._baseline_output, + self._kernel_mutates_args, + self._baseline_post_args, + ) = self._compute_baseline() def _clone_args(self, args: Sequence[object]) -> Sequence[object]: def _clone_leaf(leaf: object) -> object: @@ -145,7 +149,9 @@ def _validate_against_baseline( ) except AssertionError as e: self.counters["accuracy_mismatch"] += 1 - self.log.warning(f"Accuracy mismatch for {config!r}: {e!s}") + self.log.warning( + f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check." + ) return False return True @@ -186,7 +192,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: if self._kernel_mutates_args: self.args = self._clone_args(self._original_args) output = fn(*self.args) # make sure the kernel is compiled - if not self._validate_against_baseline(config, output, self.args): + if ( + self.settings.autotune_accuracy_check + and not self._validate_against_baseline(config, output, self.args) + ): # Accuracy check failed; reject this config return inf t1 = time.perf_counter() diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index f729074b0..82809361a 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -98,6 +98,9 @@ class _Settings: autotune_random_seed: int = dataclasses.field( default_factory=_get_autotune_random_seed ) + autotune_accuracy_check: bool = ( + os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1" + ) print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1" force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1" allow_warp_specialize: bool = ( @@ -127,6 +130,7 @@ class Settings(_Settings): "autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.", "autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.", "autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.", + "autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.", "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",