Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 120 additions & 21 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import sys
import time
from typing import TYPE_CHECKING
from typing import NamedTuple
from typing import Callable
from typing import NoReturn

from .benchmarking import interleaved_bench

if TYPE_CHECKING:
from triton.runtime.jit import JITFunction

Expand Down Expand Up @@ -90,6 +92,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
self.args: Sequence[object] = args
self.counters: collections.Counter[str] = collections.Counter()
self.log = LambdaLogger(self.settings.autotune_log_level)
self.best_perf_so_far = inf
seed = self.settings.autotune_random_seed
random.seed(seed)
self.log(f"Autotune random seed: {seed}")
Expand Down Expand Up @@ -167,7 +170,7 @@ def _validate_against_baseline(
return False
return True

def benchmark(self, config: Config) -> float:
def benchmark(self, config: Config) -> tuple[Callable[..., object], float]:
"""
Benchmark a specific configuration.

Expand All @@ -177,12 +180,12 @@ def benchmark(self, config: Config) -> float:
config: The configuration to benchmark.

Returns:
The performance of the configuration in seconds.
The function and performance of the configuration in ms.
"""
fn = self.kernel.compile_config(config, allow_print=False)
if self.start_precompile_and_check_for_hangs(config, fn)():
return self.benchmark_function(config, fn)
return inf
return fn, self.benchmark_function(config, fn)
return fn, inf

def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
"""
Expand All @@ -194,7 +197,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
fn: A precompiled version of config.

Returns:
The performance of the configuration in seconds.
The performance of the configuration in ms.
"""
self.counters["benchmark"] += 1
self.log.debug(lambda: f"Running benchmark for {config!r}")
Expand All @@ -214,12 +217,17 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
res = do_bench(
functools.partial(fn, *self.args),
return_mode="median",
warmup=1, # we are already warmed up above
rep=50,
)
t2 = time.perf_counter()
assert isinstance(res, float)
self.log.debug(
lambda: f"result: {res:.4f}ms (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
)
return res # pyright: ignore[reportReturnType]
if res < self.best_perf_so_far:
self.best_perf_so_far = res
return res
except Exception as e:
action = classify_triton_exception(e)
if action == "raise":
Expand Down Expand Up @@ -286,7 +294,9 @@ def extract_launcher(
timeout=self.settings.autotune_compile_timeout,
)

def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]]:
def parallel_benchmark(
self, configs: list[Config]
) -> list[tuple[Config, Callable[..., object], float]]:
"""
Benchmark multiple configurations in parallel.

Expand All @@ -312,9 +322,9 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
for config, fn, is_working in zip(configs, fns, is_workings, strict=True):
if is_working:
# benchmark one-by-one to avoid noisy results
results.append((config, self.benchmark_function(config, fn)))
results.append((config, fn, self.benchmark_function(config, fn)))
else:
results.append((config, inf))
results.append((config, fn, inf))
return results

def autotune(self) -> Config:
Expand Down Expand Up @@ -356,20 +366,26 @@ def _autotune(self) -> Config:
raise NotImplementedError


class PopulationMember(NamedTuple):
@dataclasses.dataclass
class PopulationMember:
"""
Represents a member of the population in population-based search algorithms.

