Skip to content

Commit 3703e20

Browse files
committed
[wip] Rebenchmark configs to avoid noise
stack-info: PR: #654, branch: jansel/stack/146
1 parent 2626b41 commit 3703e20

File tree

5 files changed

+135
-16
lines changed

5 files changed

+135
-16
lines changed

helion/autotuner/base_search.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
import sys
1616
import time
1717
from typing import TYPE_CHECKING
18+
from typing import Callable
1819
from typing import NamedTuple
1920
from typing import NoReturn
2021

22+
from .benchmarking import interleaved_bench
23+
2124
if TYPE_CHECKING:
2225
from triton.runtime.jit import JITFunction
2326

@@ -85,9 +88,10 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
8588
self.args = args
8689
self.counters: collections.Counter[str] = collections.Counter()
8790
self.log = LambdaLogger(self.settings.autotune_log_level)
91+
self.best_perf_so_far = inf
8892
random.seed(self.settings.autotune_random_seed)
8993

90-
def benchmark(self, config: Config) -> float:
94+
def benchmark(self, config: Config) -> tuple[Callable[..., object], float]:
9195
"""
9296
Benchmark a specific configuration.
9397
@@ -97,12 +101,12 @@ def benchmark(self, config: Config) -> float:
97101
config: The configuration to benchmark.
98102
99103
Returns:
100-
The performance of the configuration in seconds.
104+
The functiona and performance of the configuration in ms.
101105
"""
102106
fn = self.kernel.compile_config(config, allow_print=False)
103107
if self.start_precompile_and_check_for_hangs(config, fn)():
104-
return self.benchmark_function(config, fn)
105-
return inf
108+
return fn, self.benchmark_function(config, fn)
109+
return fn, inf
106110

107111
def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
108112
"""
@@ -114,7 +118,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
114118
fn: A precompiled version of config.
115119
116120
Returns:
117-
The performance of the configuration in seconds.
121+
The performance of the configuration in ms.
118122
"""
119123
self.counters["benchmark"] += 1
120124
self.log.debug(lambda: f"Running benchmark for {config!r}")
@@ -128,10 +132,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
128132
return_mode="median",
129133
)
130134
t2 = time.perf_counter()
135+
assert isinstance(res, float)
131136
self.log.debug(
132137
lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
133138
)
134-
return res # pyright: ignore[reportReturnType]
139+
if res < self.best_perf_so_far:
140+
self.best_perf_so_far = res
141+
return res
135142
except Exception as e:
136143
action = classify_triton_exception(e)
137144
if action == "raise":
@@ -198,7 +205,9 @@ def extract_launcher(
198205
timeout=self.settings.autotune_compile_timeout,
199206
)
200207

201-
def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]]:
208+
def parallel_benchmark(
209+
self, configs: list[Config]
210+
) -> list[tuple[Config, Callable[..., object], float]]:
202211
"""
203212
Benchmark multiple configurations in parallel.
204213
@@ -224,9 +233,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
224233
for config, fn, is_working in zip(configs, fns, is_workings, strict=True):
225234
if is_working:
226235
# benchmark one-by-one to avoid noisy results
227-
results.append((config, self.benchmark_function(config, fn)))
236+
results.append((config, fn, self.benchmark_function(config, fn)))
228237
else:
229-
results.append((config, inf))
238+
results.append((config, fn, inf))
230239
return results
231240

232241
def autotune(self) -> Config:
@@ -271,15 +280,20 @@ class PopulationMember(NamedTuple):
271280
Represents a member of the population in population-based search algorithms.
272281
273282
Attributes:
274-
perf (float): The performance of the configuration.
283+
perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks.
275284
flat_values (FlatConfig): The flat representation of the configuration values.
276285
config (Config): The full configuration object.
277286
"""
278287

279-
perf: float
288+
fn: Callable[..., object]
289+
perfs: list[float]
280290
flat_values: FlatConfig
281291
config: Config
282292

293+
@property
294+
def perf(self) -> float:
295+
return self.perfs[-1]
296+
283297

284298
def performance(member: PopulationMember) -> float:
285299
"""
@@ -340,7 +354,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
340354
A population member with the benchmark results.
341355
"""
342356
config = self.config_gen.unflatten(flat_values)
343-
return PopulationMember(self.benchmark(config), flat_values, config)
357+
fn, perf = self.benchmark(config)
358+
return PopulationMember(fn, [perf], flat_values, config)
344359

