Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
random.seed(seed)
self.log(f"Autotune random seed: {seed}")
self._original_args: Sequence[object] = self._clone_args(self.args)
(
self._baseline_output,
self._kernel_mutates_args,
self._baseline_post_args,
) = self._compute_baseline()
self._baseline_output: object | None = None
self._baseline_post_args: Sequence[object] | None = None
self._kernel_mutates_args: bool = False
if self.settings.autotune_accuracy_check:
(
self._baseline_output,
self._kernel_mutates_args,
self._baseline_post_args,
) = self._compute_baseline()

def _clone_args(self, args: Sequence[object]) -> Sequence[object]:
def _clone_leaf(leaf: object) -> object:
Expand Down Expand Up @@ -145,7 +149,9 @@ def _validate_against_baseline(
)
except AssertionError as e:
self.counters["accuracy_mismatch"] += 1
self.log.warning(f"Accuracy mismatch for {config!r}: {e!s}")
self.log.warning(
f"Skipping config with accuracy mismatch: {config!r}\n{e!s}\nUse HELION_AUTOTUNE_ACCURACY_CHECK=0 to disable this check."
)
return False
return True

Expand Down Expand Up @@ -186,7 +192,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
if self._kernel_mutates_args:
self.args = self._clone_args(self._original_args)
output = fn(*self.args) # make sure the kernel is compiled
if not self._validate_against_baseline(config, output, self.args):
if (
self.settings.autotune_accuracy_check
and not self._validate_against_baseline(config, output, self.args)
):
# Accuracy check failed; reject this config
return inf
t1 = time.perf_counter()
Expand Down
4 changes: 4 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class _Settings:
autotune_random_seed: int = dataclasses.field(
default_factory=_get_autotune_random_seed
)
autotune_accuracy_check: bool = (
os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1"
)
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
allow_warp_specialize: bool = (
Expand Down Expand Up @@ -127,6 +130,7 @@ class Settings(_Settings):
"autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.",
"autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.",
"autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.",
"autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.",
"print_output_code": "If True, print the output code of the kernel to stderr.",
"force_autotune": "If True, force autotuning even if a config is provided.",
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
Expand Down
Loading