diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index dff6eada7..d00e5c12c 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -21,7 +21,10 @@ if TYPE_CHECKING: from triton.runtime.jit import JITFunction +import torch import torch.multiprocessing as mp +from torch.utils._pytree import tree_flatten +from torch.utils._pytree import tree_map from triton.testing import do_bench from .. import exc @@ -82,10 +85,63 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: self.kernel = kernel self.settings: Settings = kernel.settings self.config_spec: ConfigSpec = kernel.config_spec - self.args = args + self.args: Sequence[object] = args self.counters: collections.Counter[str] = collections.Counter() self.log = LambdaLogger(self.settings.autotune_log_level) random.seed(self.settings.autotune_random_seed) + self._original_args: Sequence[object] = self._clone_args(self.args) + ( + self._baseline_output, + self._kernel_mutates_args, + self._baseline_post_args, + ) = self._compute_baseline() + + def _clone_args(self, args: Sequence[object]) -> Sequence[object]: + def _clone_leaf(leaf: object) -> object: + if isinstance(leaf, torch.Tensor): + clone = leaf.detach().clone() + clone.requires_grad_(leaf.requires_grad) + return clone + return leaf + + return tree_map(_clone_leaf, args) + + def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: + """ + Return output and post-run input arguments of the default-config kernel. + Also detect if the kernel mutates any of its input arguments. + """ + new_args = self._clone_args(self._original_args) + baseline_config = self.config_spec.default_config() + baseline_output = self.kernel.compile_config( + baseline_config, allow_print=False + )(*new_args) + original_args_flat, _ = tree_flatten(self._original_args) + new_args_flat, _ = tree_flatten(new_args) + mutated = False + for old, new in zip(original_args_flat, new_args_flat, strict=False): + if ( + isinstance(old, torch.Tensor) + and isinstance(new, torch.Tensor) + and (not torch.equal(new, old)) + ): + mutated = True + break + baseline_post_args = self._clone_args(new_args) + return baseline_output, mutated, baseline_post_args + + def _validate_against_baseline( + self, config: Config, output: object, args: Sequence[object] + ) -> bool: + try: + torch.testing.assert_close(output, self._baseline_output) + if self._kernel_mutates_args: + torch.testing.assert_close(args, self._baseline_post_args) + except AssertionError as e: + self.counters["accuracy_mismatch"] += 1 + self.log.warning(f"Accuracy mismatch for {config!r}: {e!s}") + return False + return True def benchmark(self, config: Config) -> float: """ @@ -121,7 +177,12 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: try: # TODO(jansel): early exit with fewer trials if early runs are slow t0 = time.perf_counter() - fn(*self.args) # make sure the kernel is compiled + if self._kernel_mutates_args: + self.args = self._clone_args(self._original_args) + output = fn(*self.args) # make sure the kernel is compiled + if not self._validate_against_baseline(config, output, self.args): + # Accuracy check failed; reject this config + return inf t1 = time.perf_counter() res = do_bench( functools.partial(fn, *self.args), diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 94288d864..5c7d74b9e 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import contextmanager +import math import os from pathlib import Path import random @@ -20,6 +21,7 @@ from helion._testing import skipIfRocm from helion.autotuner import DifferentialEvolutionSearch from helion.autotuner.config_generation import ConfigGeneration +from helion.autotuner.finite_search import FiniteSearch from helion.autotuner.random_search import RandomSearch import helion.language as hl from helion.language import loops @@ -172,6 +174,105 @@ def test_differential_evolution_search(self): fn = bound_kernel.compile_config(best) torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) + def test_accuracy_check_filters_bad_config_wrong_output(self) -> None: + bad_config = helion.Config(block_sizes=[1], num_warps=8) + good_config = helion.Config(block_sizes=[1], num_warps=4) + + @helion.kernel(configs=[bad_config, good_config], autotune_log_level=0) + def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + for tile in hl.tile(b.size()): + b[tile] = a[tile] + b[tile] + return b + + a = torch.randn([32], device=DEVICE) + b = torch.randn([32], device=DEVICE) + bound_kernel = add_inplace.bind((a, b)) + + original_compile = bound_kernel.compile_config + + def make_bad_config_produce_wrong_output( + config: helion.Config, *, allow_print: bool = True + ): + fn = original_compile(config, allow_print=allow_print) + if config == bad_config: + return lambda *fn_args, **fn_kwargs: fn(*fn_args, **fn_kwargs) + 1 + return fn + + with patch.object( + bound_kernel, + "compile_config", + side_effect=make_bad_config_produce_wrong_output, + ): + search = FiniteSearch( + bound_kernel, (a, b), configs=[bad_config, good_config] + ) + bad_time = search.benchmark(bad_config) + assert math.isinf(bad_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + search.counters["accuracy_mismatch"] = 0 # reset counter + + good_time = search.benchmark(good_config) + assert not math.isinf(good_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) + search.counters["accuracy_mismatch"] = 0 # reset counter + + best = search._autotune() + self.assertEqual(best, good_config) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + + def test_accuracy_check_filters_bad_config_wrong_arg_mutation(self) -> None: + bad_config = helion.Config(block_sizes=[1], num_warps=8) + good_config = helion.Config(block_sizes=[1], num_warps=4) + + @helion.kernel(configs=[bad_config, good_config], autotune_log_level=0) + def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + for tile in hl.tile(b.size()): + b[tile] = a[tile] + b[tile] + return b + + a = torch.randn([32], device=DEVICE) + b = torch.randn([32], device=DEVICE) + bound_kernel = add_inplace.bind((a, b)) + + original_compile = bound_kernel.compile_config + + def make_bad_config_produce_wrong_input_arg_mutation( + config: helion.Config, *, allow_print: bool = True + ): + fn = original_compile(config, allow_print=allow_print) + if config == bad_config: + + def wrong_fn(*fn_args, **fn_kwargs): + result = fn(*fn_args, **fn_kwargs) + # Introduce an extra mutation so inputs differ from baseline + fn_args[1].add_(1) + return result + + return wrong_fn + return fn + + with patch.object( + bound_kernel, + "compile_config", + side_effect=make_bad_config_produce_wrong_input_arg_mutation, + ): + search = FiniteSearch( + bound_kernel, (a, b), configs=[bad_config, good_config] + ) + bad_time = search.benchmark(bad_config) + assert math.isinf(bad_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + search.counters["accuracy_mismatch"] = 0 # reset counter + + good_time = search.benchmark(good_config) + assert not math.isinf(good_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) + search.counters["accuracy_mismatch"] = 0 # reset counter + + best = search._autotune() + self.assertEqual(best, good_config) + self.assertGreaterEqual(search.counters.get("accuracy_mismatch", 0), 1) + def test_use_default_config(self): @helion.kernel(use_default_config=True) def add(a, b):