Skip to content

Commit eb406f4

Browse files
committed
Rebenchmark configs to avoid noise
stack-info: PR: #654, branch: jansel/stack/146
1 parent 9a6f893 commit eb406f4

File tree

5 files changed

+137
-16
lines changed

5 files changed

+137
-16
lines changed

helion/autotuner/base_search.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
import sys
1616
import time
1717
from typing import TYPE_CHECKING
18+
from typing import Callable
1819
from typing import NamedTuple
1920
from typing import NoReturn
2021

22+
from .benchmarking import interleaved_bench
23+
2124
if TYPE_CHECKING:
2225
from triton.runtime.jit import JITFunction
2326

@@ -90,6 +93,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
9093
self.args: Sequence[object] = args
9194
self.counters: collections.Counter[str] = collections.Counter()
9295
self.log = LambdaLogger(self.settings.autotune_log_level)
96+
self.best_perf_so_far = inf
9397
seed = self.settings.autotune_random_seed
9498
random.seed(seed)
9599
self.log(f"Autotune random seed: {seed}")
@@ -157,7 +161,7 @@ def _validate_against_baseline(
157161
return False
158162
return True
159163

160-
def benchmark(self, config: Config) -> float:
164+
def benchmark(self, config: Config) -> tuple[Callable[..., object], float]:
161165
"""
162166
Benchmark a specific configuration.
163167
@@ -167,12 +171,12 @@ def benchmark(self, config: Config) -> float:
167171
config: The configuration to benchmark.
168172
169173
Returns:
170-
The performance of the configuration in seconds.
174+
The function and performance of the configuration in ms.
171175
"""
172176
fn = self.kernel.compile_config(config, allow_print=False)
173177
if self.start_precompile_and_check_for_hangs(config, fn)():
174-
return self.benchmark_function(config, fn)
175-
return inf
178+
return fn, self.benchmark_function(config, fn)
179+
return fn, inf
176180

177181
def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
178182
"""
@@ -184,7 +188,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
184188
fn: A precompiled version of config.
185189
186190
Returns:
187-
The performance of the configuration in seconds.
191+
The performance of the configuration in ms.
188192
"""
189193
self.counters["benchmark"] += 1
190194
self.log.debug(lambda: f"Running benchmark for {config!r}")
@@ -206,10 +210,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
206210
return_mode="median",
207211
)
208212
t2 = time.perf_counter()
213+
assert isinstance(res, float)
209214
self.log.debug(
210215
lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
211216
)
212-
return res # pyright: ignore[reportReturnType]
217+
if res < self.best_perf_so_far:
218+
self.best_perf_so_far = res
219+
return res
213220
except Exception as e:
214221
action = classify_triton_exception(e)
215222
if action == "raise":
@@ -276,7 +283,9 @@ def extract_launcher(
276283
timeout=self.settings.autotune_compile_timeout,
277284
)
278285

279-
def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]]:
286+
def parallel_benchmark(
287+
self, configs: list[Config]
288+
) -> list[tuple[Config, Callable[..., object], float]]:
280289
"""
281290
Benchmark multiple configurations in parallel.
282291
@@ -302,9 +311,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
302311
for config, fn, is_working in zip(configs, fns, is_workings, strict=True):
303312
if is_working:
304313
# benchmark one-by-one to avoid noisy results
305-
results.append((config, self.benchmark_function(config, fn)))
314+
results.append((config, fn, self.benchmark_function(config, fn)))
306315
else:
307-
results.append((config, inf))
316+
results.append((config, fn, inf))
308317
return results
309318

310319
def autotune(self) -> Config:
@@ -351,15 +360,20 @@ class PopulationMember(NamedTuple):
351360
Represents a member of the population in population-based search algorithms.
352361
353362
Attributes:
354-
perf (float): The performance of the configuration.
363+
perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks.
355364
flat_values (FlatConfig): The flat representation of the configuration values.
356365
config (Config): The full configuration object.
357366
"""
358367

359-
perf: float
368+
fn: Callable[..., object]
369+
perfs: list[float]
360370
flat_values: FlatConfig
361371
config: Config
362372

373+
@property
374+
def perf(self) -> float:
375+
return self.perfs[-1]
376+
363377

364378
def performance(member: PopulationMember) -> float:
365379
"""
@@ -420,7 +434,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
420434
A population member with the benchmark results.
421435
"""
422436
config = self.config_gen.unflatten(flat_values)
423-
return PopulationMember(self.benchmark(config), flat_values, config)
437+
fn, perf = self.benchmark(config)
438+
return PopulationMember(fn, [perf], flat_values, config)
424439

