1515import sys
1616import time
1717from typing import TYPE_CHECKING
18+ from typing import Callable
1819from typing import NamedTuple
1920from typing import NoReturn
2021
22+ from .benchmarking import interleaved_bench
23+
2124if TYPE_CHECKING :
2225 from triton .runtime .jit import JITFunction
2326
@@ -90,6 +93,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
9093 self .args : Sequence [object ] = args
9194 self .counters : collections .Counter [str ] = collections .Counter ()
9295 self .log = LambdaLogger (self .settings .autotune_log_level )
96+ self .best_perf_so_far = inf
9397 seed = self .settings .autotune_random_seed
9498 random .seed (seed )
9599 self .log (f"Autotune random seed: { seed } " )
@@ -157,7 +161,7 @@ def _validate_against_baseline(
157161 return False
158162 return True
159163
160- def benchmark (self , config : Config ) -> float :
164+ def benchmark (self , config : Config ) -> tuple [ Callable [..., object ], float ] :
161165 """
162166 Benchmark a specific configuration.
163167
@@ -167,12 +171,12 @@ def benchmark(self, config: Config) -> float:
167171 config: The configuration to benchmark.
168172
169173 Returns:
170- The performance of the configuration in seconds .
174+ The function and performance of the configuration in ms .
171175 """
172176 fn = self .kernel .compile_config (config , allow_print = False )
173177 if self .start_precompile_and_check_for_hangs (config , fn )():
174- return self .benchmark_function (config , fn )
175- return inf
178+ return fn , self .benchmark_function (config , fn )
179+ return fn , inf
176180
177181 def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
178182 """
@@ -184,7 +188,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
184188 fn: A precompiled version of config.
185189
186190 Returns:
187- The performance of the configuration in seconds .
191+ The performance of the configuration in ms .
188192 """
189193 self .counters ["benchmark" ] += 1
190194 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -206,10 +210,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
206210 return_mode = "median" ,
207211 )
208212 t2 = time .perf_counter ()
213+ assert isinstance (res , float )
209214 self .log .debug (
210215 lambda : f"result: { res :.4f} ms (took { t1 - t0 :.1f} s + { t2 - t1 :.1f} s)" ,
211216 )
212- return res # pyright: ignore[reportReturnType]
217+ if res < self .best_perf_so_far :
218+ self .best_perf_so_far = res
219+ return res
213220 except Exception as e :
214221 action = classify_triton_exception (e )
215222 if action == "raise" :
@@ -276,7 +283,9 @@ def extract_launcher(
276283 timeout = self .settings .autotune_compile_timeout ,
277284 )
278285
279- def parallel_benchmark (self , configs : list [Config ]) -> list [tuple [Config , float ]]:
286+ def parallel_benchmark (
287+ self , configs : list [Config ]
288+ ) -> list [tuple [Config , Callable [..., object ], float ]]:
280289 """
281290 Benchmark multiple configurations in parallel.
282291
@@ -302,9 +311,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
302311 for config , fn , is_working in zip (configs , fns , is_workings , strict = True ):
303312 if is_working :
304313 # benchmark one-by-one to avoid noisy results
305- results .append ((config , self .benchmark_function (config , fn )))
314+ results .append ((config , fn , self .benchmark_function (config , fn )))
306315 else :
307- results .append ((config , inf ))
316+ results .append ((config , fn , inf ))
308317 return results
309318
310319 def autotune (self ) -> Config :
@@ -351,15 +360,20 @@ class PopulationMember(NamedTuple):
351360 Represents a member of the population in population-based search algorithms.
352361
353362 Attributes:
354- perf ( float): The performance of the configuration.
363+ perfs (list[ float] ): The performance of the configuration, accumulated over multiple benchmarks .
355364 flat_values (FlatConfig): The flat representation of the configuration values.
356365 config (Config): The full configuration object.
357366 """
358367
359- perf : float
368+ fn : Callable [..., object ]
369+ perfs : list [float ]
360370 flat_values : FlatConfig
361371 config : Config
362372
373+ @property
374+ def perf (self ) -> float :
375+ return self .perfs [- 1 ]
376+
363377
364378def performance (member : PopulationMember ) -> float :
365379 """
@@ -420,7 +434,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
420434 A population member with the benchmark results.
421435 """
422436 config = self .config_gen .unflatten (flat_values )
423- return PopulationMember (self .benchmark (config ), flat_values , config )
437+ fn , perf = self .benchmark (config )
438+ return PopulationMember (fn , [perf ], flat_values , config )
424439
425440 def parallel_benchmark_flat (
426441 self , to_check : list [FlatConfig ]
@@ -436,13 +451,65 @@ def parallel_benchmark_flat(
436451 """
437452 configs = [* map (self .config_gen .unflatten , to_check )]
438453 result = []
439- for flat_values , config_in , (config_out , perf ) in zip (
454+ for flat_values , config_in , (config_out , fn , perf ) in zip (
440455 to_check , configs , self .parallel_benchmark (configs ), strict = True
441456 ):
442457 assert config_in is config_out
443- result .append (PopulationMember (perf , flat_values , config_in ))
458+ result .append (PopulationMember (fn , [ perf ] , flat_values , config_in ))
444459 return result
445460
461+ def compare (self , a : PopulationMember , b : PopulationMember ) -> int :
462+ """
463+ Compare two population members based on their performance, possibly with re-benchmarking.
464+
465+ Args:
466+ a: The first population member.
467+ b: The second population member.
468+
469+ Returns:
470+ -1 if a is better than b, 1 if b is better than a, 0 if they are equal.
471+ """
472+ if self .should_rebenchmark (a ) and self .should_rebenchmark (b ):
473+ self .rebenchmark ([a , b ])
474+ return (a .perf > b .perf ) - (a .perf < b .perf )
475+
476+ def should_rebenchmark (self , member : PopulationMember ) -> bool :
477+ """
478+ Determine if a population member should be re-benchmarked to avoid outliers.
479+
480+ Args:
481+ member: The population member to check.
482+
483+ Returns:
484+ True if the member should be re-benchmarked, False otherwise.
485+ """
486+ return (
487+ member .perf
488+ < self .settings .autotune_rebenchmark_threshold * self .best_perf_so_far
489+ and math .isfinite (member .perf )
490+ )
491+
492+ def rebenchmark (self , members : list [PopulationMember ]) -> None :
493+ """
494+ Re-benchmark a list of population members to avoid outliers.
495+ """
496+ if len (members ) < 2 :
497+ return
498+ repeat = max (3 , int (200 / self .best_perf_so_far ))
499+ new_timings = interleaved_bench (
500+ [functools .partial (m .fn , * self .args ) for m in members ], repeat = repeat
501+ )
502+ for m , t in zip (members , new_timings , strict = True ):
503+ m .perfs .append (t )
504+ if t < self .best_perf_so_far :
505+ self .best_perf_so_far = t
506+
507+ def rebenchmark_population (self ) -> None :
508+ """
509+ Re-benchmark the entire population to avoid outliers.
510+ """
511+ self .rebenchmark ([p for p in self .population if self .should_rebenchmark (p )])
512+
446513 def statistics (self ) -> str :
447514 """
448515 Generate statistics for the current population.
0 commit comments