From 0ec45eed8495d92ba987f2f92e9b4de3d2868fd8 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 21 Sep 2025 19:43:57 -0700 Subject: [PATCH] Rebenchmark configs to avoid noise stack-info: PR: https://github.com/pytorch/helion/pull/654, branch: jansel/stack/146 --- helion/autotuner/base_search.py | 141 ++++++++++++++++++--- helion/autotuner/benchmarking.py | 49 +++++++ helion/autotuner/differential_evolution.py | 3 +- helion/autotuner/finite_search.py | 2 +- helion/runtime/settings.py | 4 + test/test_autotuner.py | 8 +- test/test_errors.py | 9 +- 7 files changed, 188 insertions(+), 28 deletions(-) create mode 100644 helion/autotuner/benchmarking.py diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index df0442829..0a8ac12ab 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -15,9 +15,11 @@ import sys import time from typing import TYPE_CHECKING -from typing import NamedTuple +from typing import Callable from typing import NoReturn +from .benchmarking import interleaved_bench + if TYPE_CHECKING: from triton.runtime.jit import JITFunction @@ -90,6 +92,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: self.args: Sequence[object] = args self.counters: collections.Counter[str] = collections.Counter() self.log = LambdaLogger(self.settings.autotune_log_level) + self.best_perf_so_far = inf seed = self.settings.autotune_random_seed random.seed(seed) self.log(f"Autotune random seed: {seed}") @@ -167,7 +170,7 @@ def _validate_against_baseline( return False return True - def benchmark(self, config: Config) -> float: + def benchmark(self, config: Config) -> tuple[Callable[..., object], float]: """ Benchmark a specific configuration. @@ -177,12 +180,12 @@ def benchmark(self, config: Config) -> float: config: The configuration to benchmark. Returns: - The performance of the configuration in seconds. + The function and performance of the configuration in ms. """ fn = self.kernel.compile_config(config, allow_print=False) if self.start_precompile_and_check_for_hangs(config, fn)(): - return self.benchmark_function(config, fn) - return inf + return fn, self.benchmark_function(config, fn) + return fn, inf def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: """ @@ -194,7 +197,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: fn: A precompiled version of config. Returns: - The performance of the configuration in seconds. + The performance of the configuration in ms. """ self.counters["benchmark"] += 1 self.log.debug(lambda: f"Running benchmark for {config!r}") @@ -214,12 +217,17 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: res = do_bench( functools.partial(fn, *self.args), return_mode="median", + warmup=1, # we are already warmed up above + rep=50, ) t2 = time.perf_counter() + assert isinstance(res, float) self.log.debug( lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)", ) - return res # pyright: ignore[reportReturnType] + if res < self.best_perf_so_far: + self.best_perf_so_far = res + return res except Exception as e: action = classify_triton_exception(e) if action == "raise": @@ -286,7 +294,9 @@ def extract_launcher( timeout=self.settings.autotune_compile_timeout, ) - def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]]: + def parallel_benchmark( + self, configs: list[Config] + ) -> list[tuple[Config, Callable[..., object], float]]: """ Benchmark multiple configurations in parallel. @@ -312,9 +322,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float] for config, fn, is_working in zip(configs, fns, is_workings, strict=True): if is_working: # benchmark one-by-one to avoid noisy results - results.append((config, self.benchmark_function(config, fn))) + results.append((config, fn, self.benchmark_function(config, fn))) else: - results.append((config, inf)) + results.append((config, fn, inf)) return results def autotune(self) -> Config: @@ -356,20 +366,26 @@ def _autotune(self) -> Config: raise NotImplementedError -class PopulationMember(NamedTuple): +@dataclasses.dataclass +class PopulationMember: """ Represents a member of the population in population-based search algorithms. Attributes: - perf (float): The performance of the configuration. + perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks. flat_values (FlatConfig): The flat representation of the configuration values. config (Config): The full configuration object. """ - perf: float + fn: Callable[..., object] + perfs: list[float] flat_values: FlatConfig config: Config + @property + def perf(self) -> float: + return self.perfs[-1] + def performance(member: PopulationMember) -> float: """ @@ -430,7 +446,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember: A population member with the benchmark results. """ config = self.config_gen.unflatten(flat_values) - return PopulationMember(self.benchmark(config), flat_values, config) + fn, perf = self.benchmark(config) + return PopulationMember(fn, [perf], flat_values, config) def parallel_benchmark_flat( self, to_check: list[FlatConfig] @@ -444,14 +461,92 @@ def parallel_benchmark_flat( Returns: A list of population members with the benchmark results. """ - configs = [*map(self.config_gen.unflatten, to_check)] - result = [] - for flat_values, config_in, (config_out, perf) in zip( - to_check, configs, self.parallel_benchmark(configs), strict=True + result = [*map(self.make_unbenchmarked, to_check)] + return self.parallel_benchmark_population(result) + + def make_unbenchmarked(self, flat_values: FlatConfig) -> PopulationMember: + """ + Create a population member with unbenchmarked configuration. You + should pass the result of this to parallel_benchmark_population. + + Args: + flat_values: The flat configuration values. + + Returns: + A population member with undefined performance. + """ + config = self.config_gen.unflatten(flat_values) + return PopulationMember(_unset_fn, [], flat_values, config) + + def parallel_benchmark_population( + self, members: list[PopulationMember] + ) -> list[PopulationMember]: + """ + Benchmark multiple population members in parallel. Members should be created with make_unbenchmarked. + """ + for member, (config_out, fn, perf) in zip( + members, self.parallel_benchmark([m.config for m in members]), strict=True ): - assert config_in is config_out - result.append(PopulationMember(perf, flat_values, config_in)) - return result + assert config_out is member.config + member.perfs.append(perf) + member.fn = fn + return members + + def compare(self, a: PopulationMember, b: PopulationMember) -> int: + """ + Compare two population members based on their performance, possibly with re-benchmarking. + + Args: + a: The first population member. + b: The second population member. + + Returns: + -1 if a is better than b, 1 if b is better than a, 0 if they are equal. + """ + if self.should_rebenchmark(a) and self.should_rebenchmark(b): + self.rebenchmark([a, b]) + return (a.perf > b.perf) - (a.perf < b.perf) + + def should_rebenchmark(self, member: PopulationMember) -> bool: + """ + Determine if a population member should be re-benchmarked to avoid outliers. + + Args: + member: The population member to check. + + Returns: + True if the member should be re-benchmarked, False otherwise. + """ + return ( + member.perf + < self.settings.autotune_rebenchmark_threshold * self.best_perf_so_far + and math.isfinite(member.perf) + ) + + def rebenchmark(self, members: list[PopulationMember]) -> None: + """ + Re-benchmark a list of population members to avoid outliers. + """ + if len(members) < 2: + return + repeat = max(3, int(200 / self.best_perf_so_far)) + new_timings = interleaved_bench( + [functools.partial(m.fn, *self.args) for m in members], repeat=repeat + ) + for m, t in zip(members, new_timings, strict=True): + m.perfs.append(t) + if t < self.best_perf_so_far: + self.best_perf_so_far = t + + def rebenchmark_population( + self, members: list[PopulationMember] | None = None + ) -> None: + """ + Re-benchmark the entire population to avoid outliers. + """ + if members is None: + members = self.population + self.rebenchmark([p for p in members if self.should_rebenchmark(p)]) def statistics(self) -> str: """ @@ -697,3 +792,7 @@ def __init__( self.grid = grid self.args = args self.kwargs = kwargs + + +def _unset_fn(*args: object) -> NoReturn: + raise RuntimeError("Uninitialized function") diff --git a/helion/autotuner/benchmarking.py b/helion/autotuner/benchmarking.py new file mode 100644 index 000000000..2533b1e2b --- /dev/null +++ b/helion/autotuner/benchmarking.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import functools +import statistics +from typing import Callable + +from triton import runtime + + +def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[float]: + """ + Benchmark multiple functions at once, interleaving their executions to reduce + the impact of external factors (e.g., load, temperature) on the + measurements. + """ + # warmup + for fn in fns: + fn() + clear_cache = functools.partial( + runtime.driver.active.clear_cache, # type: ignore[attr-defined] + runtime.driver.active.get_empty_cache_for_benchmark(), # type: ignore[attr-defined] + ) + clear_cache() + di = runtime.driver.active.get_device_interface() # type: ignore[attr-defined] + start_events = [ + [di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns)) + ] + end_events = [ + [di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns)) + ] + + di.synchronize() + for i in range(repeat): + for j in range(len(fns)): + clear_cache() + start_events[j][i].record() + fns[j]() + end_events[j][i].record() + di.synchronize() + + return [ + statistics.median( + [ + s.elapsed_time(e) + for s, e in zip(start_events[j], end_events[j], strict=True) + ] + ) + for j in range(len(fns)) + ] diff --git a/helion/autotuner/differential_evolution.py b/helion/autotuner/differential_evolution.py index 4e922f62c..bbddbe0b1 100644 --- a/helion/autotuner/differential_evolution.py +++ b/helion/autotuner/differential_evolution.py @@ -81,7 +81,7 @@ def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]: def evolve_population(self) -> int: replaced = 0 for i, candidate in self.iter_candidates(): - if candidate.perf < self.population[i].perf: + if self.compare(candidate, self.population[i]) < 0: self.population[i] = candidate replaced += 1 return replaced @@ -97,4 +97,5 @@ def _autotune(self) -> Config: for i in range(2, self.num_generations): replaced = self.evolve_population() self.log(f"Generation {i}: replaced={replaced}", self.statistics) + self.rebenchmark_population() return self.best.config diff --git a/helion/autotuner/finite_search.py b/helion/autotuner/finite_search.py index 5c1717c09..379c1423a 100644 --- a/helion/autotuner/finite_search.py +++ b/helion/autotuner/finite_search.py @@ -35,7 +35,7 @@ def __init__( def _autotune(self) -> Config: best_config = None best_time = float("inf") - for config, time in self.parallel_benchmark(self.configs): + for config, _fn, time in self.parallel_benchmark(self.configs): if time < best_time: best_time = time best_config = config diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 82809361a..5361b6702 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -101,6 +101,9 @@ class _Settings: autotune_accuracy_check: bool = ( os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1" ) + autotune_rebenchmark_threshold: float = float( + os.environ.get("HELION_REBENCHMARK_THRESHOLD", "1.5") + ) print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1" force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1" allow_warp_specialize: bool = ( @@ -131,6 +134,7 @@ class Settings(_Settings): "autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.", "autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.", "autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.", + "autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.", "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 5c7d74b9e..41b5f8d60 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -206,12 +206,12 @@ def make_bad_config_produce_wrong_output( search = FiniteSearch( bound_kernel, (a, b), configs=[bad_config, good_config] ) - bad_time = search.benchmark(bad_config) + _, bad_time = search.benchmark(bad_config) assert math.isinf(bad_time) self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) search.counters["accuracy_mismatch"] = 0 # reset counter - good_time = search.benchmark(good_config) + _, good_time = search.benchmark(good_config) assert not math.isinf(good_time) self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) search.counters["accuracy_mismatch"] = 0 # reset counter @@ -259,12 +259,12 @@ def wrong_fn(*fn_args, **fn_kwargs): search = FiniteSearch( bound_kernel, (a, b), configs=[bad_config, good_config] ) - bad_time = search.benchmark(bad_config) + _, bad_time = search.benchmark(bad_config) assert math.isinf(bad_time) self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) search.counters["accuracy_mismatch"] = 0 # reset counter - good_time = search.benchmark(good_config) + _, good_time = search.benchmark(good_config) assert not math.isinf(good_time) self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) search.counters["accuracy_mismatch"] = 0 # reset counter diff --git a/test/test_errors.py b/test/test_errors.py index c28e2e841..18dae0f19 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -50,7 +50,14 @@ def fake_parallel( members = [] for flat_values in to_check: cfg = self.config_gen.unflatten(flat_values) - members.append(PopulationMember(float("inf"), flat_values, cfg)) + members.append( + PopulationMember( + lambda *args: None, + [float("inf")], + flat_values, + cfg, + ) + ) return members with (