Skip to content

Commit 992ebe3

Browse files
committed
Rebenchmark configs to avoid noise
stack-info: PR: #654, branch: jansel/stack/146
1 parent 4e57ce7 commit 992ebe3

File tree

5 files changed

+137
-16
lines changed

5 files changed

+137
-16
lines changed

helion/autotuner/base_search.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
import sys
1616
import time
1717
from typing import TYPE_CHECKING
18+
from typing import Callable
1819
from typing import NamedTuple
1920
from typing import NoReturn
2021

22+
from .benchmarking import interleaved_bench
23+
2124
if TYPE_CHECKING:
2225
from triton.runtime.jit import JITFunction
2326

@@ -90,6 +93,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
9093
self.args: Sequence[object] = args
9194
self.counters: collections.Counter[str] = collections.Counter()
9295
self.log = LambdaLogger(self.settings.autotune_log_level)
96+
self.best_perf_so_far = inf
9397
seed = self.settings.autotune_random_seed
9498
random.seed(seed)
9599
self.log(f"Autotune random seed: {seed}")
@@ -167,7 +171,7 @@ def _validate_against_baseline(
167171
return False
168172
return True
169173

170-
def benchmark(self, config: Config) -> float:
174+
def benchmark(self, config: Config) -> tuple[Callable[..., object], float]:
171175
"""
172176
Benchmark a specific configuration.
173177
@@ -177,12 +181,12 @@ def benchmark(self, config: Config) -> float:
177181
config: The configuration to benchmark.
178182
179183
Returns:
180-
The performance of the configuration in seconds.
184+
The function and performance of the configuration in ms.
181185
"""
182186
fn = self.kernel.compile_config(config, allow_print=False)
183187
if self.start_precompile_and_check_for_hangs(config, fn)():
184-
return self.benchmark_function(config, fn)
185-
return inf
188+
return fn, self.benchmark_function(config, fn)
189+
return fn, inf
186190

187191
def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
188192
"""
@@ -194,7 +198,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
194198
fn: A precompiled version of config.
195199
196200
Returns:
197-
The performance of the configuration in seconds.
201+
The performance of the configuration in ms.
198202
"""
199203
self.counters["benchmark"] += 1
200204
self.log.debug(lambda: f"Running benchmark for {config!r}")
@@ -216,10 +220,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
216220
return_mode="median",
217221
)
218222
t2 = time.perf_counter()
223+
assert isinstance(res, float)
219224
self.log.debug(
220225
lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
221226
)
222-
return res # pyright: ignore[reportReturnType]
227+
if res < self.best_perf_so_far:
228+
self.best_perf_so_far = res
229+
return res
223230
except Exception as e:
224231
action = classify_triton_exception(e)
225232
if action == "raise":
@@ -286,7 +293,9 @@ def extract_launcher(
286293
timeout=self.settings.autotune_compile_timeout,
287294
)
288295

289-
def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]]:
296+
def parallel_benchmark(
297+
self, configs: list[Config]
298+
) -> list[tuple[Config, Callable[..., object], float]]:
290299
"""
291300
Benchmark multiple configurations in parallel.
292301
@@ -312,9 +321,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
312321
for config, fn, is_working in zip(configs, fns, is_workings, strict=True):
313322
if is_working:
314323
# benchmark one-by-one to avoid noisy results
315-
results.append((config, self.benchmark_function(config, fn)))
324+
results.append((config, fn, self.benchmark_function(config, fn)))
316325
else:
317-
results.append((config, inf))
326+
results.append((config, fn, inf))
318327
return results
319328

320329
def autotune(self) -> Config:
@@ -361,15 +370,20 @@ class PopulationMember(NamedTuple):
361370
Represents a member of the population in population-based search algorithms.
362371
363372
Attributes:
364-
perf (float): The performance of the configuration.
373+
perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks.
365374
flat_values (FlatConfig): The flat representation of the configuration values.
366375
config (Config): The full configuration object.
367376
"""
368377

369-
perf: float
378+
fn: Callable[..., object]
379+
perfs: list[float]
370380
flat_values: FlatConfig
371381
config: Config
372382

383+
@property
384+
def perf(self) -> float:
385+
return self.perfs[-1]
386+
373387

374388
def performance(member: PopulationMember) -> float:
375389
"""
@@ -430,7 +444,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
430444
A population member with the benchmark results.
431445
"""
432446
config = self.config_gen.unflatten(flat_values)
433-
return PopulationMember(self.benchmark(config), flat_values, config)
447+
fn, perf = self.benchmark(config)
448+
return PopulationMember(fn, [perf], flat_values, config)
434449