425440
def parallel_benchmark_flat(
426441
self, to_check: list[FlatConfig]
@@ -436,13 +451,65 @@ def parallel_benchmark_flat(
436451
"""
437452
configs = [*map(self.config_gen.unflatten, to_check)]
438453
result = []
439-
for flat_values, config_in, (config_out, perf) in zip(
454+
for flat_values, config_in, (config_out, fn, perf) in zip(
440455
to_check, configs, self.parallel_benchmark(configs), strict=True
441456
):
442457
assert config_in is config_out
443-
result.append(PopulationMember(perf, flat_values, config_in))
458+
result.append(PopulationMember(fn, [perf], flat_values, config_in))
444459
return result
445460

461+
def compare(self, a: PopulationMember, b: PopulationMember) -> int:
462+
"""
463+
Compare two population members based on their performance, possibly with re-benchmarking.
464+
465+
Args:
466+
a: The first population member.
467+
b: The second population member.
468+
469+
Returns:
470+
-1 if a is better than b, 1 if b is better than a, 0 if they are equal.
471+
"""
472+
if self.should_rebenchmark(a) and self.should_rebenchmark(b):
473+
self.rebenchmark([a, b])
474+
return (a.perf > b.perf) - (a.perf < b.perf)
475+
476+
def should_rebenchmark(self, member: PopulationMember) -> bool:
477+
"""
478+
Determine if a population member should be re-benchmarked to avoid outliers.
479+
480+
Args:
481+
member: The population member to check.
482+
483+
Returns:
484+
True if the member should be re-benchmarked, False otherwise.
485+
"""
486+
return (
487+
member.perf
488+
< self.settings.autotune_rebenchmark_threshold * self.best_perf_so_far
489+
and math.isfinite(member.perf)
490+
)
491+
492+
def rebenchmark(self, members: list[PopulationMember]) -> None:
493+
"""
494+
Re-benchmark a list of population members to avoid outliers.
495+
"""
496+
if len(members) < 2:
497+
return
498+
repeat = max(3, int(200 / self.best_perf_so_far))
499+
new_timings = interleaved_bench(
500+
[functools.partial(m.fn, *self.args) for m in members], repeat=repeat
501+
)
502+
for m, t in zip(members, new_timings, strict=True):
503+
m.perfs.append(t)
504+
if t < self.best_perf_so_far:
505+
self.best_perf_so_far = t
506+
507+
def rebenchmark_population(self) -> None:
508+
"""
509+
Re-benchmark the entire population to avoid outliers.
510+
"""
511+
self.rebenchmark([p for p in self.population if self.should_rebenchmark(p)])
512+
446513
def statistics(self) -> str:
447514
"""
448515
Generate statistics for the current population.

helion/autotuner/benchmarking.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import statistics
5+
from typing import Callable
6+
7+
from triton import runtime
8+
9+
10+
def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[float]:
11+
"""
12+
Benchmark multiple functions at once, interleaving their executions to reduce
13+
the impact of external factors (e.g., load, temperature) on the
14+
measurements.
15+
"""
16+
# warmup
17+
for fn in fns:
18+
fn()
19+
clear_cache = functools.partial(
20+
runtime.driver.active.clear_cache, # type: ignore[attr-defined]
21+
runtime.driver.active.get_empty_cache_for_benchmark(), # type: ignore[attr-defined]
22+
)
23+
clear_cache()
24+
di = runtime.driver.active.get_device_interface() # type: ignore[attr-defined]
25+
start_events = [
26+
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
27+
]
28+
end_events = [
29+
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
30+
]
31+
32+
di.synchronize()
33+
for i in range(repeat):
34+
for j in range(len(fns)):
35+
clear_cache()
36+
start_events[j][i].record()
37+
fns[j]()
38+
end_events[j][i].record()
39+
di.synchronize()
40+
41+
return [
42+
statistics.median(
43+
[
44+
s.elapsed_time(e)
45+
for s, e in zip(start_events[j], end_events[j], strict=True)
46+
]
47+
)
48+
for j in range(len(fns))
49+
]

helion/autotuner/differential_evolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]:
8181
def evolve_population(self) -> int:
8282
replaced = 0
8383
for i, candidate in self.iter_candidates():
84-
if candidate.perf < self.population[i].perf:
84+
if self.compare(candidate, self.population[i]) < 0:
8585
self.population[i] = candidate
8686
replaced += 1
8787
return replaced
@@ -97,4 +97,5 @@ def _autotune(self) -> Config:
9797
for i in range(2, self.num_generations):
9898
replaced = self.evolve_population()
9999
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
100+
self.rebenchmark_population()
100101
return self.best.config

helion/autotuner/finite_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
def _autotune(self) -> Config:
3636
best_config = None
3737
best_time = float("inf")
38-
for config, time in self.parallel_benchmark(self.configs):
38+
for config, _fn, time in self.parallel_benchmark(self.configs):
3939
if time < best_time:
4040
best_time = time
4141
best_config = config

helion/runtime/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class _Settings:
101101
autotune_accuracy_check: bool = (
102102
os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1"
103103
)
104+
autotune_rebenchmark_threshold: float = float(
105+
os.environ.get("HELION_REBENCHMARK_THRESHOLD", "1.5")
106+
)
104107
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
105108
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
106109
allow_warp_specialize: bool = (
@@ -131,6 +134,7 @@ class Settings(_Settings):
131134
"autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.",
132135
"autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.",
133136
"autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.",
137+
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.",
134138
"print_output_code": "If True, print the output code of the kernel to stderr.",
135139
"force_autotune": "If True, force autotuning even if a config is provided.",
136140
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",

0 commit comments

Comments
 (0)