Attributes:
perf (float): The performance of the configuration.
perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks.
flat_values (FlatConfig): The flat representation of the configuration values.
config (Config): The full configuration object.
"""

perf: float
fn: Callable[..., object]
perfs: list[float]
flat_values: FlatConfig
config: Config

@property
def perf(self) -> float:
return self.perfs[-1]


def performance(member: PopulationMember) -> float:
"""
Expand Down Expand Up @@ -430,7 +446,8 @@ def benchmark_flat(self, flat_values: FlatConfig) -> PopulationMember:
A population member with the benchmark results.
"""
config = self.config_gen.unflatten(flat_values)
return PopulationMember(self.benchmark(config), flat_values, config)
fn, perf = self.benchmark(config)
return PopulationMember(fn, [perf], flat_values, config)

def parallel_benchmark_flat(
self, to_check: list[FlatConfig]
Expand All @@ -444,14 +461,92 @@ def parallel_benchmark_flat(
Returns:
A list of population members with the benchmark results.
"""
configs = [*map(self.config_gen.unflatten, to_check)]
result = []
for flat_values, config_in, (config_out, perf) in zip(
to_check, configs, self.parallel_benchmark(configs), strict=True
result = [*map(self.make_unbenchmarked, to_check)]
return self.parallel_benchmark_population(result)

def make_unbenchmarked(self, flat_values: FlatConfig) -> PopulationMember:
"""
Create a population member with unbenchmarked configuration. You
should pass the result of this to parallel_benchmark_population.

Args:
flat_values: The flat configuration values.

Returns:
A population member with undefined performance.
"""
config = self.config_gen.unflatten(flat_values)
return PopulationMember(_unset_fn, [], flat_values, config)

def parallel_benchmark_population(
self, members: list[PopulationMember]
) -> list[PopulationMember]:
"""
Benchmark multiple population members in parallel. Members should be created with make_unbenchmarked.
"""
for member, (config_out, fn, perf) in zip(
members, self.parallel_benchmark([m.config for m in members]), strict=True
):
assert config_in is config_out
result.append(PopulationMember(perf, flat_values, config_in))
return result
assert config_out is member.config
member.perfs.append(perf)
member.fn = fn
return members

def compare(self, a: PopulationMember, b: PopulationMember) -> int:
"""
Compare two population members based on their performance, possibly with re-benchmarking.

Args:
a: The first population member.
b: The second population member.

Returns:
-1 if a is better than b, 1 if b is better than a, 0 if they are equal.
"""
if self.should_rebenchmark(a) and self.should_rebenchmark(b):
self.rebenchmark([a, b])
return (a.perf > b.perf) - (a.perf < b.perf)

def should_rebenchmark(self, member: PopulationMember) -> bool:
"""
Determine if a population member should be re-benchmarked to avoid outliers.

Args:
member: The population member to check.

Returns:
True if the member should be re-benchmarked, False otherwise.
"""
return (
member.perf
< self.settings.autotune_rebenchmark_threshold * self.best_perf_so_far
and math.isfinite(member.perf)
)

def rebenchmark(self, members: list[PopulationMember]) -> None:
"""
Re-benchmark a list of population members to avoid outliers.
"""
if len(members) < 2:
return
repeat = max(3, int(200 / self.best_perf_so_far))
new_timings = interleaved_bench(
[functools.partial(m.fn, *self.args) for m in members], repeat=repeat
)
for m, t in zip(members, new_timings, strict=True):
m.perfs.append(t)
if t < self.best_perf_so_far:
self.best_perf_so_far = t

def rebenchmark_population(
self, members: list[PopulationMember] | None = None
) -> None:
"""
Re-benchmark the entire population to avoid outliers.
"""
if members is None:
members = self.population
self.rebenchmark([p for p in members if self.should_rebenchmark(p)])

def statistics(self) -> str:
"""
Expand Down Expand Up @@ -697,3 +792,7 @@ def __init__(
self.grid = grid
self.args = args
self.kwargs = kwargs


def _unset_fn(*args: object) -> NoReturn:
raise RuntimeError("Uninitialized function")
49 changes: 49 additions & 0 deletions helion/autotuner/benchmarking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import functools
import statistics
from typing import Callable

from triton import runtime


def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[float]:
"""
Benchmark multiple functions at once, interleaving their executions to reduce
the impact of external factors (e.g., load, temperature) on the
measurements.
"""
# warmup
for fn in fns:
fn()
clear_cache = functools.partial(
runtime.driver.active.clear_cache, # type: ignore[attr-defined]
runtime.driver.active.get_empty_cache_for_benchmark(), # type: ignore[attr-defined]
)
clear_cache()
di = runtime.driver.active.get_device_interface() # type: ignore[attr-defined]
start_events = [
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
]
end_events = [
[di.Event(enable_timing=True) for _ in range(repeat)] for _ in range(len(fns))
]

di.synchronize()
for i in range(repeat):
for j in range(len(fns)):
clear_cache()
start_events[j][i].record()
fns[j]()
end_events[j][i].record()
di.synchronize()

return [
statistics.median(
[
s.elapsed_time(e)
for s, e in zip(start_events[j], end_events[j], strict=True)
]
)
for j in range(len(fns))
]
3 changes: 2 additions & 1 deletion helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def iter_candidates(self) -> Iterator[tuple[int, PopulationMember]]:
def evolve_population(self) -> int:
replaced = 0
for i, candidate in self.iter_candidates():
if candidate.perf < self.population[i].perf:
if self.compare(candidate, self.population[i]) < 0:
self.population[i] = candidate
replaced += 1
return replaced
Expand All @@ -97,4 +97,5 @@ def _autotune(self) -> Config:
for i in range(2, self.num_generations):
replaced = self.evolve_population()
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
self.rebenchmark_population()
return self.best.config
2 changes: 1 addition & 1 deletion helion/autotuner/finite_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
def _autotune(self) -> Config:
best_config = None
best_time = float("inf")
for config, time in self.parallel_benchmark(self.configs):
for config, _fn, time in self.parallel_benchmark(self.configs):
if time < best_time:
best_time = time
best_config = config
Expand Down
4 changes: 4 additions & 0 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class _Settings:
autotune_accuracy_check: bool = (
os.environ.get("HELION_AUTOTUNE_ACCURACY_CHECK", "1") == "1"
)
autotune_rebenchmark_threshold: float = float(
os.environ.get("HELION_REBENCHMARK_THRESHOLD", "1.5")
)
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
allow_warp_specialize: bool = (
Expand Down Expand Up @@ -131,6 +134,7 @@ class Settings(_Settings):
"autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.",
"autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.",
"autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.",
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.",
"print_output_code": "If True, print the output code of the kernel to stderr.",
"force_autotune": "If True, force autotuning even if a config is provided.",
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
Expand Down
8 changes: 4 additions & 4 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ def make_bad_config_produce_wrong_output(
search = FiniteSearch(
bound_kernel, (a, b), configs=[bad_config, good_config]
)
bad_time = search.benchmark(bad_config)
_, bad_time = search.benchmark(bad_config)
assert math.isinf(bad_time)
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1)
search.counters["accuracy_mismatch"] = 0 # reset counter

good_time = search.benchmark(good_config)
_, good_time = search.benchmark(good_config)
assert not math.isinf(good_time)
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0)
search.counters["accuracy_mismatch"] = 0 # reset counter
Expand Down Expand Up @@ -259,12 +259,12 @@ def wrong_fn(*fn_args, **fn_kwargs):
search = FiniteSearch(
bound_kernel, (a, b), configs=[bad_config, good_config]
)
bad_time = search.benchmark(bad_config)
_, bad_time = search.benchmark(bad_config)
assert math.isinf(bad_time)
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1)
search.counters["accuracy_mismatch"] = 0 # reset counter

good_time = search.benchmark(good_config)
_, good_time = search.benchmark(good_config)
assert not math.isinf(good_time)
self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0)
search.counters["accuracy_mismatch"] = 0 # reset counter
Expand Down
9 changes: 8 additions & 1 deletion test/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ def fake_parallel(
members = []
for flat_values in to_check:
cfg = self.config_gen.unflatten(flat_values)
members.append(PopulationMember(float("inf"), flat_values, cfg))
members.append(
PopulationMember(
lambda *args: None,
[float("inf")],
flat_values,
cfg,
)
)
return members

with (
Expand Down
Loading