345360
def parallel_benchmark_flat(
346361
self, to_check: list[FlatConfig]
@@ -356,13 +371,65 @@ def parallel_benchmark_flat(
356371
"""
357372
configs = [*map(self.config_gen.unflatten, to_check)]
358373
result = []
359-
for flat_values, config_in, (config_out, perf) in zip(
374+
for flat_values, config_in, (config_out, fn, perf) in zip(
360375
to_check, configs, self.parallel_benchmark(configs), strict=True
361376
):
362377
assert config_in is config_out
363-
result.append(PopulationMember(perf, flat_values, config_in))
378+
result.append(PopulationMember(fn, [perf], flat_values, config_in))
364379
return result
365380

381+
def compare(self, a: PopulationMember, b: PopulationMember) -> int:
382+
"""
383+
Compare two population members based on their performance, possibly with re-benchmarking.
384+
385+
Args:
386+
a: The first population member.
387+
b: The second population member.
388+
389+
Returns:
390+
-1 if a is better than b, 1 if b is better than a, 0 if they are equal.
391+
"""
392+
if self.should_rebenchmark(a) and self.should_rebenchmark(b):
393+
self.rebenchmark([a, b])
394+
return (a.perf > b.perf) - (a.perf < b.perf)
395+
396+
def should_rebenchmark(self, member: PopulationMember) -> bool:
397+
"""
398+
Determine if a population member should be re-benchmarked to avoid outliers.
399+
400+
Args:
401+
member: The population member to check.
402+
403+
Returns:
404+
True if the member should be re-benchmarked, False otherwise.
405+
"""
406+
return (
407+
member.perf
408+
< self.settings.autotune_rebenchmark_threshold * self.best_perf_so_far
409+
and math.isfinite(member.perf)
410+
)
411+
412+
def rebenchmark(self, members: list[PopulationMember]) -> None:
413+
"""
414+
Re-benchmark a list of population members to avoid outliers.
415+
"""
416+
if len(members) < 2:
417+
return
418+
repeat = max(3, int(200 / self.best_perf_so_far))
419+
new_timings = interleaved_bench(
420+
[functools.partial(m.fn, *self.args) for m in members], repeat=repeat
421+
)
422+
for m, t in zip(members, new_timings, strict=True):
423+
m.perfs.append(t)
424+
if t < self.best_perf_so_far:
425+
self.best_perf_so_far = t
426+
427+
def rebenchmark_population(self) -> None:
428+
"""
429+
Re-benchmark the entire population to avoid outliers.
430+
"""
431+
self.rebenchmark([p for p in self.population if self.should_rebenchmark(p)])
432+
366433
def statistics(self) -> str:
367434
"""
368435
Generate statistics for the current population.

helion/autotuner/benchmarking.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import statistics
5+
from typing import Callable
6+
7+
from triton import runtime
8+
9+
10+
def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[float]:
11+
"""
12+
Benchmark multiple functions at once, interleaving their executions to reduce
13+
the impact of external factors (e.g., load, temperature) on the
14+
measurements.
15+
"""
16+
# warmup
17+
for fn in fns:
18+
fn()
19+
clear_cache = functools.partial(
20+
runtime.driver.active.clear_cache, # type: ignore[attr-defined]
21+
runtime.driver.active.get_empty_cache_for_benchmark(), # type: ignore[attr-defined]
22+
)
23+
clear_cache()
24+
di = runtime.driver.active.get_device_interface() # type: ignore[attr-defined]
25+
start_events = [
26+
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
27+
]
28+
end_events = [
29+
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
30+
]
31+
32+
di.synchronize()
33+
for i in range(repeat):
34+
for j in range(len(fns)):
35+
clear_cache()
36+
start_events[j][i].record()
37+
fns[j]()
38+
end_events[j][i].record()
39+
di.synchronize()
40+
41+
return [
42+
statistics.median(
43+
[
44+
s.elapsed_time(e)
45+
for s, e in zip(start_events[j], end_events[j], strict=True)
46+
]
47+
)
48+
for j in range(len(fns))
49+
]

helion/autotuner/differential_evolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]:
8181
def evolve_population(self) -> int:
8282
replaced = 0
8383
for i, candidate in self.iter_candidates():
84-
if candidate.perf < self.population[i].perf:
84+
if self.compare(candidate, self.population[i]) < 0:
8585
self.population[i] = candidate
8686
replaced += 1
8787
return replaced
@@ -97,4 +97,5 @@ def _autotune(self) -> Config:
9797
for i in range(2, self.num_generations):
9898
replaced = self.evolve_population()
9999
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
100+
self.rebenchmark_population()
100101
return self.best.config

helion/autotuner/finite_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
def _autotune(self) -> Config:
3636
best_config = None
3737
best_time = float("inf")
38-
for config, time in self.parallel_benchmark(self.configs):
38+
for config, _fn, time in self.parallel_benchmark(self.configs):
3939
if time < best_time:
4040
best_time = time
4141
best_config = config

helion/runtime/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class _Settings:
9898
autotune_random_seed: int = dataclasses.field(
9999
default_factory=_get_autotune_random_seed
100100
)
101+
autotune_rebenchmark_threshold: float = 1.5
101102
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
102103
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
103104
allow_warp_specialize: bool = (
@@ -127,6 +128,7 @@ class Settings(_Settings):
127128
"autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.",
128129
"autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.",
129130
"autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.",
131+
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.",
130132
"print_output_code": "If True, print the output code of the kernel to stderr.",
131133
"force_autotune": "If True, force autotuning even if a config is provided.",
132134
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",

0 commit comments

Comments
 (0)