435450
def parallel_benchmark_flat(
436451
self, to_check: list[FlatConfig]
@@ -446,13 +461,65 @@ def parallel_benchmark_flat(
446461
"""
447462
configs = [*map(self.config_gen.unflatten, to_check)]
448463
result = []
449-
for flat_values, config_in, (config_out, perf) in zip(
464+
for flat_values, config_in, (config_out, fn, perf) in zip(
450465
to_check, configs, self.parallel_benchmark(configs), strict=True
451466
):
452467
assert config_in is config_out
453-
result.append(PopulationMember(perf, flat_values, config_in))
468+
result.append(PopulationMember(fn, [perf], flat_values, config_in))
454469
return result
455470

471+
def compare(self, a: PopulationMember, b: PopulationMember) -> int:
472+
"""
473+
Compare two population members based on their performance, possibly with re-benchmarking.
474+
475+
Args:
476+
a: The first population member.
477+
b: The second population member.
478+
479+
Returns:
480+
-1 if a is better than b, 1 if b is better than a, 0 if they are equal.
481+
"""
482+
if self.should_rebenchmark(a) and self.should_rebenchmark(b):
483+
self.rebenchmark([a, b])
484+
return (a.perf > b.perf) - (a.perf < b.perf)
485+
486+
def should_rebenchmark(self, member: PopulationMember) -> bool:
487+
"""
488+
Determine if a population member should be re-benchmarked to avoid outliers.
489+
490+
Args:
491+
member: The population member to check.
492+
493+
Returns:
494+
True if the member should be re-benchmarked, False otherwise.
495+
"""
496+
return (
497+
member.perf
498+
< self.settings.autotune_rebenchmark_threshold * self.best_perf_so_far
499+
and math.isfinite(member.perf)
500+
)
501+
502+
def rebenchmark(self, members: list[PopulationMember]) -> None:
503+
"""
504+
Re-benchmark a list of population members to avoid outliers.
505+
"""
506+
if len(members) < 2:
507+
return
508+
repeat = max(3, int(200 / self.best_perf_so_far))
509+
new_timings = interleaved_bench(
510+
[functools.partial(m.fn, *self.args) for m in members], repeat=repeat
511+
)
512+
for m, t in zip(members, new_timings, strict=True):
513+
m.perfs.append(t)
514+
if t < self.best_perf_so_far:
515+
self.best_perf_so_far = t
516+
517+
def rebenchmark_population(self) -> None:
518+
"""
519+
Re-benchmark the entire population to avoid outliers.
520+
"""
521+
self.rebenchmark([p for p in self.population if self.should_rebenchmark(p)])
522+
456523
def statistics(self) -> str:
457524
"""
458525
Generate statistics for the current population.

helion/autotuner/benchmarking.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import statistics
5+
from typing import Callable
6+
7+
from triton import runtime
8+
9+
10+
def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[float]:
11+
"""
12+
Benchmark multiple functions at once, interleaving their executions to reduce
13+
the impact of external factors (e.g., load, temperature) on the
14+
measurements.
15+
"""
16+
# warmup
17+
for fn in fns:
18+
fn()
19+
clear_cache = functools.partial(
20+
runtime.driver.active.clear_cache, # type: ignore[attr-defined]
21+
runtime.driver.active.get_empty_cache_for_benchmark(), # type: ignore[attr-defined]
22+
)
23+
clear_cache()
24+
di = runtime.driver.active.get_device_interface() # type: ignore[attr-defined]
25+
start_events = [
26+
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
27+
]
28+
end_events = [
29+
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
30+
]
31+
32+
di.synchronize()
33+
for i in range(repeat):
34+
for j in range(len(fns)):
35+
clear_cache()
36+
start_events[j][i].record()
37+
fns[j]()
38+
end_events[j][i].record()
39+
di.synchronize()
40+
41+
return [
42+
statistics.median(
43+
[
44+
s.elapsed_time(e)
45+
for s, e in zip(start_events[j], end_events[j], strict=True)
46+
]
47+
)
48+
for j in range(len(fns))
49+
]

helion/autotuner/differential_evolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]:
8181
def evolve_population(self) -> int:
8282
replaced = 0
8383
for i, candidate in self.iter_candidates():
84-
if candidate.perf < self.population[i].perf:
84+
if self.compare(candidate, self.population[i]) < 0:
8585
self.population[i] = candidate
8686
replaced += 1
8787
return replaced
@@ -97,4 +97,5 @@ def _autotune(self) -> Config:
9797
for i in range(2, self.num_generations):
9898
replaced = self.evolve_population()
9999
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
100+
self.rebenchmark_population()
100101
return self.best.config

helion/autotuner/finite_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
def _autotune(self) -> Config:
3636
best_config = None
3737
best_time = float("inf")
38-
for config, time in self.parallel_benchmark(self.configs):
38+
for config, _fn, time in self.parallel_benchmark(self.configs):
3939
if time < best_time:
4040
best_time = time
4141
best_config = config

helion/runtime/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class _Settings:
101101
autotune_accuracy_check: bool = (
102102
os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1"
103103
)
104+
autotune_rebenchmark_threshold: float = float(
105+
os.environ.get("HELION_REBENCHMARK_THRESHOLD", "1.5")
106+
)
104107
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
105108
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
106109
allow_warp_specialize: bool = (
@@ -131,6 +134,7 @@ class Settings(_Settings):
131134
"autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.",
132135
"autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.",
133136
"autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.",
137+
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.",
134138
"print_output_code": "If True, print the output code of the kernel to stderr.",
135139
"force_autotune": "If True, force autotuning even if a config is provided.",
136140
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",

0 commit comments

Comments
 (0)