diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0a739548a..9c4e73bb0 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -100,7 +100,7 @@ jobs: - name: Install Helion run: | source .venv/bin/activate - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]' python -c "import helion; print(helion.__name__)" - name: Install Benchmark Requirements diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d1bbdc223..51e196fcf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,7 @@ jobs: run: | source .venv/bin/activate uv pip install setuptools ninja - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]' python -c "import helion; print(helion.__name__)" - name: Run Tests diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 541f4a787..52a0c672e 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -6,6 +6,7 @@ from .config_fragment import ListOf as ListOf from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment from .config_spec import ConfigSpec as ConfigSpec +from .de_surrogate_hybrid import DESurrogateHybrid as DESurrogateHybrid from .differential_evolution import ( DifferentialEvolutionSearch as DifferentialEvolutionSearch, ) @@ -20,6 +21,7 @@ from .random_search import RandomSearch as RandomSearch search_algorithms = { + "DESurrogateHybrid": DESurrogateHybrid, "DifferentialEvolutionSearch": DifferentialEvolutionSearch, "FiniteSearch": FiniteSearch, "PatternSearch": PatternSearch, diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 01b96ae4a..ae2c5648e 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -624,6 +624,7 @@ class PopulationMember: perfs (list[float]): The performance of the configuration, accumulated over multiple benchmarks. flat_values (FlatConfig): The flat representation of the configuration values. config (Config): The full configuration object. + compile_time (float | None): The compilation time for this configuration. """ fn: Callable[..., object] diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index c8bddf2b3..c58df1265 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -57,6 +57,26 @@ def get_minimum(self) -> int: """ raise NotImplementedError + def encode_scalar(self, value: object) -> float: + """ + Encode a configuration value into a float for ML models. + + This is used by surrogate-assisted algorithms to convert configurations + into numerical vectors for prediction models. + + Args: + value: The configuration value to encode. + + Returns: + A float representing the encoded value. + """ + # Default: convert to float if possible + if not isinstance(value, (int, float, bool)): + raise TypeError( + f"Cannot encode {type(value).__name__} value {value!r} for ML" + ) + return float(value) + @dataclasses.dataclass class PermutationFragment(ConfigSpecFragment): @@ -121,6 +141,14 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(upper) return neighbors + def encode_scalar(self, value: object) -> float: + """Encode integer values directly as floats.""" + if not isinstance(value, (int, float)): + raise TypeError( + f"Expected int/float for BaseIntegerFragment, got {type(value).__name__}: {value!r}" + ) + return float(value) + class PowerOfTwoFragment(BaseIntegerFragment): def random(self) -> int: @@ -152,6 +180,20 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(ai * 2) return ai + def encode_scalar(self, value: object) -> float: + """Encode power-of-2 values using log2 transformation.""" + import math + + if not isinstance(value, (int, float)): + raise TypeError( + f"Expected int/float for PowerOfTwoFragment, got {type(value).__name__}: {value!r}" + ) + if value <= 0: + raise ValueError( + f"Expected positive value for PowerOfTwoFragment, got {value}" + ) + return math.log2(float(value)) + class IntegerFragment(BaseIntegerFragment): def random(self) -> int: @@ -193,6 +235,17 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: choices.remove(a) return random.choice(choices) + def encode_scalar(self, value: object) -> float: + """Encode enum values as their index.""" + try: + choice_idx = self.choices.index(value) + except ValueError: + raise ValueError( + f"Invalid enum value {value!r} for EnumFragment. " + f"Valid choices: {self.choices}" + ) from None + return float(choice_idx) + class BooleanFragment(ConfigSpecFragment): def default(self) -> bool: diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index 505f95da5..0747f41a9 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -181,3 +181,24 @@ def differential_mutation( # TODO(jansel): can this be larger? (too large and Triton compile times blow up) self.shrink_config(result, 8192) return result + + def encode_config(self, flat_config: FlatConfig) -> list[float]: + """ + Encode a flat configuration into a numerical vector for ML models. + + This is used by surrogate-assisted algorithms (e.g., DE-Surrogate) that need + to represent configurations as continuous vectors for prediction models. + + Args: + flat_config: The flat configuration values to encode. + + Returns: + A list of floats representing the encoded configuration. + """ + encoded: list[float] = [] + + for flat_idx, spec in enumerate(self.flat_spec): + value = flat_config[flat_idx] + encoded.append(spec.encode_scalar(value)) + + return encoded diff --git a/helion/autotuner/de_surrogate_hybrid.py b/helion/autotuner/de_surrogate_hybrid.py new file mode 100644 index 000000000..b04cbfb9b --- /dev/null +++ b/helion/autotuner/de_surrogate_hybrid.py @@ -0,0 +1,311 @@ +""" +Differential Evolution with Surrogate-Assisted Selection (DE-SAS). + +This hybrid approach combines the robust exploration of Differential Evolution +with the sample efficiency of surrogate models. It's designed to beat standard DE +by making smarter decisions about which candidates to evaluate. + +Key idea: +- Use DE's mutation/crossover to generate candidates (good exploration) +- Use a Random Forest surrogate to predict which candidates are promising +- Only evaluate the most promising candidates (sample efficiency) +- Periodically re-fit the surrogate model + +This is inspired by recent work on surrogate-assisted evolutionary algorithms, +which have shown 2-5× speedups over standard EAs on expensive optimization problems. + +References: +- Jin, Y. (2011). "Surrogate-assisted evolutionary computation: Recent advances and future challenges." +- Sun, C., et al. (2019). "A surrogate-assisted DE with an adaptive local search" + +Author: Francisco Geiman Thiesen +Date: 2025-11-05 +""" + +from __future__ import annotations + +import math +import operator +import random +from typing import TYPE_CHECKING +from typing import Any + +from .differential_evolution import DifferentialEvolutionSearch + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..runtime.kernel import BoundKernel + from .config_generation import Config + from .config_generation import FlatConfig + +try: + import numpy as np # type: ignore[import-not-found] + from sklearn.ensemble import RandomForestRegressor # type: ignore[import-not-found] + + HAS_ML_DEPS = True +except ImportError: + HAS_ML_DEPS = False + np = None # type: ignore[assignment] + RandomForestRegressor = None # type: ignore[assignment,misc] + + +class DESurrogateHybrid(DifferentialEvolutionSearch): + """ + Hybrid Differential Evolution with Surrogate-Assisted Selection. + + This algorithm uses DE for exploration but adds a surrogate model to intelligently + select which candidates to actually evaluate, avoiding wasting evaluations on + poor candidates. + + Args: + kernel: The bound kernel to tune + args: Arguments for the kernel + population_size: Size of the DE population (default: 40) + max_generations: Maximum number of generations (default: 40) + crossover_rate: Crossover probability (default: 0.8) + surrogate_threshold: Use surrogate after this many evaluations (default: 100) + candidate_ratio: Generate this many× candidates per slot (default: 3) + refit_frequency: Refit surrogate every N generations (default: 5) + n_estimators: Number of trees in Random Forest (default: 50) + min_improvement_delta: Relative improvement threshold for early stopping. + Default: 0.001 (0.1%). Early stopping enabled by default. + patience: Number of generations without improvement before stopping. + Default: 3. Early stopping enabled by default. + """ + + def __init__( + self, + kernel: BoundKernel, + args: Sequence[object], + population_size: int = 40, + max_generations: int = 40, + crossover_rate: float = 0.8, + surrogate_threshold: int = 100, + candidate_ratio: int = 3, + refit_frequency: int = 5, + n_estimators: int = 50, + min_improvement_delta: float = 0.001, + patience: int = 3, + ) -> None: + if not HAS_ML_DEPS: + raise ImportError( + "DESurrogateHybrid requires numpy and scikit-learn. " + "Install them with: pip install helion[de-surrogate]" + ) + + # Initialize parent with early stopping parameters + super().__init__( + kernel, + args, + population_size=population_size, + max_generations=max_generations, + crossover_rate=crossover_rate, + min_improvement_delta=min_improvement_delta, + patience=patience, + ) + + self.surrogate_threshold = surrogate_threshold + self.candidate_ratio = candidate_ratio + self.refit_frequency = refit_frequency + self.n_estimators = n_estimators + + # Surrogate model + self.surrogate: Any = None + + # Track all evaluations for surrogate training + self.all_observations: list[tuple[FlatConfig, float]] = [] + + def _autotune(self) -> Config: + """ + Run DE with surrogate-assisted selection. + + Returns: + Best configuration found + """ + self.log("=" * 70) + self.log("Differential Evolution with Surrogate-Assisted Selection") + self.log("=" * 70) + self.log(f"Population: {self.population_size}") + self.log(f"Generations: {self.max_generations}") + self.log(f"Crossover rate: {self.crossover_rate}") + self.log(f"Surrogate activation: after {self.surrogate_threshold} evals") + self.log(f"Candidate oversampling: {self.candidate_ratio}× per slot") + self.log( + f"Early stopping: delta={self.min_improvement_delta}, patience={self.patience}" + ) + self.log("=" * 70) + + # Initialize population + self.set_generation(0) + self.initial_two_generations() + + # Track initial observations for surrogate + for member in self.population: + if math.isfinite(member.perf): + self.all_observations.append((member.flat_values, member.perf)) + + # Initialize early stopping tracking + self.best_perf_history = [self.best.perf] + self.generations_without_improvement = 0 + + # Evolution loop + for gen in range(2, self.max_generations + 1): + self.set_generation(gen) + self._evolve_generation(gen) + + # Check for convergence + if self.check_early_stopping(): + break + + # Return best config + best = min(self.population, key=lambda m: m.perf) + self.log("=" * 70) + self.log(f"✓ Best configuration: {best.perf:.4f} ms") + self.log(f"Total evaluations: {len(self.all_observations)}") + self.log("=" * 70) + + return best.config + + def _evolve_generation(self, generation: int) -> None: + """Run one generation of DE with surrogate assistance.""" + + # Refit surrogate periodically + use_surrogate = len(self.all_observations) >= self.surrogate_threshold + if use_surrogate and (generation % self.refit_frequency == 0): + self._fit_surrogate() + + # Generate candidates using DE mutation/crossover + if use_surrogate: + # Generate more candidates and use surrogate to select best + n_candidates = self.population_size * self.candidate_ratio + candidates = self._generate_de_candidates(n_candidates) + selected_candidates = self._surrogate_select( + candidates, self.population_size + ) + else: + # Standard DE: generate and evaluate all + selected_candidates = self._generate_de_candidates(self.population_size) + + # Evaluate selected candidates + new_members = self.parallel_benchmark_flat(selected_candidates) + + # Track observations + for member in new_members: + if math.isfinite(member.perf): + self.all_observations.append((member.flat_values, member.perf)) + + # Selection: keep better of old vs new for each position + replacements = 0 + for i, new_member in enumerate(new_members): + if new_member.perf < self.population[i].perf: + self.population[i] = new_member + replacements += 1 + + # Log progress + best_perf = min(m.perf for m in self.population) + surrogate_status = "SURROGATE" if use_surrogate else "STANDARD" + self.log( + f"Gen {generation}: {surrogate_status} | " + f"best={best_perf:.4f} ms | replaced={replacements}/{self.population_size} | " + f"total_evals={len(self.all_observations)}" + ) + + def _generate_de_candidates(self, n_candidates: int) -> list[FlatConfig]: + """Generate candidates using standard DE mutation/crossover.""" + candidates = [] + + for _ in range(n_candidates): + # Select four distinct individuals: x (base), and a, b, c for mutation + x, a, b, c = random.sample(self.population, 4) + + # Differential mutation: x + F(a - b + c) + trial = self.config_gen.differential_mutation( + x.flat_values, + a.flat_values, + b.flat_values, + c.flat_values, + crossover_rate=self.crossover_rate, + ) + + candidates.append(trial) + + return candidates + + def _fit_surrogate(self) -> None: + """Fit Random Forest surrogate model on all observations.""" + if len(self.all_observations) < 10: + return # Need minimum data + + # Encode configs to numeric arrays + X = [] + y = [] + + for config, perf in self.all_observations: + try: + encoded = self.config_gen.encode_config(config) + X.append(encoded) + y.append(perf) + except Exception: + continue + + if len(X) < 10: + return + + X_array = np.array(X) # type: ignore[union-attr] + y_array = np.array(y) # type: ignore[union-attr] + + # Fit Random Forest + surrogate = RandomForestRegressor( # type: ignore[misc] + n_estimators=self.n_estimators, + max_depth=15, + min_samples_split=5, + min_samples_leaf=2, + random_state=42, + n_jobs=-1, + ) + surrogate.fit(X_array, y_array) + self.surrogate = surrogate + + def _surrogate_select( + self, candidates: list[FlatConfig], n_select: int + ) -> list[FlatConfig]: + """ + Use surrogate model to select most promising candidates. + + Args: + candidates: Pool of candidate configurations + n_select: Number of candidates to select + + Returns: + Selected candidates predicted to be best + """ + if self.surrogate is None: + # Fallback: random selection + return random.sample(candidates, min(n_select, len(candidates))) + + # Predict performance for all candidates + predictions = [] + + for config in candidates: + try: + encoded = self.config_gen.encode_config(config) + pred = self.surrogate.predict([encoded])[0] + predictions.append((config, pred)) + except Exception: + # Skip encoding failures + predictions.append((config, float("inf"))) + + # Sort by predicted performance (lower is better) + predictions.sort(key=operator.itemgetter(1)) + + # Select top n_select candidates + return [config for config, pred in predictions[:n_select]] + + def __repr__(self) -> str: + return ( + f"DESurrogateHybrid(pop={self.population_size}, " + f"gen={self.max_generations}, " + f"cr={self.crossover_rate}, " + f"surrogate_threshold={self.surrogate_threshold})" + ) diff --git a/helion/autotuner/differential_evolution.py b/helion/autotuner/differential_evolution.py index f17e42631..bb1e03c9a 100644 --- a/helion/autotuner/differential_evolution.py +++ b/helion/autotuner/differential_evolution.py @@ -31,7 +31,24 @@ def __init__( max_generations: int = DIFFERENTIAL_EVOLUTION_DEFAULTS.max_generations, crossover_rate: float = 0.8, immediate_update: bool | None = None, + min_improvement_delta: float | None = None, + patience: int | None = None, ) -> None: + """ + Create a DifferentialEvolutionSearch autotuner. + + Args: + kernel: The kernel to be autotuned. + args: The arguments to be passed to the kernel. + population_size: The size of the population. + max_generations: The maximum number of generations to run. + crossover_rate: The crossover rate for mutation. + immediate_update: Whether to update population immediately after each evaluation. + min_improvement_delta: Relative improvement threshold for early stopping. + If None (default), early stopping is disabled. + patience: Number of generations without improvement before stopping. + If None (default), early stopping is disabled. + """ super().__init__(kernel, args) if immediate_update is None: immediate_update = not bool(kernel.settings.autotune_precompile) @@ -39,6 +56,12 @@ def __init__( self.max_generations = max_generations self.crossover_rate = crossover_rate self.immediate_update = immediate_update + self.min_improvement_delta = min_improvement_delta + self.patience = patience + + # Early stopping state + self.best_perf_history: list[float] = [] + self.generations_without_improvement = 0 def mutate(self, x_index: int) -> FlatConfig: a, b, c, *_ = [ @@ -97,18 +120,84 @@ def evolve_population(self) -> int: replaced += 1 return replaced + def check_early_stopping(self) -> bool: + """ + Check if early stopping criteria are met and update state. + + This method updates best_perf_history and generations_without_improvement, + and returns whether the optimization should stop. + + Returns: + True if optimization should stop early, False otherwise. + """ + import math + + # Update history + current_best = self.best.perf + self.best_perf_history.append(current_best) + + if self.patience is None or len(self.best_perf_history) <= self.patience: + return False + + # Check improvement over last patience generations + past_best = self.best_perf_history[-self.patience - 1] + + if not ( + math.isfinite(current_best) + and math.isfinite(past_best) + and past_best != 0.0 + ): + return False + + relative_improvement = abs(current_best / past_best - 1.0) + + if ( + self.min_improvement_delta is not None + and relative_improvement < self.min_improvement_delta + ): + # No significant improvement + self.generations_without_improvement += 1 + if self.generations_without_improvement >= self.patience: + self.log( + f"Early stopping at generation {self._current_generation}: " + f"no improvement >{self.min_improvement_delta:.1%} for {self.patience} generations" + ) + return True + return False + + # Significant improvement - reset counter + self.generations_without_improvement = 0 + return False + def _autotune(self) -> Config: + early_stopping_enabled = ( + self.min_improvement_delta is not None and self.patience is not None + ) + self.log( lambda: ( f"Starting DifferentialEvolutionSearch with population={self.population_size}, " - f"generations={self.max_generations}, crossover_rate={self.crossover_rate}" + f"generations={self.max_generations}, crossover_rate={self.crossover_rate}, " + f"early_stopping=(delta={self.min_improvement_delta}, patience={self.patience})" ) ) + self.initial_two_generations() + + # Initialize early stopping tracking + if early_stopping_enabled: + self.best_perf_history = [self.best.perf] + self.generations_without_improvement = 0 + for i in range(2, self.max_generations): self.set_generation(i) self.log(f"Generation {i} starting") replaced = self.evolve_population() self.log(f"Generation {i} complete: replaced={replaced}", self.statistics) + + # Check for convergence (only if early stopping enabled) + if early_stopping_enabled and self.check_early_stopping(): + break + self.rebenchmark_population() return self.best.config diff --git a/pyproject.toml b/pyproject.toml index 0c898b8f0..fc063289f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,12 +20,14 @@ dependencies = [ "typing_extensions>=4.0.0", "filecheck", "psutil", - "filecheck", - "numpy", - "rich" + "rich", ] [project.optional-dependencies] +de-surrogate = [ + "numpy", + "scikit-learn>=1.3.0" +] dev = [ "expecttest", "pytest", diff --git a/test/test_autotuner.py b/test/test_autotuner.py index e95fde358..7c6fcd4d6 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -33,6 +33,7 @@ from helion._testing import import_path from helion._testing import skipIfCpu from helion._testing import skipIfRocm +from helion.autotuner import DESurrogateHybrid from helion.autotuner import DifferentialEvolutionSearch from helion.autotuner import PatternSearch from helion.autotuner.base_search import BaseSearch @@ -486,6 +487,79 @@ def test_differential_evolution_search(self): fn = bound_kernel.compile_config(best) torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) + @skipIfRocm("too slow on rocm") + @skip("too slow") + def test_de_surrogate_hybrid(self): + args = ( + torch.randn([512, 512], device=DEVICE), + torch.randn([512, 512], device=DEVICE), + ) + bound_kernel = examples_matmul.bind(args) + random.seed(123) + best = DESurrogateHybrid( + bound_kernel, args, population_size=5, max_generations=3 + ).autotune() + fn = bound_kernel.compile_config(best) + torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) + + @skipIfRocm("too slow on rocm") + @skipIfCpu("fails on Triton CPU backend") + def test_differential_evolution_early_stopping_parameters(self): + """Test that early stopping is disabled by default and can be enabled.""" + args = ( + torch.randn([64, 64], device=DEVICE), + torch.randn([64, 64], device=DEVICE), + ) + bound_kernel = basic_kernels.add.bind(args) + + # Test 1: Default parameters (early stopping disabled) + search = DifferentialEvolutionSearch( + bound_kernel, args, population_size=5, max_generations=3 + ) + self.assertIsNone(search.min_improvement_delta) + self.assertIsNone(search.patience) + + # Test 2: Enable early stopping with custom parameters + search_custom = DifferentialEvolutionSearch( + bound_kernel, + args, + population_size=5, + max_generations=3, + min_improvement_delta=0.01, + patience=5, + ) + self.assertEqual(search_custom.min_improvement_delta, 0.01) + self.assertEqual(search_custom.patience, 5) + + @skipIfRocm("too slow on rocm") + @skipIfCpu("fails on Triton CPU backend") + def test_de_surrogate_early_stopping_parameters(self): + """Test that DE-Surrogate early stopping parameters are optional with correct defaults.""" + args = ( + torch.randn([64, 64], device=DEVICE), + torch.randn([64, 64], device=DEVICE), + ) + bound_kernel = basic_kernels.add.bind(args) + + # Test 1: Default parameters (optional) + search = DESurrogateHybrid( + bound_kernel, args, population_size=5, max_generations=3 + ) + self.assertEqual(search.min_improvement_delta, 0.001) + self.assertEqual(search.patience, 3) + + # Test 2: Custom parameters + search_custom = DESurrogateHybrid( + bound_kernel, + args, + population_size=5, + max_generations=3, + min_improvement_delta=0.01, + patience=5, + ) + self.assertEqual(search_custom.min_improvement_delta, 0.01) + self.assertEqual(search_custom.patience, 5) + @skip("too slow") def test_pattern_search(self): args = (