1515import sys
1616import time
1717from typing import TYPE_CHECKING
18+ from typing import Callable
1819from typing import NamedTuple
1920from typing import NoReturn
2021
22+ from .benchmarking import interleaved_bench
23+
2124if TYPE_CHECKING :
2225 from triton .runtime .jit import JITFunction
2326
@@ -90,6 +93,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
9093 self .args : Sequence [object ] = args
9194 self .counters : collections .Counter [str ] = collections .Counter ()
9295 self .log = LambdaLogger (self .settings .autotune_log_level )
96+ self .best_perf_so_far = inf
9397 seed = self .settings .autotune_random_seed
9498 random .seed (seed )
9599 self .log (f"Autotune random seed: { seed } " )
@@ -167,7 +171,7 @@ def _validate_against_baseline(
167171 return False
168172 return True
169173
170- def benchmark (self , config : Config ) -> float :
174+ def benchmark (self , config : Config ) -> tuple [ Callable [..., object ], float ] :
171175 """
172176 Benchmark a specific configuration.
173177
@@ -177,12 +181,12 @@ def benchmark(self, config: Config) -> float:
177181 config: The configuration to benchmark.
178182
179183 Returns:
180- The performance of the configuration in seconds .
184+ The function and performance of the configuration in ms .
181185 """
182186 fn = self .kernel .compile_config (config , allow_print = False )
183187 if self .start_precompile_and_check_for_hangs (config , fn )():
184- return self .benchmark_function (config , fn )
185- return inf
188+ return fn , self .benchmark_function (config , fn )
189+ return fn , inf
186190
187191 def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
188192 """
@@ -194,7 +198,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
194198 fn: A precompiled version of config.
195199
196200 Returns:
197- The performance of the configuration in seconds .
201+ The performance of the configuration in ms .
198202 """
199203 self .counters ["benchmark" ] += 1
200204 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -216,10 +220,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
216220 return_mode = "median" ,
217221 )
218222 t2 = time .perf_counter ()
223+ assert isinstance (res , float )
219224 self .log .debug (
220225 lambda : f"result: { res :.4f} ms (took { t1 - t0 :.1f} s + { t2 - t1 :.1f} s)" ,
221226 )
222- return res # pyright: ignore[reportReturnType]
227+ if res < self .best_perf_so_far :
228+ self .best_perf_so_far = res
229+ return res
223230 except Exception as e :
224231 action = classify_triton_exception (e )
225232 if action == "raise" :
@@ -286,7 +293,9 @@ def extract_launcher(
286293 timeout = self .settings .autotune_compile_timeout ,
287294 )
288295
289- def parallel_benchmark (self , configs : list [Config ]) -> list [tuple [Config , float ]]:
296+ def parallel_benchmark (
297+ self , configs : list [Config ]
298+ ) -> list [tuple [Config , Callable [..., object ], float ]]:
290299 """
291300 Benchmark multiple configurations in parallel.
292301
@@ -312,9 +321,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
312321 for config , fn , is_working in zip (configs , fns , is_workings , strict = True ):
313322 if is_working :
314323 # benchmark one-by-one to avoid noisy results
315- results .append ((config , self .benchmark_function (config , fn )))
324+ results .append ((config , fn , self .benchmark_function (config , fn )))
316325 else :
317- results .append ((config , inf ))
326+ results .append ((config , fn , inf ))
318327 return results
319328
320329 def autotune (self ) -> Config :
@@ -361,15 +370,20 @@ class PopulationMember(NamedTuple):
361370 Represents a member of the population in population-based search algorithms.
362371
363372 Attributes:
364- perf ( float): The performance of the configuration.
373+ perfs (list[ float] ): The performance of the configuration, accumulated over multiple benchmarks .
365374 flat_values (FlatConfig): The flat representation of the configuration values.
366375 config (Config): The full configuration object.
367376 """
368377
369- perf : float
378+ fn : Callable [..., object ]
379+ perfs : list [float ]
370380 flat_values : FlatConfig
371381 config : Config
372382
383+ @property
384+ def perf (self ) -> float :
385+ return self .perfs [- 1 ]
386+
373387
374388def performance (member : PopulationMember ) -> float :
375389 """
@@ -430,7 +444,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
430444 A population member with the benchmark results.
431445 """
432446 config = self .config_gen .unflatten (flat_values )
433- return PopulationMember (self .benchmark (config ), flat_values , config )
447+ fn , perf = self .benchmark (config )
448+ return PopulationMember (fn , [perf ], flat_values , config )
434449
435450 def parallel_benchmark_flat (
436451 self , to_check : list [FlatConfig ]
@@ -446,13 +461,65 @@ def parallel_benchmark_flat(
446461 """
447462 configs = [* map (self .config_gen .unflatten , to_check )]
448463 result = []
449- for flat_values , config_in , (config_out , perf ) in zip (
464+ for flat_values , config_in , (config_out , fn , perf ) in zip (
450465 to_check , configs , self .parallel_benchmark (configs ), strict = True
451466 ):
452467 assert config_in is config_out
453- result .append (PopulationMember (perf , flat_values , config_in ))
468+ result .append (PopulationMember (fn , [ perf ] , flat_values , config_in ))
454469 return result
455470
471+ def compare (self , a : PopulationMember , b : PopulationMember ) -> int :
472+ """
473+ Compare two population members based on their performance, possibly with re-benchmarking.
474+
475+ Args:
476+ a: The first population member.
477+ b: The second population member.
478+
479+ Returns:
480+ -1 if a is better than b, 1 if b is better than a, 0 if they are equal.
481+ """
482+ if self .should_rebenchmark (a ) and self .should_rebenchmark (b ):
483+ self .rebenchmark ([a , b ])
484+ return (a .perf > b .perf ) - (a .perf < b .perf )
485+
486+ def should_rebenchmark (self , member : PopulationMember ) -> bool :
487+ """
488+ Determine if a population member should be re-benchmarked to avoid outliers.
489+
490+ Args:
491+ member: The population member to check.
492+
493+ Returns:
494+ True if the member should be re-benchmarked, False otherwise.
495+ """
496+ return (
497+ member .perf
498+ < self .settings .autotune_rebenchmark_threshold * self .best_perf_so_far
499+ and math .isfinite (member .perf )
500+ )
501+
502+ def rebenchmark (self , members : list [PopulationMember ]) -> None :
503+ """
504+ Re-benchmark a list of population members to avoid outliers.
505+ """
506+ if len (members ) < 2 :
507+ return
508+ repeat = max (3 , int (200 / self .best_perf_so_far ))
509+ new_timings = interleaved_bench (
510+ [functools .partial (m .fn , * self .args ) for m in members ], repeat = repeat
511+ )
512+ for m , t in zip (members , new_timings , strict = True ):
513+ m .perfs .append (t )
514+ if t < self .best_perf_so_far :
515+ self .best_perf_so_far = t
516+
517+ def rebenchmark_population (self ) -> None :
518+ """
519+ Re-benchmark the entire population to avoid outliers.
520+ """
521+ self .rebenchmark ([p for p in self .population if self .should_rebenchmark (p )])
522+
456523 def statistics (self ) -> str :
457524 """
458525 Generate statistics for the current population.
0 commit comments