1515import sys
1616import time
1717from typing import TYPE_CHECKING
18+ from typing import Callable
1819from typing import NamedTuple
1920from typing import NoReturn
2021
22+ from .benchmarking import interleaved_bench
23+
2124if TYPE_CHECKING :
2225 from triton .runtime .jit import JITFunction
2326
@@ -85,9 +88,10 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
8588 self .args = args
8689 self .counters : collections .Counter [str ] = collections .Counter ()
8790 self .log = LambdaLogger (self .settings .autotune_log_level )
91+ self .best_perf_so_far = inf
8892 random .seed (self .settings .autotune_random_seed )
8993
90- def benchmark (self , config : Config ) -> float :
94+ def benchmark (self , config : Config ) -> tuple [ Callable [..., object ], float ] :
9195 """
9296 Benchmark a specific configuration.
9397
@@ -97,12 +101,12 @@ def benchmark(self, config: Config) -> float:
97101 config: The configuration to benchmark.
98102
99103 Returns:
100- The performance of the configuration in seconds .
104+ The functiona and performance of the configuration in ms .
101105 """
102106 fn = self .kernel .compile_config (config , allow_print = False )
103107 if self .start_precompile_and_check_for_hangs (config , fn )():
104- return self .benchmark_function (config , fn )
105- return inf
108+ return fn , self .benchmark_function (config , fn )
109+ return fn , inf
106110
107111 def benchmark_function (self , config : Config , fn : CompiledConfig ) -> float :
108112 """
@@ -114,7 +118,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
114118 fn: A precompiled version of config.
115119
116120 Returns:
117- The performance of the configuration in seconds .
121+ The performance of the configuration in ms .
118122 """
119123 self .counters ["benchmark" ] += 1
120124 self .log .debug (lambda : f"Running benchmark for { config !r} " )
@@ -128,10 +132,13 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
128132 return_mode = "median" ,
129133 )
130134 t2 = time .perf_counter ()
135+ assert isinstance (res , float )
131136 self .log .debug (
132137 lambda : f"result: { res :.4f} ms (took { t1 - t0 :.1f} s + { t2 - t1 :.1f} s)" ,
133138 )
134- return res # pyright: ignore[reportReturnType]
139+ if res < self .best_perf_so_far :
140+ self .best_perf_so_far = res
141+ return res
135142 except Exception as e :
136143 action = classify_triton_exception (e )
137144 if action == "raise" :
@@ -198,7 +205,9 @@ def extract_launcher(
198205 timeout = self .settings .autotune_compile_timeout ,
199206 )
200207
201- def parallel_benchmark (self , configs : list [Config ]) -> list [tuple [Config , float ]]:
208+ def parallel_benchmark (
209+ self , configs : list [Config ]
210+ ) -> list [tuple [Config , Callable [..., object ], float ]]:
202211 """
203212 Benchmark multiple configurations in parallel.
204213
@@ -224,9 +233,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
224233 for config , fn , is_working in zip (configs , fns , is_workings , strict = True ):
225234 if is_working :
226235 # benchmark one-by-one to avoid noisy results
227- results .append ((config , self .benchmark_function (config , fn )))
236+ results .append ((config , fn , self .benchmark_function (config , fn )))
228237 else :
229- results .append ((config , inf ))
238+ results .append ((config , fn , inf ))
230239 return results
231240
232241 def autotune (self ) -> Config :
@@ -271,15 +280,20 @@ class PopulationMember(NamedTuple):
271280 Represents a member of the population in population-based search algorithms.
272281
273282 Attributes:
274- perf ( float): The performance of the configuration.
283+ perfs (list[ float] ): The performance of the configuration, accumulated over multiple benchmarks .
275284 flat_values (FlatConfig): The flat representation of the configuration values.
276285 config (Config): The full configuration object.
277286 """
278287
279- perf : float
288+ fn : Callable [..., object ]
289+ perfs : list [float ]
280290 flat_values : FlatConfig
281291 config : Config
282292
293+ @property
294+ def perf (self ) -> float :
295+ return self .perfs [- 1 ]
296+
283297
284298def performance (member : PopulationMember ) -> float :
285299 """
@@ -340,7 +354,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
340354 A population member with the benchmark results.
341355 """
342356 config = self .config_gen .unflatten (flat_values )
343- return PopulationMember (self .benchmark (config ), flat_values , config )
357+ fn , perf = self .benchmark (config )
358+ return PopulationMember (fn , [perf ], flat_values , config )
344359
345360 def parallel_benchmark_flat (
346361 self , to_check : list [FlatConfig ]
@@ -356,13 +371,65 @@ def parallel_benchmark_flat(
356371 """
357372 configs = [* map (self .config_gen .unflatten , to_check )]
358373 result = []
359- for flat_values , config_in , (config_out , perf ) in zip (
374+ for flat_values , config_in , (config_out , fn , perf ) in zip (
360375 to_check , configs , self .parallel_benchmark (configs ), strict = True
361376 ):
362377 assert config_in is config_out
363- result .append (PopulationMember (perf , flat_values , config_in ))
378+ result .append (PopulationMember (fn , [ perf ] , flat_values , config_in ))
364379 return result
365380
381+ def compare (self , a : PopulationMember , b : PopulationMember ) -> int :
382+ """
383+ Compare two population members based on their performance, possibly with re-benchmarking.
384+
385+ Args:
386+ a: The first population member.
387+ b: The second population member.
388+
389+ Returns:
390+ -1 if a is better than b, 1 if b is better than a, 0 if they are equal.
391+ """
392+ if self .should_rebenchmark (a ) and self .should_rebenchmark (b ):
393+ self .rebenchmark ([a , b ])
394+ return (a .perf > b .perf ) - (a .perf < b .perf )
395+
396+ def should_rebenchmark (self , member : PopulationMember ) -> bool :
397+ """
398+ Determine if a population member should be re-benchmarked to avoid outliers.
399+
400+ Args:
401+ member: The population member to check.
402+
403+ Returns:
404+ True if the member should be re-benchmarked, False otherwise.
405+ """
406+ return (
407+ member .perf
408+ < self .settings .autotune_rebenchmark_threshold * self .best_perf_so_far
409+ and math .isfinite (member .perf )
410+ )
411+
412+ def rebenchmark (self , members : list [PopulationMember ]) -> None :
413+ """
414+ Re-benchmark a list of population members to avoid outliers.
415+ """
416+ if len (members ) < 2 :
417+ return
418+ repeat = max (3 , int (200 / self .best_perf_so_far ))
419+ new_timings = interleaved_bench (
420+ [functools .partial (m .fn , * self .args ) for m in members ], repeat = repeat
421+ )
422+ for m , t in zip (members , new_timings , strict = True ):
423+ m .perfs .append (t )
424+ if t < self .best_perf_so_far :
425+ self .best_perf_so_far = t
426+
427+ def rebenchmark_population (self ) -> None :
428+ """
429+ Re-benchmark the entire population to avoid outliers.
430+ """
431+ self .rebenchmark ([p for p in self .population if self .should_rebenchmark (p )])
432+
366433 def statistics (self ) -> str :
367434 """
368435 Generate statistics for the current population.
0 commit comments