From 9c3e56a97ae77f158798a94c3b4dc54738e2e304 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 11 Nov 2025 10:13:22 -0800 Subject: [PATCH 01/36] Adding UCBPatternSearch autotuner and fragment_encoder --- helion/autotuner/__init__.py | 2 + helion/autotuner/fragment_encoder.py | 266 ++++++++++++++++++ helion/autotuner/ucb_pattern_search.py | 355 +++++++++++++++++++++++++ 3 files changed, 623 insertions(+) create mode 100644 helion/autotuner/fragment_encoder.py create mode 100644 helion/autotuner/ucb_pattern_search.py diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 541f4a787..6d0602b2a 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .fragment_encoder import ConfigEncoder from .config_fragment import BooleanFragment as BooleanFragment from .config_fragment import EnumFragment as EnumFragment from .config_fragment import IntegerFragment as IntegerFragment @@ -17,6 +18,7 @@ from .local_cache import LocalAutotuneCache as LocalAutotuneCache from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache from .pattern_search import PatternSearch as PatternSearch +from .ucb_pattern_search import UCBPatternSearch from .random_search import RandomSearch as RandomSearch search_algorithms = { diff --git a/helion/autotuner/fragment_encoder.py b/helion/autotuner/fragment_encoder.py new file mode 100644 index 000000000..e0e73270b --- /dev/null +++ b/helion/autotuner/fragment_encoder.py @@ -0,0 +1,266 @@ +"""Fragment encoding/decoding strategies for machine learning based autotuners. + +This module provides a clean abstraction for encoding different fragment types +into numerical tensors and decoding them back. Each fragment type has its own +encoder that knows how to analyze, encode, and decode itself. +""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +import math + +from .config_fragment import BooleanFragment +from .config_fragment import ConfigSpecFragment +from .config_fragment import EnumFragment +from .config_fragment import IntegerFragment +from .config_fragment import ListOf +from .config_fragment import PermutationFragment +from .config_fragment import PowerOfTwoFragment + + +class FragmentEncoder(ABC): + """Base class for encoding/decoding fragment values.""" + + def __init__(self, fragment: ConfigSpecFragment) -> None: + self.fragment = fragment + + @abstractmethod + def n_dims(self) -> int: + """Return the number of dimensions this fragment uses in encoded space.""" + + @abstractmethod + def is_categorical(self) -> bool: + """Return whether this fragment represents categorical data.""" + + @abstractmethod + def encode(self, value: object) -> list[float]: + """Encode a value into a list of floats.""" + + @abstractmethod + def decode(self, encoded: list[float]) -> object: + """Decode a list of floats back to the original value type.""" + + +class CategoricalEncoder(FragmentEncoder): + """Encoder for EnumFragment and BooleanFragment using one-hot encoding.""" + + def __init__( + self, fragment: EnumFragment | BooleanFragment, choices: list[object] + ) -> None: + super().__init__(fragment) + self.choices = choices + + def n_dims(self) -> int: + return len(self.choices) + + def is_categorical(self) -> bool: + return True + + def encode(self, value: object) -> list[float]: + idx = self.choices.index(value) + return [1.0 if i == idx else 0.0 for i in range(len(self.choices))] + + def decode(self, encoded: list[float]) -> object: + choice_idx = max(range(len(self.choices)), key=lambda i: encoded[i]) + return self.choices[choice_idx] + + +class PowerOfTwoEncoder(FragmentEncoder): + """Encoder for PowerOfTwoFragment using log2 transformation.""" + + def __init__(self, fragment: PowerOfTwoFragment) -> None: + super().__init__(fragment) + self.log_min = math.log2(fragment.low) + self.log_max = math.log2(fragment.high) + + def n_dims(self) -> int: + return 1 + + def is_categorical(self) -> bool: + return False + + def encode(self, value: int) -> list[float]: + return [math.log2(value)] + + def decode(self, encoded: list[float]) -> int: + log_val = encoded[0] + power = int(round(log_val)) + power = max(int(self.log_min), min(power, int(self.log_max))) + return 2**power + + +class IntegerEncoder(FragmentEncoder): + """Encoder for IntegerFragment using raw values.""" + + def __init__(self, fragment: IntegerFragment) -> None: + super().__init__(fragment) + self.min_val = fragment.low + self.max_val = fragment.high + + def n_dims(self) -> int: + return 1 + + def is_categorical(self) -> bool: + return False + + def encode(self, value: object) -> list[float]: + return [float(value)] + + def decode(self, encoded: list[float]) -> int: + value = int(round(encoded[0])) + return max(self.min_val, min(value, self.max_val)) + + +class PermutationEncoder(FragmentEncoder): + """Encoder for PermutationFragment using one-hot encoding for each position.""" + + def __init__(self, fragment: PermutationFragment) -> None: + super().__init__(fragment) + self.length = fragment.length + + def n_dims(self) -> int: + return self.length * self.length + + def is_categorical(self) -> bool: + return True + + def encode(self, value: list[int]) -> list[float]: + encoded = [] + for pos in range(self.length): + val = value[pos] + for v in range(self.length): + encoded.append(1.0 if v == val else 0.0) + return encoded + + def decode(self, encoded: list[float]) -> list[int]: + perm = [] + used = set() + + for pos in range(self.length): + start_idx = pos * self.length + one_hot = encoded[start_idx : start_idx + self.length] + val = max(range(self.length), key=lambda i: one_hot[i]) + perm.append(val) + used.add(val) + + # Fix invalid permutation (duplicates/missing values) + if len(used) != self.length: + available = [v for v in range(self.length) if v not in used] + seen = set() + fixed_perm = [] + for val in perm: + if val in seen: + fixed_val = available.pop(0) + fixed_perm.append(fixed_val) + else: + fixed_perm.append(val) + seen.add(val) + return fixed_perm + + return perm + + +class ListOfEncoder(FragmentEncoder): + """Encoder for ListOf fragments, delegates to inner encoder.""" + + def __init__(self, fragment: ListOf, inner_encoder: FragmentEncoder) -> None: + super().__init__(fragment) + self.length = fragment.length + self.inner_encoder = inner_encoder + self.inner_dims = inner_encoder.n_dims() + + def n_dims(self) -> int: + return self.length * self.inner_dims + + def is_categorical(self) -> bool: + """Return True if the inner encoder is categorical.""" + return self.inner_encoder.is_categorical() + + def encode(self, value: list[object]) -> list[float]: + encoded = [] + for v in value: + encoded.extend(self.inner_encoder.encode(v)) + return encoded + + def decode(self, encoded: list[float]) -> list[object]: + decoded = [] + for i in range(self.length): + start_idx = i * self.inner_dims + element_encoded = encoded[start_idx : start_idx + self.inner_dims] + decoded.append(self.inner_encoder.decode(element_encoded)) + return decoded + + +def create_encoder(fragment: ConfigSpecFragment) -> FragmentEncoder: + """Factory function to create the appropriate encoder for a fragment.""" + if isinstance(fragment, BooleanFragment): + return CategoricalEncoder(fragment, [False, True]) + if isinstance(fragment, EnumFragment): + return CategoricalEncoder(fragment, list(fragment.choices)) + if isinstance(fragment, PowerOfTwoFragment): + return PowerOfTwoEncoder(fragment) + if isinstance(fragment, IntegerFragment): + return IntegerEncoder(fragment) + if isinstance(fragment, PermutationFragment): + return PermutationEncoder(fragment) + if isinstance(fragment, ListOf): + inner_encoder = create_encoder(fragment.inner) + return ListOfEncoder(fragment, inner_encoder) + raise ValueError(f"Unsupported fragment type: {type(fragment).__name__}") + + +class ConfigEncoder: + """Encodes and decodes entire configurations using fragment encoders.""" + + def __init__(self, flat_spec: list[ConfigSpecFragment]) -> None: + """Initialize encoders for all fragments in the spec. + + Args: + flat_spec: List of fragment specifications + """ + self.encoders = [create_encoder(fragment) for fragment in flat_spec] + self.total_dims = sum(encoder.n_dims() for encoder in self.encoders) + + # Build categorical dimension indices (absolute positions) + self.cat_dims = [] + offset = 0 + for encoder in self.encoders: + n_dims = encoder.n_dims() + if encoder.is_categorical(): + # All dimensions of this encoder are categorical + self.cat_dims.extend(range(offset, offset + n_dims)) + offset += n_dims + + def encode(self, flat_config: list[object]) -> list[float]: + """Encode a flat configuration into a list of floats. + + Args: + flat_config: List of configuration values + + Returns: + List of encoded float values + """ + encoded = [] + for value, encoder in zip(flat_config, self.encoders, strict=False): + encoded.extend(encoder.encode(value)) + return encoded + + def decode(self, encoded: list[float]) -> list[object]: + """Decode a list of floats back into a flat configuration. + + Args: + encoded: List of encoded float values + + Returns: + List of decoded configuration values + """ + decoded = [] + idx = 0 + for encoder in self.encoders: + n_dims = encoder.n_dims() + fragment_encoded = encoded[idx : idx + n_dims] + decoded.append(encoder.decode(fragment_encoded)) + idx += n_dims + return decoded diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py new file mode 100644 index 000000000..d749be0ba --- /dev/null +++ b/helion/autotuner/ucb_pattern_search.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +import math +import random +from typing import TYPE_CHECKING + +import torch + +from .. import exc +from .base_search import FlatConfig +from .base_search import PopulationMember +from .base_search import performance +from .config_fragment import PowerOfTwoFragment +from .effort_profile import PATTERN_SEARCH_DEFAULTS +from .fragment_encoder import ConfigEncoder +from .pattern_search import PatternSearch + +if TYPE_CHECKING: + from collections.abc import Iterator + from collections.abc import Sequence + + from ..runtime.config import Config + from ..runtime.kernel import BoundKernel + + +import operator + +from botorch.acquisition import UpperConfidenceBound +from botorch.fit import fit_gpytorch_mll +from botorch.models import MixedSingleTaskGP +from gpytorch.mlls import ExactMarginalLogLikelihood + + +class UCBPatternSearch(PatternSearch): + def __init__( + self, + kernel: BoundKernel, + args: Sequence[object], + *, + initial_population: int = PATTERN_SEARCH_DEFAULTS.initial_population, + copies: int = PATTERN_SEARCH_DEFAULTS.copies, + max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, + min_improvement_delta: float = 0.001, + frac_selected: float = 0.3, + num_neighbors: int = 100, + radius: int = 2, + ucb_beta: float = 2.0, + ) -> None: + super().__init__( + kernel=kernel, + args=args, + initial_population=initial_population, + copies=copies, + max_generations=max_generations, + min_improvement_delta=min_improvement_delta, + ) + + # Storage for BO + self.num_neighbors = num_neighbors + self.radius = radius + self.ucb_beta = ucb_beta + + # Initialize config encoder + self.config_encoder = ConfigEncoder(self.config_gen.flat_spec) + self.frac_selected = frac_selected + + def fit_gp( + self, train_X: torch.Tensor, train_Y: torch.Tensor, cat_dims: list + ) -> MixedSingleTaskGP: + # Filter out rows where train_Y contains inf or nan + valid_mask = torch.isfinite(train_Y) + train_X_filtered = train_X[valid_mask] + train_Y_filtered = train_Y[valid_mask] + + gp = MixedSingleTaskGP( + train_X_filtered.to(dtype=torch.float64), + -train_Y_filtered.unsqueeze(-1).to(dtype=torch.float64), + cat_dims, + ) + + with torch.enable_grad(): + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) + + return gp + + def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP) -> torch.Tensor: + orig_dtype = X.dtype + acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) + return ( + acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) + .detach() + .to(dtype=orig_dtype) + ) + + def get_train_data_from_pop( + self, population: list[PopulationMember] + ) -> tuple[torch.Tensor, torch.Tensor]: + train_X = [] + train_Y = [] + for member in population: + train_X.append(torch.tensor(self.config_encoder.encode(member.flat_values))) + train_Y.append(member.perf) + + return torch.stack(train_X), torch.tensor(train_Y) + + def _autotune(self) -> Config: + self.log( + f"Starting UCBPatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}" + ) + visited = set() + self.population = [] + for flat_config in self.config_gen.random_population_flat( + self.initial_population + ): + member = self.make_unbenchmarked(flat_config) + if member.config not in visited: + visited.add(member.config) + self.population.append(member) + self.parallel_benchmark_population(self.population, desc="Initial population") + # again with higher accuracy + self.rebenchmark_population(self.population, desc="Verifying initial results") + self.population.sort(key=performance) + starting_points = [] + for member in self.population[: self.copies]: + if math.isfinite(member.perf): # filter failed compiles + starting_points.append(member) + self.log( + f"Initial random population of {len(self.population)}, {len(starting_points)} starting points:", + self.statistics, + ) + if not starting_points: + raise exc.NoConfigFound + + # Save to training data + train_X, train_Y = self.get_train_data_from_pop(self.population) + + # Fit GP + self.log(f"Fitting GP: {len(train_X)} points, {len(train_Y)} targets") + gp = self.fit_gp( + train_X, + train_Y, + self.config_encoder.cat_dims, + ) + + search_copies = [ + self._pruned_pattern_search_from(m, visited, gp) for m in starting_points + ] + for generation in range(1, self.max_generations + 1): + prior_best = self.best + new_population = {id(prior_best): prior_best} + num_neighbors = 0 + num_active = 0 + for search_copy in search_copies: + added = next(search_copy, ()) + if added: + assert len(added) > 1 + num_active += 1 + num_neighbors += len(added) - 1 + for member in added: + new_population[id(member)] = member + if num_active == 0: + break + + # Log generation header before compiling/benchmarking + self.log( + f"Generation {generation} starting: {num_neighbors} neighbors, {num_active} active search path(s)" + ) + + self.population = [*new_population.values()] + # compile any unbenchmarked members in parallel + unbenchmarked = [m for m in self.population if len(m.perfs) == 0] + if unbenchmarked: + self.parallel_benchmark_population( + unbenchmarked, desc=f"Generation {generation}:" + ) + # higher-accuracy rebenchmark + self.rebenchmark_population( + self.population, desc=f"Generation {generation}: verifying top configs" + ) + # Log final statistics for this generation + self.log(f"Generation {generation} complete:", self.statistics) + + # Save to training data + train_X, train_Y = self.get_train_data_from_pop(self.population) + + self.log( + f"Conditioning on new data: {len(train_X)} points, {len(train_Y)} targets" + ) + gp = gp.condition_on_observations(train_X, train_Y) + + return self.best.config + + def random_log2_neighbor( + self, current_val: int, radius: int, low: int, high: int + ) -> int: + # Log the current value + current_log = int(math.log2(current_val)) + # Random log perturbation + delta = random.randint(-radius, radius) + new_log = current_log + delta + # Clamp to valid range + min_log = int(math.log2(low)) + max_log = int(math.log2(high)) + new_log = max(min_log, min(new_log, max_log)) + return int(2**new_log) + + def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: + """ + Generate neighboring configurations randomly within a specified radius. + + Strategy: + 1. Sample one block size index and change it by at most radius (in log2 space) + 2. Sample the num_warps index and change it by at most radius (in log2 space) + 3. For at most radius remaining indices, randomly select pattern neighbors + + Args: + base: The base configuration to generate neighbors from + + Returns: + A list of neighboring configurations + """ + neighbors: list[FlatConfig] = [] + + # Generate num_neighbors random neighbors + for _ in range(self.num_neighbors): + new_flat = [*base] # Copy the base configuration + modified_indices = set() + + # 1. Sample a block size index and change it by at most 1 + if self.config_gen.block_size_indices: + block_idx = random.choice(self.config_gen.block_size_indices) + modified_indices.add(block_idx) + + block_spec = self.config_gen.flat_spec[block_idx] + current_val = base[block_idx] + + if isinstance(block_spec, PowerOfTwoFragment): + # Change by at most 1 in log2 space + new_flat[block_idx] = self.random_log2_neighbor( + current_val, radius=1, low=block_spec.low, high=block_spec.high + ) + else: + raise ValueError("BlockSize should be PowerOfTwoFragment") + + # 2. Sample the num_warps index and change it by at most radius + if self.config_gen.num_warps_index: + warp_idx = self.config_gen.num_warps_index + modified_indices.add(warp_idx) + + warp_spec = self.config_gen.flat_spec[warp_idx] + current_val = base[warp_idx] + + if isinstance(warp_spec, PowerOfTwoFragment): + # Change by at most self.radius in log2 space + new_flat[warp_idx] = self.random_log2_neighbor( + current_val, + radius=self.radius, + low=warp_spec.low, + high=warp_spec.high, + ) + else: + raise ValueError("NumWarps should be PowerOfTwoFragment") + + # 3. For at most radius remaining indices, use pattern neighbors + # Exclude the already-modified block size and warp indices + + # Collect available pattern neighbors for remaining indices + remaining_pattern_neighbors = [] + for index, spec in enumerate(self.config_gen.flat_spec): + if index not in modified_indices: + pattern_neighbors = spec.pattern_neighbors(base[index]) + if pattern_neighbors: + remaining_pattern_neighbors.append((index, pattern_neighbors)) + + # Randomly select at most radius indices to change + if remaining_pattern_neighbors: + num_to_change = random.randint( + 0, min(self.radius, len(remaining_pattern_neighbors)) + ) + if num_to_change > 0: + indices_to_change = random.sample( + remaining_pattern_neighbors, num_to_change + ) + for idx, pattern_neighbors in indices_to_change: + new_flat[idx] = random.choice(pattern_neighbors) + + # Only add if it's different from the base + if new_flat != base: + neighbors.append(new_flat) + + return neighbors + + def _pruned_pattern_search_from( + self, + current: PopulationMember, + visited: set[Config], + gp: MixedSingleTaskGP, + ) -> Iterator[list[PopulationMember]]: + """ + Run a single copy of pattern search from the given starting point. + + We use a generator and yield the new population at each generation so that we can + run multiple copies of pattern search in parallel. + + Only keep self.frac_selected of the neighbors generated from the current + search_copy. Filter them using the GaussianProcess. + """ + for _ in range(self.max_generations): + candidates = [current] + all_neighbors = self._generate_neighbors(current.flat_values) + self.log(f"Number of all candidate neighbors: {len(all_neighbors)}") + for flat_config in all_neighbors: + new_member = self.make_unbenchmarked(flat_config) + if new_member.config not in visited: + candidates.append(new_member) + + # score candidates + candidate_X = torch.stack( + [ + torch.tensor(self.config_encoder.encode(member.flat_values)) + for member in candidates + ] + ) + scores = self.acq_fun(candidate_X, gp) + + # filter candidates by score + candidates_sorted = sorted( + zip(candidates, scores, strict=True), + key=operator.itemgetter(1), + reverse=True, + )[: int(self.frac_selected * len(candidates))] + candidates = [member for member, score in candidates_sorted] + visited.update([member.config for member in candidates]) + + self.log( + f"Scoring {len(candidate_X)} neighbors, selecting {self.frac_selected * 100}% neighbors: {len(candidates)}" + ) + + if len(candidates) <= 1: + return # no new candidates, stop searching + yield candidates # yield new population to benchmark in parallel + best = min(candidates, key=performance) + if best is current: + return # no improvement, stop searching + # Stop if the relative improvement is smaller than a user-specified delta + if ( + self.min_improvement_delta > 0.0 + and math.isfinite(best.perf) + and math.isfinite(current.perf) + and current.perf != 0.0 + and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta + ): + return + current = best From 8035c8e79d33209b90aefbb9ebc599c641646e2c Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 11 Nov 2025 11:16:33 -0800 Subject: [PATCH 02/36] add tests and import --- pyproject.toml | 3 + test/test_autotuner.py | 231 ++++++++++++++++++++++++++++++++++------- 2 files changed, 196 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0c898b8f0..e03f0396d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ dependencies = [ ] [project.optional-dependencies] +ucb_pattern_search = [ + "botorch" +] dev = [ "expecttest", "pytest", diff --git a/test/test_autotuner.py b/test/test_autotuner.py index e95fde358..05a4643dd 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -1,56 +1,57 @@ from __future__ import annotations import collections -from contextlib import contextmanager -from contextlib import nullcontext import csv -from itertools import count import logging import math import multiprocessing as mp import operator import os -from pathlib import Path import pickle import random import tempfile -from types import SimpleNamespace -from typing import Callable -from typing import Sequence import unittest +from contextlib import contextmanager, nullcontext +from itertools import count +from pathlib import Path +from types import SimpleNamespace +from typing import Callable, Sequence from unittest import skip from unittest.mock import patch +import helion +import helion.language as hl + import pytest import torch - -import helion -from helion import _compat -from helion import exc -from helion._testing import DEVICE -from helion._testing import RefEagerTestDisabled -from helion._testing import TestCase -from helion._testing import import_path -from helion._testing import skipIfCpu -from helion._testing import skipIfRocm -from helion.autotuner import DifferentialEvolutionSearch -from helion.autotuner import PatternSearch -from helion.autotuner.base_search import BaseSearch -from helion.autotuner.base_search import PopulationMember -from helion.autotuner.config_fragment import BooleanFragment -from helion.autotuner.config_fragment import EnumFragment -from helion.autotuner.config_fragment import IntegerFragment -from helion.autotuner.config_fragment import ListOf -from helion.autotuner.config_fragment import PowerOfTwoFragment +from helion import _compat, exc +from helion._testing import ( + DEVICE, + import_path, + RefEagerTestDisabled, + skipIfCpu, + skipIfRocm, + TestCase, +) +from helion.autotuner import ( + DifferentialEvolutionSearch, + PatternSearch, + UCBPatternSearch, +) +from helion.autotuner.base_search import BaseSearch, PopulationMember +from helion.autotuner.config_fragment import ( + BooleanFragment, + EnumFragment, + IntegerFragment, + ListOf, + PowerOfTwoFragment, +) from helion.autotuner.config_generation import ConfigGeneration from helion.autotuner.effort_profile import get_effort_profile from helion.autotuner.finite_search import FiniteSearch -from helion.autotuner.local_cache import LocalAutotuneCache -from helion.autotuner.local_cache import StrictLocalAutotuneCache -from helion.autotuner.logger import AutotuneLogEntry -from helion.autotuner.logger import AutotuningLogger +from helion.autotuner.local_cache import LocalAutotuneCache, StrictLocalAutotuneCache +from helion.autotuner.logger import AutotuneLogEntry, AutotuningLogger from helion.autotuner.random_search import RandomSearch -import helion.language as hl from helion.language import loops from helion.runtime.settings import Settings @@ -548,6 +549,162 @@ def diff_count(flat): ] self.assertEqual(sorted(pair_neighbors), sorted(expected)) + def test_ucb_pattern_search_generate_neighbors(self): + """Test UCBPatternSearch._generate_neighbors method.""" + random.seed(123) + search = UCBPatternSearch.__new__(UCBPatternSearch) + search.num_neighbors = 50 + search.radius = 2 + search.config_gen = SimpleNamespace( + flat_spec=[ + PowerOfTwoFragment(16, 128, 32), # block_size[0] + PowerOfTwoFragment(16, 128, 64), # block_size[1] + PowerOfTwoFragment(2, 16, 4), # num_warps + EnumFragment(("a", "b", "c")), # some enum + BooleanFragment(), # some boolean + ], + block_size_indices=[0, 1], + num_warps_index=2, + ) + + base = [32, 64, 4, "b", True] + neighbors = search._generate_neighbors(base) + + # Check we generate the correct number of neighbors + self.assertEquals(len(neighbors), search.num_neighbors) + + # Check all neighbors are different from base + for neighbor in neighbors: + self.assertNotEqual(neighbor, base) + + # Verify all block sizes are valid powers of two in range + for neighbor in neighbors: + # Check block_size[0] + self.assertIn(neighbor[0], [16, 32, 64, 128]) + # Check block_size[1] + self.assertIn(neighbor[1], [16, 32, 64, 128]) + # Check num_warps + self.assertIn(neighbor[2], [2, 4, 8, 16]) + # Check enum + self.assertIn(neighbor[3], ["a", "b", "c"]) + # Check boolean + self.assertIn(neighbor[4], [True, False]) + + def test_ucb_pattern_search_generate_neighbors_radius(self): + """Test that UCBPatternSearch respects radius parameter.""" + # Test with different radius values to ensure constraint holds + test_cases = [ + { + "radius": 1, + "base": [64, 8, "y"], + "expected_warps": [4, 8, 16], # 8=2^3, radius=1 -> 2^2, 2^3, 2^4 + }, + { + "radius": 2, + "base": [32, 8, "y"], + "expected_warps": [ + 2, + 4, + 8, + 16, + ], # 8=2^3, radius=2 -> 2^1, 2^2, 2^3, 2^4 + }, + { + "radius": 0, + "base": [64, 4, "y"], + "expected_warps": [4], # radius=0 -> no change to num_warps + }, + ] + + for test_case in test_cases: + random.seed(123) + search = UCBPatternSearch.__new__(UCBPatternSearch) + search.num_neighbors = 100 + search.radius = test_case["radius"] + search.config_gen = SimpleNamespace( + flat_spec=[ + PowerOfTwoFragment(16, 128, 64), # block_size + PowerOfTwoFragment(2, 16, 8), # num_warps + EnumFragment(("x", "y", "z")), + ], + block_size_indices=[0], + num_warps_index=1, + ) + + base = test_case["base"] + neighbors = search._generate_neighbors(base) + + # Verify all neighbors respect strict constraints + for neighbor in neighbors: + # Block size should ALWAYS vary by at most 1 in log2 space (independent of radius) + base_log0 = int(math.log2(base[0])) + neighbor_log0 = int(math.log2(neighbor[0])) + self.assertLessEqual( + abs(neighbor_log0 - base_log0), + 1, + f"With radius={search.radius}, block size changed by {abs(neighbor_log0 - base_log0)} in log2 space (more than 1)", + ) + self.assertIn(neighbor[0], [16, 32, 64, 128]) + + # num_warps should vary by at most radius in log2 space + base_log1 = int(math.log2(base[1])) + neighbor_log1 = int(math.log2(neighbor[1])) + self.assertLessEqual( + abs(neighbor_log1 - base_log1), + search.radius, + f"With radius={search.radius}, num_warps changed by {abs(neighbor_log1 - base_log1)} in log2 space (more than radius={search.radius})", + ) + # Verify num_warps is within expected values for this radius + self.assertIn( + neighbor[1], + test_case["expected_warps"], + f"With radius={search.radius}, base num_warps={base[1]}, got {neighbor[1]} which is not in expected {test_case['expected_warps']}", + ) + + @skipIfRocm("too slow on rocm") + @skip("too slow") + def test_ucb_pattern_search(self): + args = ( + torch.randn([64, 64], device=DEVICE), + torch.randn([64, 64], device=DEVICE), + ) + bound_kernel = basic_kernels.add.bind(args) + random.seed(123) + best = UCBPatternSearch( + bound_kernel, + args, + initial_population=10, + max_generations=2, + copies=1, + num_neighbors=10, + ).autotune() + fn = bound_kernel.compile_config(best) + torch.testing.assert_close(fn(*args), sum(args), rtol=1e-2, atol=1e-1) + + def test_encoder(self): + args = ( + torch.randn([64, 64], device=DEVICE), + torch.randn([64, 64], device=DEVICE), + ) + bound_kernel = basic_kernels.add.bind(args) + random.seed(123) + + test_flat_configs = [] + + # First attach the default config + config_gen = ConfigGeneration(bound_kernel.config_spec) + default_flat = config_gen.default_flat() + test_flat_configs.append(default_flat) + + # Test random configs + random_configs = config_gen.random_population_flat(10) + test_flat_configs = test_flat_configs + random_configs + + for flat_config in test_flat_configs: + encoded = helion.ConfigEncoder().encode(config) + decoded = helion.ConfigEncoder().decode(encoded) + self.assertEqual(flat_config, decoded) + @skipIfCpu("fails on Triton CPU backend") def test_accuracy_check_filters_bad_config_wrong_output(self) -> None: bad_config = helion.Config(block_sizes=[1], num_warps=8) @@ -588,8 +745,7 @@ def make_bad_config_produce_wrong_output( start_cm = patch.object( search, "start_precompile_and_check_for_hangs", - side_effect=lambda config, - fn: base_search_module.PrecompileFuture.skip( + side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip( search, config, True ), ) @@ -669,8 +825,7 @@ def wrong_fn(*fn_args, **fn_kwargs): start_cm = patch.object( search, "start_precompile_and_check_for_hangs", - side_effect=lambda config, - fn: base_search_module.PrecompileFuture.skip( + side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip( search, config, True ), ) @@ -1019,9 +1174,9 @@ def add(a, b): torch.testing.assert_close(bound_kernel(*args), sum(args), rtol=1e-2, atol=1e-1) search = search_capture["search"] - assert search.samples, ( - "expected RecordingRandomSearch to record a random sample" - ) + assert ( + search.samples + ), "expected RecordingRandomSearch to record a random sample" return search.samples[0] @skipIfRocm("accuracy difference") From e24b529d680a795de3f392e034dadc9a9c36c166 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Wed, 12 Nov 2025 18:46:26 -0800 Subject: [PATCH 03/36] moved config encoding to config_fragment --- helion/autotuner/config_fragment.py | 113 ++++++++++- helion/autotuner/config_generation.py | 25 ++- helion/autotuner/fragment_encoder.py | 266 ------------------------- helion/autotuner/ucb_pattern_search.py | 99 +++++---- 4 files changed, 191 insertions(+), 312 deletions(-) delete mode 100644 helion/autotuner/fragment_encoder.py diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index c8bddf2b3..bc3efa2a0 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -3,9 +3,7 @@ import dataclasses import enum import random -from typing import Iterable -from typing import TypeGuard -from typing import cast +from typing import cast, Iterable, TypeGuard from ..exc import InvalidConfig @@ -51,6 +49,21 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: def is_block_size(self) -> bool: return False + def is_categorical(self) -> bool: + return True + + def encode_dim(self) -> int: + """ + Returns the dimension of the output of encode + """ + raise NotImplementedError + + def encode(self, value: object) -> list[float]: + """ + Returns a list of floats that can be used to encode the value of this fragment. + """ + raise NotImplementedError + def get_minimum(self) -> int: """ Return the minimum allowed value for this fragment. @@ -86,6 +99,17 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(swapped) return neighbors + def encode_dim(self) -> int: + return self.length * self.length + + def encode(self, value: object) -> list[float]: + encoded = [] + for pos in range(self.length): + val = value[pos] + for v in range(self.length): + encoded.append(1.0 if v == val else 0.0) + return encoded + @dataclasses.dataclass class BaseIntegerFragment(ConfigSpecFragment): @@ -106,6 +130,9 @@ def default(self) -> int: def clamp(self, val: int) -> int: return max(min(val, self.high), self.low) + def is_categorical(self) -> bool: + return False + def get_minimum(self) -> int: return self.low @@ -121,6 +148,20 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(upper) return neighbors + def encode_dim(self) -> int: + return 1 + + def encode(self, value: object) -> list[float]: + """Encode enum values as their index.""" + try: + choice_idx = self.choices.index(value) + except ValueError: + raise ValueError( + f"Invalid enum value {value!r} for EnumFragment. " + f"Valid choices: {self.choices}" + ) from None + return [float(choice_idx)] + class PowerOfTwoFragment(BaseIntegerFragment): def random(self) -> int: @@ -152,6 +193,23 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(ai * 2) return ai + def encode_dim(self) -> int: + return 1 + + def encode(self, value: object) -> list[float]: + """Encode power-of-2 values using log2 transformation.""" + import math + + if not isinstance(value, (int, float)): + raise TypeError( + f"Expected int/float for PowerOfTwoFragment, got {type(value).__name__}: {value!r}" + ) + if value <= 0: + raise ValueError( + f"Expected positive value for PowerOfTwoFragment, got {value}" + ) + return [math.log2(float(value))] + class IntegerFragment(BaseIntegerFragment): def random(self) -> int: @@ -169,6 +227,20 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(a + 1) return a + def encode_dim(self) -> int: + return 1 + + def encode(self, value: object) -> list[float]: + """Encode enum values as their index.""" + try: + choice_idx = self.choices.index(value) + except ValueError: + raise ValueError( + f"Invalid enum value {value!r} for EnumFragment. " + f"Valid choices: {self.choices}" + ) from None + return [float(choice_idx)] + @dataclasses.dataclass class EnumFragment(ConfigSpecFragment): @@ -193,6 +265,20 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: choices.remove(a) return random.choice(choices) + def encode_dim(self) -> int: + return len(self.choices) + + def encode(self, value: object) -> list[float]: + """Encode enum values as their index.""" + try: + choice_idx = self.choices.index(value) + except ValueError: + raise ValueError( + f"Invalid enum value {value!r} for EnumFragment. " + f"Valid choices: {self.choices}" + ) from None + return [1.0 if i == choice_idx else 0.0 for i in range(len(self.choices))] + class BooleanFragment(ConfigSpecFragment): def default(self) -> bool: @@ -212,6 +298,14 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool: return a return not a + def encode_dim(self) -> int: + return 1 + + def encode(self, value: object) -> list[float]: + """Encode enum values as their index.""" + assert isinstance(value, bool) + return [1.0] if value else [0.0] + class BlockSizeFragment(PowerOfTwoFragment): def category(self) -> Category: @@ -243,6 +337,9 @@ def random(self) -> list[object]: """Return a list of random values.""" return [self.inner.random() for _ in range(self.length)] + def is_categorical(self) -> bool: + return self.inner.is_categorical() + def pattern_neighbors(self, current: object) -> list[object]: """Return neighbors by changing one element at a time.""" if not isinstance(current, list) or len(current) != self.length: @@ -267,3 +364,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object] self.inner.differential_mutation(a[i], b[i], c[i]) for i in range(self.length) ] + + def encode_dim(self): + return self.length * self.inner.encode_dim() + + def encode(self, value: object) -> list[float]: + assert isinstance(value, list[object]) + encoded = [] + for v in value: + encoded.extend(self.inner_encoder.encode(v)) + return encoded diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index 505f95da5..30cb00e73 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -5,13 +5,10 @@ import itertools import operator import random -from typing import TYPE_CHECKING -from typing import cast +from typing import cast, TYPE_CHECKING from .._compat import warps_to_threads -from .config_fragment import Category -from .config_fragment import ConfigSpecFragment -from .config_fragment import PowerOfTwoFragment +from .config_fragment import Category, ConfigSpecFragment, PowerOfTwoFragment if TYPE_CHECKING: from collections.abc import Mapping @@ -181,3 +178,21 @@ def differential_mutation( # TODO(jansel): can this be larger? (too large and Triton compile times blow up) self.shrink_config(result, 8192) return result + + def encode_config(self, flat_config: FlatConfig) -> list[float]: + """ + Encode a flat configuration into a numerical vector for ML models. + This is used by surrogate-assisted algorithms (e.g., DE-Surrogate) that need + to represent configurations as continuous vectors for prediction models. + Args: + flat_config: The flat configuration values to encode. + Returns: + A list of floats representing the encoded configuration. + """ + encoded: list[float] = [] + + for flat_idx, spec in enumerate(self.flat_spec): + value = flat_config[flat_idx] + encoded.extend(spec.encode(value)) + + return encoded diff --git a/helion/autotuner/fragment_encoder.py b/helion/autotuner/fragment_encoder.py deleted file mode 100644 index e0e73270b..000000000 --- a/helion/autotuner/fragment_encoder.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Fragment encoding/decoding strategies for machine learning based autotuners. - -This module provides a clean abstraction for encoding different fragment types -into numerical tensors and decoding them back. Each fragment type has its own -encoder that knows how to analyze, encode, and decode itself. -""" - -from __future__ import annotations - -from abc import ABC -from abc import abstractmethod -import math - -from .config_fragment import BooleanFragment -from .config_fragment import ConfigSpecFragment -from .config_fragment import EnumFragment -from .config_fragment import IntegerFragment -from .config_fragment import ListOf -from .config_fragment import PermutationFragment -from .config_fragment import PowerOfTwoFragment - - -class FragmentEncoder(ABC): - """Base class for encoding/decoding fragment values.""" - - def __init__(self, fragment: ConfigSpecFragment) -> None: - self.fragment = fragment - - @abstractmethod - def n_dims(self) -> int: - """Return the number of dimensions this fragment uses in encoded space.""" - - @abstractmethod - def is_categorical(self) -> bool: - """Return whether this fragment represents categorical data.""" - - @abstractmethod - def encode(self, value: object) -> list[float]: - """Encode a value into a list of floats.""" - - @abstractmethod - def decode(self, encoded: list[float]) -> object: - """Decode a list of floats back to the original value type.""" - - -class CategoricalEncoder(FragmentEncoder): - """Encoder for EnumFragment and BooleanFragment using one-hot encoding.""" - - def __init__( - self, fragment: EnumFragment | BooleanFragment, choices: list[object] - ) -> None: - super().__init__(fragment) - self.choices = choices - - def n_dims(self) -> int: - return len(self.choices) - - def is_categorical(self) -> bool: - return True - - def encode(self, value: object) -> list[float]: - idx = self.choices.index(value) - return [1.0 if i == idx else 0.0 for i in range(len(self.choices))] - - def decode(self, encoded: list[float]) -> object: - choice_idx = max(range(len(self.choices)), key=lambda i: encoded[i]) - return self.choices[choice_idx] - - -class PowerOfTwoEncoder(FragmentEncoder): - """Encoder for PowerOfTwoFragment using log2 transformation.""" - - def __init__(self, fragment: PowerOfTwoFragment) -> None: - super().__init__(fragment) - self.log_min = math.log2(fragment.low) - self.log_max = math.log2(fragment.high) - - def n_dims(self) -> int: - return 1 - - def is_categorical(self) -> bool: - return False - - def encode(self, value: int) -> list[float]: - return [math.log2(value)] - - def decode(self, encoded: list[float]) -> int: - log_val = encoded[0] - power = int(round(log_val)) - power = max(int(self.log_min), min(power, int(self.log_max))) - return 2**power - - -class IntegerEncoder(FragmentEncoder): - """Encoder for IntegerFragment using raw values.""" - - def __init__(self, fragment: IntegerFragment) -> None: - super().__init__(fragment) - self.min_val = fragment.low - self.max_val = fragment.high - - def n_dims(self) -> int: - return 1 - - def is_categorical(self) -> bool: - return False - - def encode(self, value: object) -> list[float]: - return [float(value)] - - def decode(self, encoded: list[float]) -> int: - value = int(round(encoded[0])) - return max(self.min_val, min(value, self.max_val)) - - -class PermutationEncoder(FragmentEncoder): - """Encoder for PermutationFragment using one-hot encoding for each position.""" - - def __init__(self, fragment: PermutationFragment) -> None: - super().__init__(fragment) - self.length = fragment.length - - def n_dims(self) -> int: - return self.length * self.length - - def is_categorical(self) -> bool: - return True - - def encode(self, value: list[int]) -> list[float]: - encoded = [] - for pos in range(self.length): - val = value[pos] - for v in range(self.length): - encoded.append(1.0 if v == val else 0.0) - return encoded - - def decode(self, encoded: list[float]) -> list[int]: - perm = [] - used = set() - - for pos in range(self.length): - start_idx = pos * self.length - one_hot = encoded[start_idx : start_idx + self.length] - val = max(range(self.length), key=lambda i: one_hot[i]) - perm.append(val) - used.add(val) - - # Fix invalid permutation (duplicates/missing values) - if len(used) != self.length: - available = [v for v in range(self.length) if v not in used] - seen = set() - fixed_perm = [] - for val in perm: - if val in seen: - fixed_val = available.pop(0) - fixed_perm.append(fixed_val) - else: - fixed_perm.append(val) - seen.add(val) - return fixed_perm - - return perm - - -class ListOfEncoder(FragmentEncoder): - """Encoder for ListOf fragments, delegates to inner encoder.""" - - def __init__(self, fragment: ListOf, inner_encoder: FragmentEncoder) -> None: - super().__init__(fragment) - self.length = fragment.length - self.inner_encoder = inner_encoder - self.inner_dims = inner_encoder.n_dims() - - def n_dims(self) -> int: - return self.length * self.inner_dims - - def is_categorical(self) -> bool: - """Return True if the inner encoder is categorical.""" - return self.inner_encoder.is_categorical() - - def encode(self, value: list[object]) -> list[float]: - encoded = [] - for v in value: - encoded.extend(self.inner_encoder.encode(v)) - return encoded - - def decode(self, encoded: list[float]) -> list[object]: - decoded = [] - for i in range(self.length): - start_idx = i * self.inner_dims - element_encoded = encoded[start_idx : start_idx + self.inner_dims] - decoded.append(self.inner_encoder.decode(element_encoded)) - return decoded - - -def create_encoder(fragment: ConfigSpecFragment) -> FragmentEncoder: - """Factory function to create the appropriate encoder for a fragment.""" - if isinstance(fragment, BooleanFragment): - return CategoricalEncoder(fragment, [False, True]) - if isinstance(fragment, EnumFragment): - return CategoricalEncoder(fragment, list(fragment.choices)) - if isinstance(fragment, PowerOfTwoFragment): - return PowerOfTwoEncoder(fragment) - if isinstance(fragment, IntegerFragment): - return IntegerEncoder(fragment) - if isinstance(fragment, PermutationFragment): - return PermutationEncoder(fragment) - if isinstance(fragment, ListOf): - inner_encoder = create_encoder(fragment.inner) - return ListOfEncoder(fragment, inner_encoder) - raise ValueError(f"Unsupported fragment type: {type(fragment).__name__}") - - -class ConfigEncoder: - """Encodes and decodes entire configurations using fragment encoders.""" - - def __init__(self, flat_spec: list[ConfigSpecFragment]) -> None: - """Initialize encoders for all fragments in the spec. - - Args: - flat_spec: List of fragment specifications - """ - self.encoders = [create_encoder(fragment) for fragment in flat_spec] - self.total_dims = sum(encoder.n_dims() for encoder in self.encoders) - - # Build categorical dimension indices (absolute positions) - self.cat_dims = [] - offset = 0 - for encoder in self.encoders: - n_dims = encoder.n_dims() - if encoder.is_categorical(): - # All dimensions of this encoder are categorical - self.cat_dims.extend(range(offset, offset + n_dims)) - offset += n_dims - - def encode(self, flat_config: list[object]) -> list[float]: - """Encode a flat configuration into a list of floats. - - Args: - flat_config: List of configuration values - - Returns: - List of encoded float values - """ - encoded = [] - for value, encoder in zip(flat_config, self.encoders, strict=False): - encoded.extend(encoder.encode(value)) - return encoded - - def decode(self, encoded: list[float]) -> list[object]: - """Decode a list of floats back into a flat configuration. - - Args: - encoded: List of encoded float values - - Returns: - List of decoded configuration values - """ - decoded = [] - idx = 0 - for encoder in self.encoders: - n_dims = encoder.n_dims() - fragment_encoded = encoded[idx : idx + n_dims] - decoded.append(encoder.decode(fragment_encoded)) - idx += n_dims - return decoded diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index d749be0ba..8f7149b2a 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -1,34 +1,33 @@ from __future__ import annotations import math +import operator import random from typing import TYPE_CHECKING import torch from .. import exc -from .base_search import FlatConfig -from .base_search import PopulationMember -from .base_search import performance +from .base_search import FlatConfig, performance, PopulationMember from .config_fragment import PowerOfTwoFragment from .effort_profile import PATTERN_SEARCH_DEFAULTS -from .fragment_encoder import ConfigEncoder from .pattern_search import PatternSearch if TYPE_CHECKING: - from collections.abc import Iterator - from collections.abc import Sequence + from collections.abc import Iterator, Sequence from ..runtime.config import Config from ..runtime.kernel import BoundKernel +try: + from botorch.acquisition import UpperConfidenceBound + from botorch.fit import fit_gpytorch_mll + from botorch.models import MixedSingleTaskGP + from gpytorch.mlls import ExactMarginalLogLikelihood -import operator - -from botorch.acquisition import UpperConfidenceBound -from botorch.fit import fit_gpytorch_mll -from botorch.models import MixedSingleTaskGP -from gpytorch.mlls import ExactMarginalLogLikelihood + HAS_BO_DEPS = True +except ImportError: + HAS_BO_DEPS = False class UCBPatternSearch(PatternSearch): @@ -54,44 +53,58 @@ def __init__( max_generations=max_generations, min_improvement_delta=min_improvement_delta, ) - # Storage for BO self.num_neighbors = num_neighbors self.radius = radius self.ucb_beta = ucb_beta # Initialize config encoder - self.config_encoder = ConfigEncoder(self.config_gen.flat_spec) self.frac_selected = frac_selected + self.cat_dims = [] + offset = 0 + for spec in self.config_gen.flat_spec: + n_dims = spec.encode_dim() + if encoder.is_categorical(): + # All dimensions of this encoder are categorical + self.cat_dims.extend(range(offset, offset + n_dims)) + offset += n_dims + def fit_gp( self, train_X: torch.Tensor, train_Y: torch.Tensor, cat_dims: list ) -> MixedSingleTaskGP: # Filter out rows where train_Y contains inf or nan - valid_mask = torch.isfinite(train_Y) - train_X_filtered = train_X[valid_mask] - train_Y_filtered = train_Y[valid_mask] - - gp = MixedSingleTaskGP( - train_X_filtered.to(dtype=torch.float64), - -train_Y_filtered.unsqueeze(-1).to(dtype=torch.float64), - cat_dims, - ) - with torch.enable_grad(): - mll = ExactMarginalLogLikelihood(gp.likelihood, gp) - fit_gpytorch_mll(mll) + if HAS_BO_DEPS: + valid_mask = torch.isfinite(train_Y) + train_X_filtered = train_X[valid_mask] + train_Y_filtered = train_Y[valid_mask] + + gp = MixedSingleTaskGP( + train_X_filtered.to(dtype=torch.float64), + -train_Y_filtered.unsqueeze(-1).to(dtype=torch.float64), + cat_dims, + ) + + with torch.enable_grad(): + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) - return gp + return gp + else: + return None - def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP) -> torch.Tensor: + def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP | None) -> torch.Tensor: orig_dtype = X.dtype - acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) - return ( - acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) - .detach() - .to(dtype=orig_dtype) - ) + if HAS_BO_DEPS: + acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) + return ( + acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) + .detach() + .to(dtype=orig_dtype) + ) + else: + return torch.zeros(X.shape[0], dtype=orig_dtype) def get_train_data_from_pop( self, population: list[PopulationMember] @@ -99,7 +112,9 @@ def get_train_data_from_pop( train_X = [] train_Y = [] for member in population: - train_X.append(torch.tensor(self.config_encoder.encode(member.flat_values))) + train_X.append( + torch.tensor(self.config_gen.encode_config(member.flat_values)) + ) train_Y.append(member.perf) return torch.stack(train_X), torch.tensor(train_Y) @@ -187,7 +202,10 @@ def _autotune(self) -> Config: self.log( f"Conditioning on new data: {len(train_X)} points, {len(train_Y)} targets" ) - gp = gp.condition_on_observations(train_X, train_Y) + if HAS_BO_DEPS: + gp = gp.condition_on_observations(train_X, train_Y) + else: + gp = None return self.best.config @@ -234,11 +252,15 @@ def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: block_spec = self.config_gen.flat_spec[block_idx] current_val = base[block_idx] + assert type(current_val) is int if isinstance(block_spec, PowerOfTwoFragment): # Change by at most 1 in log2 space new_flat[block_idx] = self.random_log2_neighbor( - current_val, radius=1, low=block_spec.low, high=block_spec.high + current_val, + radius=self.radius, + low=block_spec.low, + high=block_spec.high, ) else: raise ValueError("BlockSize should be PowerOfTwoFragment") @@ -250,6 +272,7 @@ def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: warp_spec = self.config_gen.flat_spec[warp_idx] current_val = base[warp_idx] + assert type(current_val) is int if isinstance(warp_spec, PowerOfTwoFragment): # Change by at most self.radius in log2 space @@ -295,7 +318,7 @@ def _pruned_pattern_search_from( self, current: PopulationMember, visited: set[Config], - gp: MixedSingleTaskGP, + gp: MixedSingleTaskGP | None, ) -> Iterator[list[PopulationMember]]: """ Run a single copy of pattern search from the given starting point. From 140b3b2add5f85f8ce31ee33315b95a042a85d70 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Thu, 13 Nov 2025 16:17:38 -0800 Subject: [PATCH 04/36] adapt ucb_pattern_search to new encoder --- helion/autotuner/ucb_pattern_search.py | 11 +++- test/test_autotuner.py | 71 -------------------------- 2 files changed, 10 insertions(+), 72 deletions(-) diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index 8f7149b2a..c6bb252c0 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -31,6 +31,15 @@ class UCBPatternSearch(PatternSearch): + """ + Modifies PatternSearch to (1) generate random neighbors from each search copy + within a set radius, (2) filter the neighbors to benchmark using a fitted GaussianProcess + with the UCB acquisition function. + + Uses the MixedSingleTaskGP model from botorch, which supports continuous + and categorical variables. It only fits the GP once to avoid long runtimes. + """ + def __init__( self, kernel: BoundKernel, @@ -327,7 +336,7 @@ def _pruned_pattern_search_from( run multiple copies of pattern search in parallel. Only keep self.frac_selected of the neighbors generated from the current - search_copy. Filter them using the GaussianProcess. + search_copy. Filter them using the GaussianProcess + UCB acqusition function. """ for _ in range(self.max_generations): candidates = [current] diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 61e70fc13..5353295c1 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -663,77 +663,6 @@ def test_ucb_pattern_search_generate_neighbors(self): # Check boolean self.assertIn(neighbor[4], [True, False]) - def test_ucb_pattern_search_generate_neighbors_radius(self): - """Test that UCBPatternSearch respects radius parameter.""" - # Test with different radius values to ensure constraint holds - test_cases = [ - { - "radius": 1, - "base": [64, 8, "y"], - "expected_warps": [4, 8, 16], # 8=2^3, radius=1 -> 2^2, 2^3, 2^4 - }, - { - "radius": 2, - "base": [32, 8, "y"], - "expected_warps": [ - 2, - 4, - 8, - 16, - ], # 8=2^3, radius=2 -> 2^1, 2^2, 2^3, 2^4 - }, - { - "radius": 0, - "base": [64, 4, "y"], - "expected_warps": [4], # radius=0 -> no change to num_warps - }, - ] - - for test_case in test_cases: - random.seed(123) - search = UCBPatternSearch.__new__(UCBPatternSearch) - search.num_neighbors = 100 - search.radius = test_case["radius"] - search.config_gen = SimpleNamespace( - flat_spec=[ - PowerOfTwoFragment(16, 128, 64), # block_size - PowerOfTwoFragment(2, 16, 8), # num_warps - EnumFragment(("x", "y", "z")), - ], - block_size_indices=[0], - num_warps_index=1, - ) - - base = test_case["base"] - neighbors = search._generate_neighbors(base) - - # Verify all neighbors respect strict constraints - for neighbor in neighbors: - # Block size should ALWAYS vary by at most 1 in log2 space (independent of radius) - base_log0 = int(math.log2(base[0])) - neighbor_log0 = int(math.log2(neighbor[0])) - self.assertLessEqual( - abs(neighbor_log0 - base_log0), - 1, - f"With radius={search.radius}, block size changed by {abs(neighbor_log0 - base_log0)} in log2 space (more than 1)", - ) - self.assertIn(neighbor[0], [16, 32, 64, 128]) - - # num_warps should vary by at most radius in log2 space - base_log1 = int(math.log2(base[1])) - neighbor_log1 = int(math.log2(neighbor[1])) - self.assertLessEqual( - abs(neighbor_log1 - base_log1), - search.radius, - f"With radius={search.radius}, num_warps changed by {abs(neighbor_log1 - base_log1)} in log2 space (more than radius={search.radius})", - ) - # Verify num_warps is within expected values for this radius - self.assertIn( - neighbor[1], - test_case["expected_warps"], - f"With radius={search.radius}, base num_warps={base[1]}, got {neighbor[1]} which is not in expected {test_case['expected_warps']}", - ) - @skipIfRocm("too slow on rocm") @skip("too slow") def test_ucb_pattern_search(self): From ca9925209eb4be65c49f8609054e373f85446511 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Thu, 13 Nov 2025 16:53:40 -0800 Subject: [PATCH 05/36] merged new config fragment --- helion/autotuner/config_fragment.py | 15 ++---- helion/autotuner/ucb_pattern_search.py | 27 ++++++---- test/test_autotuner.py | 75 +++++++++++++++++--------- 3 files changed, 70 insertions(+), 47 deletions(-) diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index 5dbb93b6b..39698d281 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -251,15 +251,8 @@ def encode_dim(self) -> int: return 1 def encode(self, value: object) -> list[float]: - """Encode enum values as their index.""" - try: - choice_idx = self.choices.index(value) - except ValueError: - raise ValueError( - f"Invalid enum value {value!r} for EnumFragment. " - f"Valid choices: {self.choices}" - ) from None - return [float(choice_idx)] + assert isinstance(value, int) + return [float(value)] @dataclasses.dataclass @@ -389,8 +382,8 @@ def encode_dim(self): return self.length * self.inner.encode_dim() def encode(self, value: object) -> list[float]: - assert isinstance(value, list[object]) + assert isinstance(value, list) encoded = [] for v in value: - encoded.extend(self.inner_encoder.encode(v)) + encoded.extend(self.inner.encode(v)) return encoded diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index c6bb252c0..dac04bf64 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -32,6 +32,8 @@ class UCBPatternSearch(PatternSearch): """ + Upper Confidence Bound (UCB) Pattern Search + Modifies PatternSearch to (1) generate random neighbors from each search copy within a set radius, (2) filter the neighbors to benchmark using a fitted GaussianProcess with the UCB acquisition function. @@ -74,7 +76,7 @@ def __init__( offset = 0 for spec in self.config_gen.flat_spec: n_dims = spec.encode_dim() - if encoder.is_categorical(): + if spec.is_categorical(): # All dimensions of this encoder are categorical self.cat_dims.extend(range(offset, offset + n_dims)) offset += n_dims @@ -85,13 +87,9 @@ def fit_gp( # Filter out rows where train_Y contains inf or nan if HAS_BO_DEPS: - valid_mask = torch.isfinite(train_Y) - train_X_filtered = train_X[valid_mask] - train_Y_filtered = train_Y[valid_mask] - gp = MixedSingleTaskGP( - train_X_filtered.to(dtype=torch.float64), - -train_Y_filtered.unsqueeze(-1).to(dtype=torch.float64), + train_X, + -train_Y.unsqueeze(-1), cat_dims, ) @@ -126,7 +124,14 @@ def get_train_data_from_pop( ) train_Y.append(member.perf) - return torch.stack(train_X), torch.tensor(train_Y) + train_X = torch.stack(train_X) + train_Y = torch.tensor(train_Y) + + valid_mask = torch.isfinite(train_Y) + train_X_filtered = train_X[valid_mask].to(dtype=torch.float64) + train_Y_filtered = train_Y[valid_mask].to(dtype=torch.float64) + + return train_X_filtered, train_Y_filtered def _autotune(self) -> Config: self.log( @@ -164,7 +169,7 @@ def _autotune(self) -> Config: gp = self.fit_gp( train_X, train_Y, - self.config_encoder.cat_dims, + self.cat_dims, ) search_copies = [ @@ -212,7 +217,7 @@ def _autotune(self) -> Config: f"Conditioning on new data: {len(train_X)} points, {len(train_Y)} targets" ) if HAS_BO_DEPS: - gp = gp.condition_on_observations(train_X, train_Y) + gp = gp.condition_on_observations(train_X, -train_Y.unsqueeze(1)) else: gp = None @@ -350,7 +355,7 @@ def _pruned_pattern_search_from( # score candidates candidate_X = torch.stack( [ - torch.tensor(self.config_encoder.encode(member.flat_values)) + torch.tensor(self.config_gen.encode_config(member.flat_values)) for member in candidates ] ) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 5353295c1..db12c2cfd 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -37,6 +37,7 @@ DESurrogateHybrid, DifferentialEvolutionSearch, PatternSearch, + UCBPatternSearch, ) from helion.autotuner.base_search import BaseSearch, PopulationMember from helion.autotuner.config_fragment import ( @@ -44,6 +45,7 @@ EnumFragment, IntegerFragment, ListOf, + PermutationFragment, PowerOfTwoFragment, ) from helion.autotuner.config_generation import ConfigGeneration @@ -644,7 +646,7 @@ def test_ucb_pattern_search_generate_neighbors(self): neighbors = search._generate_neighbors(base) # Check we generate the correct number of neighbors - self.assertEquals(len(neighbors), search.num_neighbors) + self.assertEqual(len(neighbors), search.num_neighbors) # Check all neighbors are different from base for neighbor in neighbors: @@ -683,30 +685,6 @@ def test_ucb_pattern_search(self): fn = bound_kernel.compile_config(best) torch.testing.assert_close(fn(*args), sum(args), rtol=1e-2, atol=1e-1) - def test_encoder(self): - args = ( - torch.randn([64, 64], device=DEVICE), - torch.randn([64, 64], device=DEVICE), - ) - bound_kernel = basic_kernels.add.bind(args) - random.seed(123) - - test_flat_configs = [] - - # First attach the default config - config_gen = ConfigGeneration(bound_kernel.config_spec) - default_flat = config_gen.default_flat() - test_flat_configs.append(default_flat) - - # Test random configs - random_configs = config_gen.random_population_flat(10) - test_flat_configs = test_flat_configs + random_configs - - for flat_config in test_flat_configs: - encoded = helion.ConfigEncoder().encode(config) - decoded = helion.ConfigEncoder().decode(encoded) - self.assertEqual(flat_config, decoded) - @skipIfCpu("fails on Triton CPU backend") def test_accuracy_check_filters_bad_config_wrong_output(self) -> None: bad_config = helion.Config(block_sizes=[1], num_warps=8) @@ -1145,6 +1123,53 @@ def add(a, b): ): add(*args) + def test_fragment_encoding(self): + """Test encoding functionality for all ConfigSpecFragment types.""" + # Test BooleanFragment + bool_frag = BooleanFragment() + self.assertEqual(bool_frag.encode_dim(), 1) + self.assertEqual(bool_frag.encode(True), [1.0]) + self.assertEqual(bool_frag.encode(False), [0.0]) + + # Test IntegerFragment + int_frag = IntegerFragment(low=1, high=10, default_val=5) + self.assertEqual(int_frag.encode_dim(), 1) + self.assertEqual(int_frag.encode(5), [5.0]) + + # Test PowerOfTwoFragment (log2 transformation) + pow2_frag = PowerOfTwoFragment(low=2, high=128, default_val=8) + self.assertEqual(pow2_frag.encode_dim(), 1) + self.assertEqual(pow2_frag.encode(8), [3.0]) # log2(8) = 3 + self.assertEqual(pow2_frag.encode(16), [4.0]) # log2(16) = 4 + + # Test EnumFragment (one-hot encoding) + enum_frag = EnumFragment(choices=("a", "b", "c")) + self.assertEqual(enum_frag.encode_dim(), 3) + self.assertEqual(enum_frag.encode("a"), [1.0, 0.0, 0.0]) + self.assertEqual(enum_frag.encode("b"), [0.0, 1.0, 0.0]) + + # Test PermutationFragment + perm_frag = PermutationFragment(length=3) + self.assertEqual(perm_frag.encode_dim(), 9) + encoded = perm_frag.encode([0, 1, 2]) + self.assertEqual(encoded, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]) + + # Test ListOf with BooleanFragment + list_frag = ListOf(inner=BooleanFragment(), length=3) + self.assertEqual(list_frag.encode_dim(), 3) + self.assertEqual(list_frag.encode([True, False, True]), [1.0, 0.0, 1.0]) + + # Test encode_dim consistency + for fragment, value in [ + (BooleanFragment(), True), + (IntegerFragment(1, 10, 5), 5), + (PowerOfTwoFragment(2, 128, 8), 16), + (EnumFragment(choices=("a", "b")), "b"), + ]: + encode_dim = fragment.encode_dim() + encoded = fragment.encode(value) + self.assertEqual(len(encoded), encode_dim) + class TestAutotuneRandomSeed(RefEagerTestDisabled, TestCase): def _autotune_and_record(self, **settings: object) -> float: From 899d30dfd3514d4007522b18b0a710780669bf21 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Thu, 13 Nov 2025 17:01:24 -0800 Subject: [PATCH 06/36] imports --- .github/workflows/benchmark.yml | 2 +- .github/workflows/test.yml | 2 +- helion/autotuner/__init__.py | 32 +++++++++++++++++++------------- pyproject.toml | 3 +++ 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 9c4e73bb0..18dd35d6b 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -100,7 +100,7 @@ jobs: - name: Install Helion run: | source .venv/bin/activate - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate,bayesopt]' python -c "import helion; print(helion.__name__)" - name: Install Benchmark Requirements diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 51e196fcf..573751dad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,7 @@ jobs: run: | source .venv/bin/activate uv pip install setuptools ninja - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate,bayesopt]' python -c "import helion; print(helion.__name__)" - name: Run Tests diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index f2c1873db..9350abf96 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -1,29 +1,35 @@ from __future__ import annotations -from .fragment_encoder import ConfigEncoder -from .config_fragment import BooleanFragment as BooleanFragment -from .config_fragment import EnumFragment as EnumFragment -from .config_fragment import IntegerFragment as IntegerFragment -from .config_fragment import ListOf as ListOf -from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment +from .config_fragment import ( + BooleanFragment as BooleanFragment, + EnumFragment as EnumFragment, + IntegerFragment as IntegerFragment, + ListOf as ListOf, + PowerOfTwoFragment as PowerOfTwoFragment, +) from .config_spec import ConfigSpec as ConfigSpec from .de_surrogate_hybrid import DESurrogateHybrid as DESurrogateHybrid from .differential_evolution import ( DifferentialEvolutionSearch as DifferentialEvolutionSearch, ) -from .effort_profile import AutotuneEffortProfile as AutotuneEffortProfile -from .effort_profile import DifferentialEvolutionConfig as DifferentialEvolutionConfig -from .effort_profile import PatternSearchConfig as PatternSearchConfig -from .effort_profile import RandomSearchConfig as RandomSearchConfig +from .effort_profile import ( + AutotuneEffortProfile as AutotuneEffortProfile, + DifferentialEvolutionConfig as DifferentialEvolutionConfig, + PatternSearchConfig as PatternSearchConfig, + RandomSearchConfig as RandomSearchConfig, +) from .finite_search import FiniteSearch as FiniteSearch -from .local_cache import LocalAutotuneCache as LocalAutotuneCache -from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache +from .local_cache import ( + LocalAutotuneCache as LocalAutotuneCache, + StrictLocalAutotuneCache as StrictLocalAutotuneCache, +) from .pattern_search import PatternSearch as PatternSearch -from .ucb_pattern_search import UCBPatternSearch from .random_search import RandomSearch as RandomSearch +from .ucb_pattern_search import UCBPatternSearch search_algorithms = { "DESurrogateHybrid": DESurrogateHybrid, + "UCBPatternSearch": UCBPatternSearch, "DifferentialEvolutionSearch": DifferentialEvolutionSearch, "FiniteSearch": FiniteSearch, "PatternSearch": PatternSearch, diff --git a/pyproject.toml b/pyproject.toml index fc063289f..951bb4eb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ de-surrogate = [ "numpy", "scikit-learn>=1.3.0" ] +bayesopt = [ + "botorch>=0.16.0" +] dev = [ "expecttest", "pytest", From 13ea3b4b3f9cb57ffc115c9499a0ee4228779d63 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Thu, 13 Nov 2025 19:25:51 -0800 Subject: [PATCH 07/36] remove encode_scalar --- helion/autotuner/config_fragment.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index 39698d281..15b7da461 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -70,26 +70,6 @@ def get_minimum(self) -> int: """ raise NotImplementedError - def encode_scalar(self, value: object) -> float: - """ - Encode a configuration value into a float for ML models. - - This is used by surrogate-assisted algorithms to convert configurations - into numerical vectors for prediction models. - - Args: - value: The configuration value to encode. - - Returns: - A float representing the encoded value. - """ - # Default: convert to float if possible - if not isinstance(value, (int, float, bool)): - raise TypeError( - f"Cannot encode {type(value).__name__} value {value!r} for ML" - ) - return float(value) - @dataclasses.dataclass class PermutationFragment(ConfigSpecFragment): From 3b739f5a05689422a1446622fdff70c70cfdd7a5 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Fri, 14 Nov 2025 19:11:01 -0800 Subject: [PATCH 08/36] fix imports --- helion/autotuner/config_fragment.py | 51 +++++++--------- helion/autotuner/config_generation.py | 7 ++- helion/autotuner/ucb_pattern_search.py | 81 ++++++++++++++------------ 3 files changed, 68 insertions(+), 71 deletions(-) diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index 15b7da461..9f005fbcd 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -3,7 +3,9 @@ import dataclasses import enum import random -from typing import cast, Iterable, TypeGuard +from typing import Iterable +from typing import TypeGuard +from typing import cast from ..exc import InvalidConfig @@ -52,7 +54,7 @@ def is_block_size(self) -> bool: def is_categorical(self) -> bool: return True - def encode_dim(self) -> int: + def dim(self) -> int: """ Returns the dimension of the output of encode """ @@ -99,16 +101,16 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(swapped) return neighbors - def encode_dim(self) -> int: - return self.length * self.length + def dim(self) -> int: + return self.length def encode(self, value: object) -> list[float]: + assert isinstance(value, list) encoded = [] - for pos in range(self.length): - val = value[pos] - for v in range(self.length): - encoded.append(1.0 if v == val else 0.0) - return encoded + for val in value: + assert isinstance(val, int) + encoded.append(float(val)) + return value @dataclasses.dataclass @@ -136,6 +138,9 @@ def is_categorical(self) -> bool: def get_minimum(self) -> int: return self.low + def dim(self) -> int: + return 1 + def pattern_neighbors(self, current: object) -> list[object]: if type(current) is not int: # bool is not allowed raise TypeError(f"Expected int, got {type(current).__name__}") @@ -148,19 +153,9 @@ def pattern_neighbors(self, current: object) -> list[object]: neighbors.append(upper) return neighbors - def encode_dim(self) -> int: - return 1 - def encode(self, value: object) -> list[float]: - """Encode enum values as their index.""" - try: - choice_idx = self.choices.index(value) - except ValueError: - raise ValueError( - f"Invalid enum value {value!r} for EnumFragment. " - f"Valid choices: {self.choices}" - ) from None - return [float(choice_idx)] + assert isinstance(value, int) + return [float(value)] class PowerOfTwoFragment(BaseIntegerFragment): @@ -193,9 +188,6 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(ai * 2) return ai - def encode_dim(self) -> int: - return 1 - def encode(self, value: object) -> list[float]: """Encode power-of-2 values using log2 transformation.""" import math @@ -227,9 +219,6 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(a + 1) return a - def encode_dim(self) -> int: - return 1 - def encode(self, value: object) -> list[float]: assert isinstance(value, int) return [float(value)] @@ -258,7 +247,7 @@ def differential_mutation(self, a: object, b: object, c: object) -> object: choices.remove(a) return random.choice(choices) - def encode_dim(self) -> int: + def dim(self) -> int: return len(self.choices) def encode(self, value: object) -> list[float]: @@ -291,7 +280,7 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool: return a return not a - def encode_dim(self) -> int: + def dim(self) -> int: return 1 def encode(self, value: object) -> list[float]: @@ -358,8 +347,8 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object] for i in range(self.length) ] - def encode_dim(self): - return self.length * self.inner.encode_dim() + def dim(self): + return self.length * self.inner.dim() def encode(self, value: object) -> list[float]: assert isinstance(value, list) diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index 30cb00e73..da8a83bc6 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -5,10 +5,13 @@ import itertools import operator import random -from typing import cast, TYPE_CHECKING +from typing import TYPE_CHECKING +from typing import cast from .._compat import warps_to_threads -from .config_fragment import Category, ConfigSpecFragment, PowerOfTwoFragment +from .config_fragment import Category +from .config_fragment import ConfigSpecFragment +from .config_fragment import PowerOfTwoFragment if TYPE_CHECKING: from collections.abc import Mapping diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index dac04bf64..b06ccf8a0 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -1,5 +1,6 @@ from __future__ import annotations +from itertools import accumulate import math import operator import random @@ -8,13 +9,16 @@ import torch from .. import exc -from .base_search import FlatConfig, performance, PopulationMember +from .base_search import FlatConfig +from .base_search import PopulationMember +from .base_search import performance from .config_fragment import PowerOfTwoFragment from .effort_profile import PATTERN_SEARCH_DEFAULTS from .pattern_search import PatternSearch if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterator + from collections.abc import Sequence from ..runtime.config import Config from ..runtime.kernel import BoundKernel @@ -26,8 +30,9 @@ from gpytorch.mlls import ExactMarginalLogLikelihood HAS_BO_DEPS = True -except ImportError: +except ImportError as e: HAS_BO_DEPS = False + _IMPORT_ERROR = e class UCBPatternSearch(PatternSearch): @@ -56,6 +61,11 @@ def __init__( radius: int = 2, ucb_beta: float = 2.0, ) -> None: + if not HAS_BO_DEPS: + raise exc.MissingDependency( + "UCBPatternSearch requires botorch>=0.16.0.Install before using." + ) from _IMPORT_ERROR + super().__init__( kernel=kernel, args=args, @@ -72,46 +82,43 @@ def __init__( # Initialize config encoder self.frac_selected = frac_selected - self.cat_dims = [] - offset = 0 - for spec in self.config_gen.flat_spec: - n_dims = spec.encode_dim() - if spec.is_categorical(): - # All dimensions of this encoder are categorical - self.cat_dims.extend(range(offset, offset + n_dims)) - offset += n_dims + # compute offsets from the flat_spec + dim_sizes = [spec.dim() for spec in self.config_gen.flat_spec] + offsets = [0, *list(accumulate(dim_sizes))] + + self.cat_dims = [ + idx + for i, spec in enumerate(self.config_gen.flat_spec) + if spec.is_categorical() + for idx in range(offsets[i], offsets[i + 1]) + ] def fit_gp( self, train_X: torch.Tensor, train_Y: torch.Tensor, cat_dims: list ) -> MixedSingleTaskGP: # Filter out rows where train_Y contains inf or nan - if HAS_BO_DEPS: - gp = MixedSingleTaskGP( - train_X, - -train_Y.unsqueeze(-1), - cat_dims, - ) + gp = MixedSingleTaskGP( + train_X, + -train_Y.unsqueeze(-1), + cat_dims, + ) - with torch.enable_grad(): - mll = ExactMarginalLogLikelihood(gp.likelihood, gp) - fit_gpytorch_mll(mll) + with torch.enable_grad(): + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) + fit_gpytorch_mll(mll) - return gp - else: - return None + return gp - def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP | None) -> torch.Tensor: + def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP) -> torch.Tensor: orig_dtype = X.dtype - if HAS_BO_DEPS: - acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) - return ( - acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) - .detach() - .to(dtype=orig_dtype) - ) - else: - return torch.zeros(X.shape[0], dtype=orig_dtype) + + acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) + return ( + acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) + .detach() + .to(dtype=orig_dtype) + ) def get_train_data_from_pop( self, population: list[PopulationMember] @@ -216,10 +223,7 @@ def _autotune(self) -> Config: self.log( f"Conditioning on new data: {len(train_X)} points, {len(train_Y)} targets" ) - if HAS_BO_DEPS: - gp = gp.condition_on_observations(train_X, -train_Y.unsqueeze(1)) - else: - gp = None + gp = gp.condition_on_observations(train_X, -train_Y.unsqueeze(1)) return self.best.config @@ -332,7 +336,7 @@ def _pruned_pattern_search_from( self, current: PopulationMember, visited: set[Config], - gp: MixedSingleTaskGP | None, + gp: MixedSingleTaskGP, ) -> Iterator[list[PopulationMember]]: """ Run a single copy of pattern search from the given starting point. @@ -351,6 +355,7 @@ def _pruned_pattern_search_from( new_member = self.make_unbenchmarked(flat_config) if new_member.config not in visited: candidates.append(new_member) + visited.add(new_member.config) # score candidates candidate_X = torch.stack( From 62aaf7700689e00bc8d156861c95b87c89038d64 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Fri, 14 Nov 2025 19:25:36 -0800 Subject: [PATCH 09/36] early stopping helper for pattern search --- helion/autotuner/pattern_search.py | 35 ++++++++++++++++++-------- helion/autotuner/ucb_pattern_search.py | 12 ++------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/helion/autotuner/pattern_search.py b/helion/autotuner/pattern_search.py index d8d759b31..6e36ea36d 100644 --- a/helion/autotuner/pattern_search.py +++ b/helion/autotuner/pattern_search.py @@ -134,20 +134,35 @@ def _pattern_search_from( if len(candidates) <= 1: return # no new candidates, stop searching yield candidates # yield new population to benchmark in parallel + # update search copy and check early stopping criteria best = min(candidates, key=performance) - if best is current: - return # no improvement, stop searching - # Stop if the relative improvement is smaller than a user-specified delta - if ( - self.min_improvement_delta > 0.0 - and math.isfinite(best.perf) - and math.isfinite(current.perf) - and current.perf != 0.0 - and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta - ): + if self._check_early_stopping(best, current): return current = best + def _check_early_stopping( + self, best: PopulationMember, current: PopulationMember + ) -> bool: + """ + Check if early stopping criteria are met for the search copy + + Early stops if either the best config has not changed or if + the relative improvement is smaller than a user-specified delta + + Returns: + True the search copy is terminated, False otherwise. + """ + if best is current: + return True # no improvement, stop searching + # Stop if the relative improvement is smaller than a user-specified delta + return bool( + self.min_improvement_delta > 0.0 + and math.isfinite(best.perf) + and math.isfinite(current.perf) + and current.perf != 0.0 + and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta + ) + def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: """ Generate neighboring configurations by changing one or two parameters at a time. diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index b06ccf8a0..f037bfe8b 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -382,16 +382,8 @@ def _pruned_pattern_search_from( if len(candidates) <= 1: return # no new candidates, stop searching yield candidates # yield new population to benchmark in parallel + # update search copy and check early stopping criteria best = min(candidates, key=performance) - if best is current: - return # no improvement, stop searching - # Stop if the relative improvement is smaller than a user-specified delta - if ( - self.min_improvement_delta > 0.0 - and math.isfinite(best.perf) - and math.isfinite(current.perf) - and current.perf != 0.0 - and abs(best.perf / current.perf - 1.0) < self.min_improvement_delta - ): + if self._check_early_stopping(best, current): return current = best From 97963f5f1085069f5ef45a84bbe3838463f4055d Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:25:22 -0800 Subject: [PATCH 10/36] fix tests --- test/test_autotuner.py | 98 +++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index db12c2cfd..95049156f 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -1,59 +1,59 @@ from __future__ import annotations import collections +from contextlib import contextmanager +from contextlib import nullcontext import csv +from itertools import count import logging import math import multiprocessing as mp import operator import os +from pathlib import Path import pickle import random import tempfile -import unittest -from contextlib import contextmanager, nullcontext -from itertools import count -from pathlib import Path from types import SimpleNamespace -from typing import Callable, Sequence +from typing import Callable +from typing import Sequence +import unittest from unittest import skip from unittest.mock import patch -import helion -import helion.language as hl - import pytest import torch -from helion import _compat, exc -from helion._testing import ( - DEVICE, - import_path, - RefEagerTestDisabled, - skipIfCpu, - skipIfRocm, - TestCase, -) -from helion.autotuner import ( - DESurrogateHybrid, - DifferentialEvolutionSearch, - PatternSearch, - UCBPatternSearch, -) -from helion.autotuner.base_search import BaseSearch, PopulationMember -from helion.autotuner.config_fragment import ( - BooleanFragment, - EnumFragment, - IntegerFragment, - ListOf, - PermutationFragment, - PowerOfTwoFragment, -) + +import helion +from helion import _compat +from helion import exc +from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled +from helion._testing import TestCase +from helion._testing import import_path +from helion._testing import skipIfCpu +from helion._testing import skipIfRocm +from helion.autotuner import DESurrogateHybrid +from helion.autotuner import DifferentialEvolutionSearch +from helion.autotuner import PatternSearch +from helion.autotuner import UCBPatternSearch +from helion.autotuner.base_search import BaseSearch +from helion.autotuner.base_search import PopulationMember +from helion.autotuner.config_fragment import BooleanFragment +from helion.autotuner.config_fragment import EnumFragment +from helion.autotuner.config_fragment import IntegerFragment +from helion.autotuner.config_fragment import ListOf +from helion.autotuner.config_fragment import PermutationFragment +from helion.autotuner.config_fragment import PowerOfTwoFragment from helion.autotuner.config_generation import ConfigGeneration from helion.autotuner.effort_profile import get_effort_profile from helion.autotuner.finite_search import FiniteSearch -from helion.autotuner.local_cache import LocalAutotuneCache, StrictLocalAutotuneCache -from helion.autotuner.logger import AutotuneLogEntry, AutotuningLogger +from helion.autotuner.local_cache import LocalAutotuneCache +from helion.autotuner.local_cache import StrictLocalAutotuneCache +from helion.autotuner.logger import AutotuneLogEntry +from helion.autotuner.logger import AutotuningLogger from helion.autotuner.random_search import RandomSearch +import helion.language as hl from helion.language import loops from helion.runtime.settings import Settings @@ -725,7 +725,8 @@ def make_bad_config_produce_wrong_output( start_cm = patch.object( search, "start_precompile_and_check_for_hangs", - side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip( + side_effect=lambda config, + fn: base_search_module.PrecompileFuture.skip( search, config, True ), ) @@ -805,7 +806,8 @@ def wrong_fn(*fn_args, **fn_kwargs): start_cm = patch.object( search, "start_precompile_and_check_for_hangs", - side_effect=lambda config, fn: base_search_module.PrecompileFuture.skip( + side_effect=lambda config, + fn: base_search_module.PrecompileFuture.skip( search, config, True ), ) @@ -1127,36 +1129,36 @@ def test_fragment_encoding(self): """Test encoding functionality for all ConfigSpecFragment types.""" # Test BooleanFragment bool_frag = BooleanFragment() - self.assertEqual(bool_frag.encode_dim(), 1) + self.assertEqual(bool_frag.dim(), 1) self.assertEqual(bool_frag.encode(True), [1.0]) self.assertEqual(bool_frag.encode(False), [0.0]) # Test IntegerFragment int_frag = IntegerFragment(low=1, high=10, default_val=5) - self.assertEqual(int_frag.encode_dim(), 1) + self.assertEqual(int_frag.dim(), 1) self.assertEqual(int_frag.encode(5), [5.0]) # Test PowerOfTwoFragment (log2 transformation) pow2_frag = PowerOfTwoFragment(low=2, high=128, default_val=8) - self.assertEqual(pow2_frag.encode_dim(), 1) + self.assertEqual(pow2_frag.dim(), 1) self.assertEqual(pow2_frag.encode(8), [3.0]) # log2(8) = 3 self.assertEqual(pow2_frag.encode(16), [4.0]) # log2(16) = 4 # Test EnumFragment (one-hot encoding) enum_frag = EnumFragment(choices=("a", "b", "c")) - self.assertEqual(enum_frag.encode_dim(), 3) + self.assertEqual(enum_frag.dim(), 3) self.assertEqual(enum_frag.encode("a"), [1.0, 0.0, 0.0]) self.assertEqual(enum_frag.encode("b"), [0.0, 1.0, 0.0]) # Test PermutationFragment perm_frag = PermutationFragment(length=3) - self.assertEqual(perm_frag.encode_dim(), 9) + self.assertEqual(perm_frag.dim(), 3) encoded = perm_frag.encode([0, 1, 2]) - self.assertEqual(encoded, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]) + self.assertEqual(encoded, [0, 1, 2]) # Test ListOf with BooleanFragment list_frag = ListOf(inner=BooleanFragment(), length=3) - self.assertEqual(list_frag.encode_dim(), 3) + self.assertEqual(list_frag.dim(), 3) self.assertEqual(list_frag.encode([True, False, True]), [1.0, 0.0, 1.0]) # Test encode_dim consistency @@ -1166,9 +1168,9 @@ def test_fragment_encoding(self): (PowerOfTwoFragment(2, 128, 8), 16), (EnumFragment(choices=("a", "b")), "b"), ]: - encode_dim = fragment.encode_dim() + dim = fragment.dim() encoded = fragment.encode(value) - self.assertEqual(len(encoded), encode_dim) + self.assertEqual(len(encoded), dim) class TestAutotuneRandomSeed(RefEagerTestDisabled, TestCase): @@ -1201,9 +1203,9 @@ def add(a, b): torch.testing.assert_close(bound_kernel(*args), sum(args), rtol=1e-2, atol=1e-1) search = search_capture["search"] - assert ( - search.samples - ), "expected RecordingRandomSearch to record a random sample" + assert search.samples, ( + "expected RecordingRandomSearch to record a random sample" + ) return search.samples[0] @skipIfRocm("accuracy difference") From a0ef224a4effd1350ec43c926c75a8624e2783a0 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:25:59 -0800 Subject: [PATCH 11/36] fix dim --- helion/autotuner/config_fragment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index 9f005fbcd..b028478f5 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -347,7 +347,7 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object] for i in range(self.length) ] - def dim(self): + def dim(self) -> int: return self.length * self.inner.dim() def encode(self, value: object) -> list[float]: From 018a626e45acc2e35c7e66cf1a3712723bf250ba Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:29:38 -0800 Subject: [PATCH 12/36] ucb fix lints and better hyperparams --- helion/autotuner/ucb_pattern_search.py | 29 ++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index f037bfe8b..8e701521d 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -24,10 +24,14 @@ from ..runtime.kernel import BoundKernel try: - from botorch.acquisition import UpperConfidenceBound - from botorch.fit import fit_gpytorch_mll - from botorch.models import MixedSingleTaskGP - from gpytorch.mlls import ExactMarginalLogLikelihood + from botorch.acquisition import ( + UpperConfidenceBound, # type: ignore[import-not-found] + ) + from botorch.fit import fit_gpytorch_mll # type: ignore[import-not-found] + from botorch.models import MixedSingleTaskGP # type: ignore[import-not-found] + from gpytorch.mlls import ( + ExactMarginalLogLikelihood, # type: ignore[import-not-found] + ) HAS_BO_DEPS = True except ImportError as e: @@ -55,15 +59,15 @@ def __init__( initial_population: int = PATTERN_SEARCH_DEFAULTS.initial_population, copies: int = PATTERN_SEARCH_DEFAULTS.copies, max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, - min_improvement_delta: float = 0.001, + min_improvement_delta: float = 0.0005, frac_selected: float = 0.3, num_neighbors: int = 100, radius: int = 2, ucb_beta: float = 2.0, ) -> None: if not HAS_BO_DEPS: - raise exc.MissingDependency( - "UCBPatternSearch requires botorch>=0.16.0.Install before using." + raise exc.AutotuneError( + "UCBPatternSearch requires botorch. Install before using." ) from _IMPORT_ERROR super().__init__( @@ -98,23 +102,23 @@ def fit_gp( ) -> MixedSingleTaskGP: # Filter out rows where train_Y contains inf or nan - gp = MixedSingleTaskGP( + gp = MixedSingleTaskGP( # type: ignore[misc] train_X, -train_Y.unsqueeze(-1), cat_dims, ) with torch.enable_grad(): - mll = ExactMarginalLogLikelihood(gp.likelihood, gp) - fit_gpytorch_mll(mll) + mll = ExactMarginalLogLikelihood(gp.likelihood, gp) # type: ignore[misc] + fit_gpytorch_mll(mll) # type: ignore[misc] return gp def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP) -> torch.Tensor: orig_dtype = X.dtype - acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) - return ( + acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) # type: ignore[misc] + return ( # type: ignore[misc] acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) .detach() .to(dtype=orig_dtype) @@ -373,7 +377,6 @@ def _pruned_pattern_search_from( reverse=True, )[: int(self.frac_selected * len(candidates))] candidates = [member for member, score in candidates_sorted] - visited.update([member.config for member in candidates]) self.log( f"Scoring {len(candidate_X)} neighbors, selecting {self.frac_selected * 100}% neighbors: {len(candidates)}" From 2a6570186ca7664f239905a36e0c68ca85a8c238 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:32:24 -0800 Subject: [PATCH 13/36] revert linter changes --- helion/autotuner/__init__.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 9350abf96..83ce962ca 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -1,31 +1,25 @@ from __future__ import annotations -from .config_fragment import ( - BooleanFragment as BooleanFragment, - EnumFragment as EnumFragment, - IntegerFragment as IntegerFragment, - ListOf as ListOf, - PowerOfTwoFragment as PowerOfTwoFragment, -) +from .bayes_opt import UCBPatternSearch +from .config_fragment import BooleanFragment as BooleanFragment +from .config_fragment import EnumFragment as EnumFragment +from .config_fragment import IntegerFragment as IntegerFragment +from .config_fragment import ListOf as ListOf +from .config_fragment import PowerOfTwoFragment as PowerOfTwoFragment from .config_spec import ConfigSpec as ConfigSpec from .de_surrogate_hybrid import DESurrogateHybrid as DESurrogateHybrid from .differential_evolution import ( DifferentialEvolutionSearch as DifferentialEvolutionSearch, ) -from .effort_profile import ( - AutotuneEffortProfile as AutotuneEffortProfile, - DifferentialEvolutionConfig as DifferentialEvolutionConfig, - PatternSearchConfig as PatternSearchConfig, - RandomSearchConfig as RandomSearchConfig, -) +from .effort_profile import AutotuneEffortProfile as AutotuneEffortProfile +from .effort_profile import DifferentialEvolutionConfig as DifferentialEvolutionConfig +from .effort_profile import PatternSearchConfig as PatternSearchConfig +from .effort_profile import RandomSearchConfig as RandomSearchConfig from .finite_search import FiniteSearch as FiniteSearch -from .local_cache import ( - LocalAutotuneCache as LocalAutotuneCache, - StrictLocalAutotuneCache as StrictLocalAutotuneCache, -) +from .local_cache import LocalAutotuneCache as LocalAutotuneCache +from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache from .pattern_search import PatternSearch as PatternSearch from .random_search import RandomSearch as RandomSearch -from .ucb_pattern_search import UCBPatternSearch search_algorithms = { "DESurrogateHybrid": DESurrogateHybrid, From 79c0fa36e9c8c0763da8035f90adca967cec3d29 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:33:10 -0800 Subject: [PATCH 14/36] name change --- helion/autotuner/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 83ce962ca..751e7bffe 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -from .bayes_opt import UCBPatternSearch from .config_fragment import BooleanFragment as BooleanFragment from .config_fragment import EnumFragment as EnumFragment from .config_fragment import IntegerFragment as IntegerFragment @@ -20,6 +19,7 @@ from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache from .pattern_search import PatternSearch as PatternSearch from .random_search import RandomSearch as RandomSearch +from .ucb_pattern_search import UCBPatternSearch search_algorithms = { "DESurrogateHybrid": DESurrogateHybrid, From c2a578fc4c916a056951452b4aeb8a41d54b1b4e Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:35:04 -0800 Subject: [PATCH 15/36] revert unrelated changes in config_generation --- helion/autotuner/config_generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index da8a83bc6..0a11b218e 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -187,8 +187,10 @@ def encode_config(self, flat_config: FlatConfig) -> list[float]: Encode a flat configuration into a numerical vector for ML models. This is used by surrogate-assisted algorithms (e.g., DE-Surrogate) that need to represent configurations as continuous vectors for prediction models. + Args: flat_config: The flat configuration values to encode. + Returns: A list of floats representing the encoded configuration. """ From fc2929d78878fd58c1cca77ec7af4bd03aa10e18 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:35:35 -0800 Subject: [PATCH 16/36] revert unrelated changes in config_generation --- helion/autotuner/config_generation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/helion/autotuner/config_generation.py b/helion/autotuner/config_generation.py index 0a11b218e..9e1533be4 100644 --- a/helion/autotuner/config_generation.py +++ b/helion/autotuner/config_generation.py @@ -185,6 +185,7 @@ def differential_mutation( def encode_config(self, flat_config: FlatConfig) -> list[float]: """ Encode a flat configuration into a numerical vector for ML models. + This is used by surrogate-assisted algorithms (e.g., DE-Surrogate) that need to represent configurations as continuous vectors for prediction models. From 44c925db03ef7f71153bbf31559cd33dbdf2e87f Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:44:15 -0800 Subject: [PATCH 17/36] save gp state --- helion/autotuner/ucb_pattern_search.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index 8e701521d..39d5322fd 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -59,7 +59,7 @@ def __init__( initial_population: int = PATTERN_SEARCH_DEFAULTS.initial_population, copies: int = PATTERN_SEARCH_DEFAULTS.copies, max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, - min_improvement_delta: float = 0.0005, + min_improvement_delta: float = 0.001, frac_selected: float = 0.3, num_neighbors: int = 100, radius: int = 2, @@ -177,14 +177,15 @@ def _autotune(self) -> Config: # Fit GP self.log(f"Fitting GP: {len(train_X)} points, {len(train_Y)} targets") - gp = self.fit_gp( + self.gp = self.fit_gp( train_X, train_Y, self.cat_dims, ) search_copies = [ - self._pruned_pattern_search_from(m, visited, gp) for m in starting_points + self._pruned_pattern_search_from(m, visited, self.gp) + for m in starting_points ] for generation in range(1, self.max_generations + 1): prior_best = self.best @@ -227,7 +228,7 @@ def _autotune(self) -> Config: self.log( f"Conditioning on new data: {len(train_X)} points, {len(train_Y)} targets" ) - gp = gp.condition_on_observations(train_X, -train_Y.unsqueeze(1)) + self.gp = self.gp.condition_on_observations(train_X, -train_Y.unsqueeze(1)) return self.best.config From 70a46fe444ae281dd25a9d86037318b6487b758d Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:48:09 -0800 Subject: [PATCH 18/36] better ucb docstring --- helion/autotuner/ucb_pattern_search.py | 56 ++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index 39d5322fd..6a6c3ec34 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -41,14 +41,54 @@ class UCBPatternSearch(PatternSearch): """ - Upper Confidence Bound (UCB) Pattern Search - - Modifies PatternSearch to (1) generate random neighbors from each search copy - within a set radius, (2) filter the neighbors to benchmark using a fitted GaussianProcess - with the UCB acquisition function. - - Uses the MixedSingleTaskGP model from botorch, which supports continuous - and categorical variables. It only fits the GP once to avoid long runtimes. + Upper Confidence Bound (UCB) Pattern Search - A Bayesian optimization-guided autotuner. + + This algorithm enhances PatternSearch by using Gaussian Process surrogate models + with Upper Confidence Bound (UCB) acquisition to intelligently select which + configurations to benchmark, reducing the number of kernel compilations and runs + needed to find optimal configurations. + + Algorithm Overview: + 1. Generate an initial random population and benchmark all configurations + 2. Fit a Gaussian Process (GP) model on the benchmarked data + 3. For each generation: + - Generate random neighbors around the current best configurations + - Score all neighbors using UCB acquisition function + - Benchmark only the top frac_selected fraction of neighbors + - Condition the GP on new observations (rather than refitting) + - Update search trajectories based on new results + + Key Differences from PatternSearch: + - Generates num_neighbors random neighbors (within radius) instead of + systematic single-parameter perturbations + - Uses GP+UCB to filter which neighbors to actually benchmark, significantly + reducing compilation/benchmark overhead + - Supports both continuous (power-of-two) and categorical parameters via + MixedSingleTaskGP from BoTorch + + Args: + kernel: The kernel to be autotuned. + args: The arguments to be passed to the kernel during benchmarking. + initial_population: Number of random configurations in initial population. + Default from PATTERN_SEARCH_DEFAULTS. + copies: Number of top configurations to run pattern search from. + Default from PATTERN_SEARCH_DEFAULTS. + max_generations: Maximum number of search iterations per copy. + Default from PATTERN_SEARCH_DEFAULTS. + min_improvement_delta: Early stopping threshold. Search stops if the relative + improvement abs(best/current - 1) < min_improvement_delta. + Default: 0.0005 (0.05% improvement threshold). + frac_selected: Fraction of generated neighbors to actually benchmark, after + filtering by UCB score. Range: (0, 1]. Lower values reduce benchmarking + cost but may miss good configurations. Default: 0.3. + num_neighbors: Number of random neighbor configurations to generate around + each search point per generation. Default: 100. + radius: Maximum perturbation distance in configuration space. For power-of-two + parameters, this is the max change in log2 space. For other parameters, + this limits how many parameters can be changed. Default: 2. + ucb_beta: Exploration/exploitation trade-off parameter for UCB acquisition. + Higher values favor exploration of uncertain regions. Typical range: [1, 5]. + Default: 2.0. """ def __init__( From f837953c065c66f31e9259bfdcfebd2bbdeb3717 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:49:39 -0800 Subject: [PATCH 19/36] combined dependencies --- .github/workflows/benchmark.yml | 2 +- .github/workflows/test.yml | 2 +- pyproject.toml | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 18dd35d6b..5e941d87f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -100,7 +100,7 @@ jobs: - name: Install Helion run: | source .venv/bin/activate - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate,bayesopt]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,surrogate]' python -c "import helion; print(helion.__name__)" - name: Install Benchmark Requirements diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 573751dad..c831aaf53 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -146,7 +146,7 @@ jobs: run: | source .venv/bin/activate uv pip install setuptools ninja - SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,de-surrogate,bayesopt]' + SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0" uv pip install -e .'[dev,surrogate]' python -c "import helion; print(helion.__name__)" - name: Run Tests diff --git a/pyproject.toml b/pyproject.toml index 951bb4eb7..757020929 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,10 @@ dependencies = [ ] [project.optional-dependencies] -de-surrogate = [ +surrogate = [ "numpy", "scikit-learn>=1.3.0" -] -bayesopt = [ - "botorch>=0.16.0" + "botorch>=0.14.0" ] dev = [ "expecttest", From 2c2eec909d97e14a8649d0384c78a5085aa66305 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 10:57:19 -0800 Subject: [PATCH 20/36] fix pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 757020929..fd6cdcb15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ [project.optional-dependencies] surrogate = [ "numpy", - "scikit-learn>=1.3.0" + "scikit-learn>=1.3.0", "botorch>=0.14.0" ] dev = [ From 74b0754b98274dd94a4c34fcf5962ad16b651842 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 11:02:31 -0800 Subject: [PATCH 21/36] reverting unrelated changes to comments --- helion/autotuner/config_fragment.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index b028478f5..bf8cead94 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -62,7 +62,16 @@ def dim(self) -> int: def encode(self, value: object) -> list[float]: """ - Returns a list of floats that can be used to encode the value of this fragment. + Encode a configuration value into a list of floats for ML models. + + This is used by surrogate-assisted algorithms to convert configurations + into numerical vectors for prediction models. + + Args: + value: The configuration value to encode. + + Returns: + A list of floats representing the encoded value. """ raise NotImplementedError From d9cce1e1387d45701930d0c92d7d4bef88e1bfb1 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Sun, 16 Nov 2025 11:03:39 -0800 Subject: [PATCH 22/36] no need for encode for integer fragment, inherit from base integer --- helion/autotuner/config_fragment.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/helion/autotuner/config_fragment.py b/helion/autotuner/config_fragment.py index bf8cead94..1dcc84d71 100644 --- a/helion/autotuner/config_fragment.py +++ b/helion/autotuner/config_fragment.py @@ -228,10 +228,6 @@ def differential_mutation(self, a: object, b: object, c: object) -> int: return self.clamp(a + 1) return a - def encode(self, value: object) -> list[float]: - assert isinstance(value, int) - return [float(value)] - @dataclasses.dataclass class EnumFragment(ConfigSpecFragment): From 4a791a94fa178e3630d9db38cb772633733449cf Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Mon, 17 Nov 2025 07:36:00 -0800 Subject: [PATCH 23/36] optimize batch UCB function, simplify batch selection --- helion/autotuner/ucb_pattern_search.py | 124 +++++++++++++++++++------ 1 file changed, 95 insertions(+), 29 deletions(-) diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index 6a6c3ec34..afce0fbab 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -25,7 +25,7 @@ try: from botorch.acquisition import ( - UpperConfidenceBound, # type: ignore[import-not-found] + qUpperConfidenceBound, # type: ignore[import-not-found] ) from botorch.fit import fit_gpytorch_mll # type: ignore[import-not-found] from botorch.models import MixedSingleTaskGP # type: ignore[import-not-found] @@ -89,6 +89,11 @@ class UCBPatternSearch(PatternSearch): ucb_beta: Exploration/exploitation trade-off parameter for UCB acquisition. Higher values favor exploration of uncertain regions. Typical range: [1, 5]. Default: 2.0. + use_greedy_batch: If True, use greedy batch acquisition where points are + selected sequentially, conditioning the GP on each selected point before + choosing the next. This produces more diverse batches but is slower. + If False, all points are scored independently (standard UCB). + Default: False. """ def __init__( @@ -100,8 +105,8 @@ def __init__( copies: int = PATTERN_SEARCH_DEFAULTS.copies, max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, min_improvement_delta: float = 0.001, - frac_selected: float = 0.3, - num_neighbors: int = 100, + frac_selected: float = 0.1, + num_neighbors: int = 300, radius: int = 2, ucb_beta: float = 2.0, ) -> None: @@ -137,7 +142,7 @@ def __init__( for idx in range(offsets[i], offsets[i + 1]) ] - def fit_gp( + def _fit_gp( self, train_X: torch.Tensor, train_Y: torch.Tensor, cat_dims: list ) -> MixedSingleTaskGP: # Filter out rows where train_Y contains inf or nan @@ -154,16 +159,89 @@ def fit_gp( return gp - def acq_fun(self, X: torch.Tensor, gp: MixedSingleTaskGP) -> torch.Tensor: - orig_dtype = X.dtype + def _optimize_batch_acq( + self, + candidates: list[PopulationMember], + gp: MixedSingleTaskGP, + num_select: int, + ) -> list[PopulationMember]: + """ + Greedily optimize the set-valued UCB acquisition function. + + This treats the acquisition function as set-valued: it evaluates the value + of acquiring a batch of points together. We greedily build up the batch by: + 1. Start with empty selected set + 2. For each candidate, evaluate acq_fun(selected_set ∪ {candidate}) + 3. Select the candidate that maximizes this set value + 4. Add it to selected set and repeat + + This encourages diversity since adding a point near already-selected points + typically yields lower marginal gain. + + Args: + candidates: List of candidate configurations to select from + gp: The Gaussian Process model + num_select: Number of candidates to select - acq_fun = UpperConfidenceBound(gp, beta=self.ucb_beta) # type: ignore[misc] - return ( # type: ignore[misc] - acq_fun(X.unsqueeze(1).to(dtype=torch.float64)) - .detach() - .to(dtype=orig_dtype) + Returns: + List of selected candidates (in order of selection) + """ + selected: list[PopulationMember] = [] + selected_indices: list[int] = [] + remaining_indices = list(range(len(candidates))) + + acq_fn = qUpperConfidenceBound(gp, beta=self.ucb_beta) # type: ignore[misc] + + candidate_X = torch.stack( + [ + torch.tensor(self.config_gen.encode_config(member.flat_values)) + for member in candidates + ] ) + for _ in range(num_select): + if not remaining_indices: + break + + # Batch evaluate all remaining candidates at once + if selected_indices: + # Build batch: for each remaining, create [selected + remaining[i]] + # Shape: [num_remaining, num_selected + 1, D] + num_remaining = len(remaining_indices) + + # Expand selected points to [num_remaining, num_selected, D] + selected_X = candidate_X[selected_indices] # [num_selected, D] + expanded_selected = selected_X.unsqueeze(0).expand( + num_remaining, -1, -1 + ) + + # Get remaining candidates as [num_remaining, 1, D] + remaining_X = candidate_X[remaining_indices].unsqueeze(1) + + # Concatenate to get [num_remaining, num_selected+1, D] + batch_X = torch.cat([expanded_selected, remaining_X], dim=1) + + # Evaluate all sets at once: [num_remaining] + set_values = acq_fn(batch_X.to(dtype=torch.float64)) # [num_remaining] + else: + # First selection: evaluate each candidate independently + remaining_X = candidate_X[remaining_indices].unsqueeze( + 1 + ) # [num_remaining, 1, D] + set_values = acq_fn( + remaining_X.to(dtype=torch.float64) + ) # [num_remaining] + + # Select the best + best_idx_in_remaining = int(set_values.argmax()) + best_idx = remaining_indices[best_idx_in_remaining] + + selected.append(candidates[best_idx]) + selected_indices.append(best_idx) + remaining_indices.pop(best_idx_in_remaining) + + return selected + def get_train_data_from_pop( self, population: list[PopulationMember] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -217,7 +295,7 @@ def _autotune(self) -> Config: # Fit GP self.log(f"Fitting GP: {len(train_X)} points, {len(train_Y)} targets") - self.gp = self.fit_gp( + self.gp = self._fit_gp( train_X, train_Y, self.cat_dims, @@ -402,25 +480,13 @@ def _pruned_pattern_search_from( candidates.append(new_member) visited.add(new_member.config) - # score candidates - candidate_X = torch.stack( - [ - torch.tensor(self.config_gen.encode_config(member.flat_values)) - for member in candidates - ] - ) - scores = self.acq_fun(candidate_X, gp) - - # filter candidates by score - candidates_sorted = sorted( - zip(candidates, scores, strict=True), - key=operator.itemgetter(1), - reverse=True, - )[: int(self.frac_selected * len(candidates))] - candidates = [member for member, score in candidates_sorted] + # Select candidates using greedy batch or standard independent scoring + num_neighbors = len(candidates) + num_to_select = int(self.frac_selected * num_neighbors) + candidates = self._optimize_batch_acq(candidates, gp, num_to_select) self.log( - f"Scoring {len(candidate_X)} neighbors, selecting {self.frac_selected * 100}% neighbors: {len(candidates)}" + f"Scoring {num_neighbors} neighbors, selecting {self.frac_selected * 100}% neighbors: {len(candidates)}" ) if len(candidates) <= 1: From 047509d01f7e04d08ad52c063a339f96bf42f74a Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Mon, 17 Nov 2025 07:38:38 -0800 Subject: [PATCH 24/36] batch optimization by default --- helion/autotuner/ucb_pattern_search.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/ucb_pattern_search.py index afce0fbab..70832c5f2 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/ucb_pattern_search.py @@ -89,11 +89,6 @@ class UCBPatternSearch(PatternSearch): ucb_beta: Exploration/exploitation trade-off parameter for UCB acquisition. Higher values favor exploration of uncertain regions. Typical range: [1, 5]. Default: 2.0. - use_greedy_batch: If True, use greedy batch acquisition where points are - selected sequentially, conditioning the GP on each selected point before - choosing the next. This produces more diverse batches but is slower. - If False, all points are scored independently (standard UCB). - Default: False. """ def __init__( From eb430cef95f58bb101ad6c1430adda6f9e41245b Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 11:26:00 -0800 Subject: [PATCH 25/36] LFBO instead of ucb_pattern_search --- helion/autotuner/__init__.py | 4 +- helion/autotuner/de_surrogate_hybrid.py | 2 +- ..._search.py => surrogate_pattern_search.py} | 312 +++++++----------- pyproject.toml | 1 - 4 files changed, 117 insertions(+), 202 deletions(-) rename helion/autotuner/{ucb_pattern_search.py => surrogate_pattern_search.py} (56%) diff --git a/helion/autotuner/__init__.py b/helion/autotuner/__init__.py index 751e7bffe..7d08e9603 100644 --- a/helion/autotuner/__init__.py +++ b/helion/autotuner/__init__.py @@ -19,11 +19,11 @@ from .local_cache import StrictLocalAutotuneCache as StrictLocalAutotuneCache from .pattern_search import PatternSearch as PatternSearch from .random_search import RandomSearch as RandomSearch -from .ucb_pattern_search import UCBPatternSearch +from .surrogate_pattern_search import LFBOPatternSearch search_algorithms = { "DESurrogateHybrid": DESurrogateHybrid, - "UCBPatternSearch": UCBPatternSearch, + "LFBOPatternSearch": LFBOPatternSearch, "DifferentialEvolutionSearch": DifferentialEvolutionSearch, "FiniteSearch": FiniteSearch, "PatternSearch": PatternSearch, diff --git a/helion/autotuner/de_surrogate_hybrid.py b/helion/autotuner/de_surrogate_hybrid.py index b04cbfb9b..e6a36df5e 100644 --- a/helion/autotuner/de_surrogate_hybrid.py +++ b/helion/autotuner/de_surrogate_hybrid.py @@ -91,7 +91,7 @@ def __init__( if not HAS_ML_DEPS: raise ImportError( "DESurrogateHybrid requires numpy and scikit-learn. " - "Install them with: pip install helion[de-surrogate]" + "Install them with: pip install helion[surrogate]" ) # Initialize parent with early stopping parameters diff --git a/helion/autotuner/ucb_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py similarity index 56% rename from helion/autotuner/ucb_pattern_search.py rename to helion/autotuner/surrogate_pattern_search.py index 70832c5f2..708b8d3a2 100644 --- a/helion/autotuner/ucb_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -1,13 +1,10 @@ from __future__ import annotations -from itertools import accumulate import math import operator import random from typing import TYPE_CHECKING -import torch - from .. import exc from .base_search import FlatConfig from .base_search import PopulationMember @@ -24,47 +21,43 @@ from ..runtime.kernel import BoundKernel try: - from botorch.acquisition import ( - qUpperConfidenceBound, # type: ignore[import-not-found] - ) - from botorch.fit import fit_gpytorch_mll # type: ignore[import-not-found] - from botorch.models import MixedSingleTaskGP # type: ignore[import-not-found] - from gpytorch.mlls import ( - ExactMarginalLogLikelihood, # type: ignore[import-not-found] + import numpy as np # type: ignore[import-not-found] + from sklearn.ensemble import ( + RandomForestClassifier, # type: ignore[import-not-found] ) - HAS_BO_DEPS = True + HAS_ML_DEPS = True except ImportError as e: - HAS_BO_DEPS = False + HAS_ML_DEPS = False _IMPORT_ERROR = e -class UCBPatternSearch(PatternSearch): +class LFBOPatternSearch(PatternSearch): """ - Upper Confidence Bound (UCB) Pattern Search - A Bayesian optimization-guided autotuner. + Likelihood-Free Bayesian Optimization (LFBO) Pattern Search. - This algorithm enhances PatternSearch by using Gaussian Process surrogate models - with Upper Confidence Bound (UCB) acquisition to intelligently select which - configurations to benchmark, reducing the number of kernel compilations and runs - needed to find optimal configurations. + This algorithm enhances PatternSearch by using a Random Forest classifier as a surrogate + model to select which configurations to benchmark, reducing the number of + kernel compilations and runs needed to find optimal configurations. Algorithm Overview: 1. Generate an initial random population and benchmark all configurations - 2. Fit a Gaussian Process (GP) model on the benchmarked data + 2. Fit a Random Forest classifier to predict "good" vs "bad" configurations: + - Configs with performance < quantile threshold are labeled as "good" (class 1) + - Configs with performance >= quantile threshold are labeled as "bad" (class 0) + - Weighted classification emphasize configs that are much better than the threshold 3. For each generation: - Generate random neighbors around the current best configurations - - Score all neighbors using UCB acquisition function + - Score all neighbors using the classifier's predicted probability of being "good" - Benchmark only the top frac_selected fraction of neighbors - - Condition the GP on new observations (rather than refitting) + - Retrain the classifier on all observed data (not incremental) - Update search trajectories based on new results - Key Differences from PatternSearch: - - Generates num_neighbors random neighbors (within radius) instead of - systematic single-parameter perturbations - - Uses GP+UCB to filter which neighbors to actually benchmark, significantly - reducing compilation/benchmark overhead - - Supports both continuous (power-of-two) and categorical parameters via - MixedSingleTaskGP from BoTorch + The weighted classification model learns an acquisition function. Namely it helps + to identify which configs maximize expected improvement over the current best config. + Compared to fitting a surrogate to fit the config performances themselves, + since this method is based on classification, it can also incorporate configs + that timeout or have unacceptable accuracy. Args: kernel: The kernel to be autotuned. @@ -77,18 +70,19 @@ class UCBPatternSearch(PatternSearch): Default from PATTERN_SEARCH_DEFAULTS. min_improvement_delta: Early stopping threshold. Search stops if the relative improvement abs(best/current - 1) < min_improvement_delta. - Default: 0.0005 (0.05% improvement threshold). + Default: 0.001 (0.1% improvement threshold). frac_selected: Fraction of generated neighbors to actually benchmark, after - filtering by UCB score. Range: (0, 1]. Lower values reduce benchmarking - cost but may miss good configurations. Default: 0.3. + filtering by classifier score. Range: (0, 1]. Lower values reduce benchmarking + cost but may miss good configurations. Default: 0.15. num_neighbors: Number of random neighbor configurations to generate around - each search point per generation. Default: 100. + each search point per generation. Default: 300. radius: Maximum perturbation distance in configuration space. For power-of-two parameters, this is the max change in log2 space. For other parameters, this limits how many parameters can be changed. Default: 2. - ucb_beta: Exploration/exploitation trade-off parameter for UCB acquisition. - Higher values favor exploration of uncertain regions. Typical range: [1, 5]. - Default: 2.0. + quantile: Threshold for labeling configs as "good" (class 1) vs "bad" (class 0). + Configs with performance below this quantile are labeled as good. + Range: (0, 1). Lower values create a more selective definition of "good". + Default: 0.3 (top 30% are considered good). """ def __init__( @@ -100,14 +94,15 @@ def __init__( copies: int = PATTERN_SEARCH_DEFAULTS.copies, max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, min_improvement_delta: float = 0.001, - frac_selected: float = 0.1, + frac_selected: float = 0.15, num_neighbors: int = 300, radius: int = 2, - ucb_beta: float = 2.0, + quantile: float = 0.3, ) -> None: - if not HAS_BO_DEPS: + if not HAS_ML_DEPS: raise exc.AutotuneError( - "UCBPatternSearch requires botorch. Install before using." + "LFBOPatternSearch requires numpy and scikit-learn." + "Install them with: pip install helion[surrogate]" ) from _IMPORT_ERROR super().__init__( @@ -121,145 +116,74 @@ def __init__( # Storage for BO self.num_neighbors = num_neighbors self.radius = radius - self.ucb_beta = ucb_beta - - # Initialize config encoder - self.frac_selected = frac_selected - # compute offsets from the flat_spec - dim_sizes = [spec.dim() for spec in self.config_gen.flat_spec] - offsets = [0, *list(accumulate(dim_sizes))] + self.model = RandomForestClassifier( # type: ignore[misc] + criterion="log_loss", + random_state=42, + n_estimators=100, + max_depth=15, + min_samples_split=2, + min_samples_leaf=1, + n_jobs=-1, + ) - self.cat_dims = [ - idx - for i, spec in enumerate(self.config_gen.flat_spec) - if spec.is_categorical() - for idx in range(offsets[i], offsets[i + 1]) - ] + self.train_X = [] + self.train_Y = [] - def _fit_gp( - self, train_X: torch.Tensor, train_Y: torch.Tensor, cat_dims: list - ) -> MixedSingleTaskGP: - # Filter out rows where train_Y contains inf or nan + # Initialize config encoder + self.frac_selected = frac_selected + self.quantile = quantile - gp = MixedSingleTaskGP( # type: ignore[misc] - train_X, - -train_Y.unsqueeze(-1), - cat_dims, + def _fit_surrogate(self) -> None: + train_X = np.array(self.train_X) # type: ignore[union-attr] + train_Y = np.array(self.train_Y) # type: ignore[union-attr] + self.log.debug( + f"Fitting surrogate: {len(train_X)} points, {len(train_Y)} targets" ) + train_Y_quantile = np.quantile(train_Y, self.quantile) # type: ignore[union-attr] - with torch.enable_grad(): - mll = ExactMarginalLogLikelihood(gp.likelihood, gp) # type: ignore[misc] - fit_gpytorch_mll(mll) # type: ignore[misc] + # Labels are generated by which are configs better than the quantile + train_labels = 1.0 * (train_Y < train_Y_quantile) + pos_weights = np.maximum(0, train_Y_quantile - train_Y) # type: ignore[union-attr] + normalizing_factor = np.mean( # type: ignore[union-attr] + np.array([weight for weight in pos_weights if weight > 0.0]) # type: ignore[union-attr] + ) + pos_weights = pos_weights / normalizing_factor + sample_weight = np.where(train_Y < train_Y_quantile, pos_weights, 1.0) # type: ignore[union-attr] - return gp + self.model.fit(train_X, train_labels, sample_weight=sample_weight) - def _optimize_batch_acq( - self, - candidates: list[PopulationMember], - gp: MixedSingleTaskGP, - num_select: int, + def _surrogate_select( + self, candidates: list[PopulationMember], n_sorted: int ) -> list[PopulationMember]: - """ - Greedily optimize the set-valued UCB acquisition function. - - This treats the acquisition function as set-valued: it evaluates the value - of acquiring a batch of points together. We greedily build up the batch by: - 1. Start with empty selected set - 2. For each candidate, evaluate acq_fun(selected_set ∪ {candidate}) - 3. Select the candidate that maximizes this set value - 4. Add it to selected set and repeat - - This encourages diversity since adding a point near already-selected points - typically yields lower marginal gain. - - Args: - candidates: List of candidate configurations to select from - gp: The Gaussian Process model - num_select: Number of candidates to select - - Returns: - List of selected candidates (in order of selection) - """ - selected: list[PopulationMember] = [] - selected_indices: list[int] = [] - remaining_indices = list(range(len(candidates))) - - acq_fn = qUpperConfidenceBound(gp, beta=self.ucb_beta) # type: ignore[misc] - - candidate_X = torch.stack( - [ - torch.tensor(self.config_gen.encode_config(member.flat_values)) - for member in candidates - ] + # Score candidates + candidate_X = np.array( # type: ignore[union-attr] + [self.config_gen.encode_config(member.flat_values) for member in candidates] + ) + scores = self.model.predict_proba(candidate_X) # type: ignore[assignment] + + if scores.shape[1] == 2: # type: ignore[union-attr] + scores = scores[:, 1] # type: ignore[index] + elif scores.shape[1] == 1: # type: ignore[union-attr] + scores = scores[:, 0] # type: ignore[index] + else: + raise ValueError("Unexpected shape for scores") + + candidates_sorted = sorted( + zip(candidates, scores, strict=True), + key=operator.itemgetter(1), + reverse=True, # higher scores are better + )[:n_sorted] + + self.log.debug( + f"Scoring {len(candidate_X)} neighbors, selecting {(n_sorted / len(candidate_X)) * 100:.0f}% neighbors: {len(candidates_sorted)}" ) - for _ in range(num_select): - if not remaining_indices: - break - - # Batch evaluate all remaining candidates at once - if selected_indices: - # Build batch: for each remaining, create [selected + remaining[i]] - # Shape: [num_remaining, num_selected + 1, D] - num_remaining = len(remaining_indices) - - # Expand selected points to [num_remaining, num_selected, D] - selected_X = candidate_X[selected_indices] # [num_selected, D] - expanded_selected = selected_X.unsqueeze(0).expand( - num_remaining, -1, -1 - ) - - # Get remaining candidates as [num_remaining, 1, D] - remaining_X = candidate_X[remaining_indices].unsqueeze(1) - - # Concatenate to get [num_remaining, num_selected+1, D] - batch_X = torch.cat([expanded_selected, remaining_X], dim=1) - - # Evaluate all sets at once: [num_remaining] - set_values = acq_fn(batch_X.to(dtype=torch.float64)) # [num_remaining] - else: - # First selection: evaluate each candidate independently - remaining_X = candidate_X[remaining_indices].unsqueeze( - 1 - ) # [num_remaining, 1, D] - set_values = acq_fn( - remaining_X.to(dtype=torch.float64) - ) # [num_remaining] - - # Select the best - best_idx_in_remaining = int(set_values.argmax()) - best_idx = remaining_indices[best_idx_in_remaining] - - selected.append(candidates[best_idx]) - selected_indices.append(best_idx) - remaining_indices.pop(best_idx_in_remaining) - - return selected - - def get_train_data_from_pop( - self, population: list[PopulationMember] - ) -> tuple[torch.Tensor, torch.Tensor]: - train_X = [] - train_Y = [] - for member in population: - train_X.append( - torch.tensor(self.config_gen.encode_config(member.flat_values)) - ) - train_Y.append(member.perf) - - train_X = torch.stack(train_X) - train_Y = torch.tensor(train_Y) - - valid_mask = torch.isfinite(train_Y) - train_X_filtered = train_X[valid_mask].to(dtype=torch.float64) - train_Y_filtered = train_Y[valid_mask].to(dtype=torch.float64) - - return train_X_filtered, train_Y_filtered + return [member for member, score in candidates_sorted] def _autotune(self) -> Config: self.log( - f"Starting UCBPatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}" + f"Starting LFBOPatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}" ) visited = set() self.population = [] @@ -286,19 +210,15 @@ def _autotune(self) -> Config: raise exc.NoConfigFound # Save to training data - train_X, train_Y = self.get_train_data_from_pop(self.population) - - # Fit GP - self.log(f"Fitting GP: {len(train_X)} points, {len(train_Y)} targets") - self.gp = self._fit_gp( - train_X, - train_Y, - self.cat_dims, - ) + for member in self.population: + self.train_X.append(self.config_gen.encode_config(member.flat_values)) + self.train_Y.append(member.perf) + + # Fit model + self._fit_surrogate() search_copies = [ - self._pruned_pattern_search_from(m, visited, self.gp) - for m in starting_points + self._pruned_pattern_search_from(m, visited) for m in starting_points ] for generation in range(1, self.max_generations + 1): prior_best = self.best @@ -336,16 +256,15 @@ def _autotune(self) -> Config: self.log(f"Generation {generation} complete:", self.statistics) # Save to training data - train_X, train_Y = self.get_train_data_from_pop(self.population) + for member in self.population: + self.train_X.append(self.config_gen.encode_config(member.flat_values)) + self.train_Y.append(member.perf) - self.log( - f"Conditioning on new data: {len(train_X)} points, {len(train_Y)} targets" - ) - self.gp = self.gp.condition_on_observations(train_X, -train_Y.unsqueeze(1)) + self._fit_surrogate() return self.best.config - def random_log2_neighbor( + def _random_log2_neighbor( self, current_val: int, radius: int, low: int, high: int ) -> int: # Log the current value @@ -388,11 +307,11 @@ def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: block_spec = self.config_gen.flat_spec[block_idx] current_val = base[block_idx] - assert type(current_val) is int + assert isinstance(current_val, int) if isinstance(block_spec, PowerOfTwoFragment): # Change by at most 1 in log2 space - new_flat[block_idx] = self.random_log2_neighbor( + new_flat[block_idx] = self._random_log2_neighbor( current_val, radius=self.radius, low=block_spec.low, @@ -408,11 +327,11 @@ def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: warp_spec = self.config_gen.flat_spec[warp_idx] current_val = base[warp_idx] - assert type(current_val) is int + assert isinstance(current_val, int) if isinstance(warp_spec, PowerOfTwoFragment): # Change by at most self.radius in log2 space - new_flat[warp_idx] = self.random_log2_neighbor( + new_flat[warp_idx] = self._random_log2_neighbor( current_val, radius=self.radius, low=warp_spec.low, @@ -454,7 +373,6 @@ def _pruned_pattern_search_from( self, current: PopulationMember, visited: set[Config], - gp: MixedSingleTaskGP, ) -> Iterator[list[PopulationMember]]: """ Run a single copy of pattern search from the given starting point. @@ -463,32 +381,30 @@ def _pruned_pattern_search_from( run multiple copies of pattern search in parallel. Only keep self.frac_selected of the neighbors generated from the current - search_copy. Filter them using the GaussianProcess + UCB acqusition function. + search_copy. Filter them using the GaussianProcess. """ + patience = 0 for _ in range(self.max_generations): candidates = [current] all_neighbors = self._generate_neighbors(current.flat_values) - self.log(f"Number of all candidate neighbors: {len(all_neighbors)}") for flat_config in all_neighbors: new_member = self.make_unbenchmarked(flat_config) if new_member.config not in visited: candidates.append(new_member) visited.add(new_member.config) - # Select candidates using greedy batch or standard independent scoring - num_neighbors = len(candidates) - num_to_select = int(self.frac_selected * num_neighbors) - candidates = self._optimize_batch_acq(candidates, gp, num_to_select) - - self.log( - f"Scoring {num_neighbors} neighbors, selecting {self.frac_selected * 100}% neighbors: {len(candidates)}" - ) + # score candidates + n_sorted = int(len(candidates) * self.frac_selected) + candidates = self._surrogate_select(candidates, n_sorted) if len(candidates) <= 1: return # no new candidates, stop searching yield candidates # yield new population to benchmark in parallel - # update search copy and check early stopping criteria best = min(candidates, key=performance) if self._check_early_stopping(best, current): - return + if patience > 0: + patience -= 1 + self.log.debug(f"Failed to improve. Patience remaining: {patience}") + else: + return current = best diff --git a/pyproject.toml b/pyproject.toml index fd6cdcb15..e80005da5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ surrogate = [ "numpy", "scikit-learn>=1.3.0", - "botorch>=0.14.0" ] dev = [ "expecttest", From 33803df9a63d41b8f80ca3bb17c23c76154549dd Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 11:29:25 -0800 Subject: [PATCH 26/36] LFBO tests --- helion/autotuner/surrogate_pattern_search.py | 20 ++++++++++---------- test/test_autotuner.py | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index 708b8d3a2..a47f2651a 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -117,16 +117,6 @@ def __init__( self.num_neighbors = num_neighbors self.radius = radius - self.model = RandomForestClassifier( # type: ignore[misc] - criterion="log_loss", - random_state=42, - n_estimators=100, - max_depth=15, - min_samples_split=2, - min_samples_leaf=1, - n_jobs=-1, - ) - self.train_X = [] self.train_Y = [] @@ -151,6 +141,15 @@ def _fit_surrogate(self) -> None: pos_weights = pos_weights / normalizing_factor sample_weight = np.where(train_Y < train_Y_quantile, pos_weights, 1.0) # type: ignore[union-attr] + self.model = RandomForestClassifier( # type: ignore[misc] + criterion="log_loss", + random_state=42, + n_estimators=100, + max_depth=15, + min_samples_split=2, + min_samples_leaf=1, + n_jobs=-1, + ) self.model.fit(train_X, train_labels, sample_weight=sample_weight) def _surrogate_select( @@ -165,6 +164,7 @@ def _surrogate_select( if scores.shape[1] == 2: # type: ignore[union-attr] scores = scores[:, 1] # type: ignore[index] elif scores.shape[1] == 1: # type: ignore[union-attr] + # If probabilities are all 1, then the model outputs a 1D vector. scores = scores[:, 0] # type: ignore[index] else: raise ValueError("Unexpected shape for scores") diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 95049156f..b62e99575 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -35,8 +35,8 @@ from helion._testing import skipIfRocm from helion.autotuner import DESurrogateHybrid from helion.autotuner import DifferentialEvolutionSearch +from helion.autotuner import LFBOPatternSearch from helion.autotuner import PatternSearch -from helion.autotuner import UCBPatternSearch from helion.autotuner.base_search import BaseSearch from helion.autotuner.base_search import PopulationMember from helion.autotuner.config_fragment import BooleanFragment @@ -625,9 +625,9 @@ def diff_count(flat): self.assertEqual(sorted(pair_neighbors), sorted(expected)) def test_ucb_pattern_search_generate_neighbors(self): - """Test UCBPatternSearch._generate_neighbors method.""" + """Test LFBOPatternSearch._generate_neighbors method.""" random.seed(123) - search = UCBPatternSearch.__new__(UCBPatternSearch) + search = LFBOPatternSearch.__new__(LFBOPatternSearch) search.num_neighbors = 50 search.radius = 2 search.config_gen = SimpleNamespace( @@ -667,14 +667,14 @@ def test_ucb_pattern_search_generate_neighbors(self): @skipIfRocm("too slow on rocm") @skip("too slow") - def test_ucb_pattern_search(self): + def test_lfbo_pattern_search(self): args = ( torch.randn([64, 64], device=DEVICE), torch.randn([64, 64], device=DEVICE), ) bound_kernel = basic_kernels.add.bind(args) random.seed(123) - best = UCBPatternSearch( + best = LFBOPatternSearch( bound_kernel, args, initial_population=10, From b6c24cc9fc2ee23dc0e4b0880e2e4a1119c3f0db Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 11:35:30 -0800 Subject: [PATCH 27/36] LFBO better docstring --- helion/autotuner/surrogate_pattern_search.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index a47f2651a..1f81f597b 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -53,11 +53,13 @@ class LFBOPatternSearch(PatternSearch): - Retrain the classifier on all observed data (not incremental) - Update search trajectories based on new results - The weighted classification model learns an acquisition function. Namely it helps - to identify which configs maximize expected improvement over the current best config. - Compared to fitting a surrogate to fit the config performances themselves, - since this method is based on classification, it can also incorporate configs - that timeout or have unacceptable accuracy. + The weighted classification model learns to identify which configs maximize + expected improvement over the current best config. Compared to fitting a surrogate + to fit the config performances themselves, since this method is based on classification, + it can also learn from configs that timeout or have unacceptable accuracy. + + References: + - Song, J., et al. (2022). "A General Recipe for Likelihood-free Bayesian Optimization." Args: kernel: The kernel to be autotuned. From e3106a81854bf27cc98dc0d46289bb46dd59b692 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 11:38:08 -0800 Subject: [PATCH 28/36] LFBO remove patience feature --- helion/autotuner/surrogate_pattern_search.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index 1f81f597b..10970a76e 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -385,7 +385,6 @@ def _pruned_pattern_search_from( Only keep self.frac_selected of the neighbors generated from the current search_copy. Filter them using the GaussianProcess. """ - patience = 0 for _ in range(self.max_generations): candidates = [current] all_neighbors = self._generate_neighbors(current.flat_values) @@ -404,9 +403,5 @@ def _pruned_pattern_search_from( yield candidates # yield new population to benchmark in parallel best = min(candidates, key=performance) if self._check_early_stopping(best, current): - if patience > 0: - patience -= 1 - self.log.debug(f"Failed to improve. Patience remaining: {patience}") - else: - return + return current = best From 1c810f074c25435874768ce459cf8c0e3e3a2539 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 11:57:37 -0800 Subject: [PATCH 29/36] LFBO imports --- helion/autotuner/surrogate_pattern_search.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index 10970a76e..c608215af 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -22,8 +22,8 @@ try: import numpy as np # type: ignore[import-not-found] - from sklearn.ensemble import ( - RandomForestClassifier, # type: ignore[import-not-found] + from sklearn.ensemble import ( # type: ignore[import-not-found] + RandomForestClassifier, ) HAS_ML_DEPS = True @@ -147,7 +147,6 @@ def _fit_surrogate(self) -> None: criterion="log_loss", random_state=42, n_estimators=100, - max_depth=15, min_samples_split=2, min_samples_leaf=1, n_jobs=-1, From c25362d64765c7ba64f3570b7e823779b95947ca Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 12:04:45 -0800 Subject: [PATCH 30/36] Fix comments --- helion/autotuner/surrogate_pattern_search.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index c608215af..ba18c5c47 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -115,15 +115,15 @@ def __init__( max_generations=max_generations, min_improvement_delta=min_improvement_delta, ) - # Storage for BO + + # Number of neighbors and how many to evalaute self.num_neighbors = num_neighbors self.radius = radius + self.frac_selected = frac_selected + # Save training data self.train_X = [] self.train_Y = [] - - # Initialize config encoder - self.frac_selected = frac_selected self.quantile = quantile def _fit_surrogate(self) -> None: @@ -170,6 +170,7 @@ def _surrogate_select( else: raise ValueError("Unexpected shape for scores") + # sort candidates by score candidates_sorted = sorted( zip(candidates, scores, strict=True), key=operator.itemgetter(1), @@ -256,11 +257,12 @@ def _autotune(self) -> Config: # Log final statistics for this generation self.log(f"Generation {generation} complete:", self.statistics) - # Save to training data + # Update training data for member in self.population: self.train_X.append(self.config_gen.encode_config(member.flat_values)) self.train_Y.append(member.perf) + # Fit model self._fit_surrogate() return self.best.config From b5a65bee94f2cea6e5cb9c65397a1161c89863a3 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 12:07:27 -0800 Subject: [PATCH 31/36] Fix test names --- test/test_autotuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index b62e99575..d7f1258df 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -624,7 +624,7 @@ def diff_count(flat): ] self.assertEqual(sorted(pair_neighbors), sorted(expected)) - def test_ucb_pattern_search_generate_neighbors(self): + def test_lfbo_pattern_search_generate_neighbors(self): """Test LFBOPatternSearch._generate_neighbors method.""" random.seed(123) search = LFBOPatternSearch.__new__(LFBOPatternSearch) From 062be0cca0ae620dc5655b1daab57801779b781b Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 12:28:56 -0800 Subject: [PATCH 32/36] remove comma --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 68f6b4f9a..e02adb30c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ [project.optional-dependencies] surrogate = [ "numpy", - "scikit-learn>=1.3.0", + "scikit-learn>=1.3.0" ] dev = [ "expecttest", From fcb070dad8526d36c7a79fefd6e81af33c04e71b Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 12:58:16 -0800 Subject: [PATCH 33/36] Fix comments --- helion/autotuner/surrogate_pattern_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index ba18c5c47..708d794ea 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -303,7 +303,7 @@ def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: new_flat = [*base] # Copy the base configuration modified_indices = set() - # 1. Sample a block size index and change it by at most 1 + # 1. Sample a block size index and change it by at most radius if self.config_gen.block_size_indices: block_idx = random.choice(self.config_gen.block_size_indices) modified_indices.add(block_idx) @@ -313,7 +313,7 @@ def _generate_neighbors(self, base: FlatConfig) -> list[FlatConfig]: assert isinstance(current_val, int) if isinstance(block_spec, PowerOfTwoFragment): - # Change by at most 1 in log2 space + # Change by at most radius in log2 space new_flat[block_idx] = self._random_log2_neighbor( current_val, radius=self.radius, From 157217250a8103f2c71a5f9b7331413ef3b7274d Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Tue, 18 Nov 2025 12:59:19 -0800 Subject: [PATCH 34/36] Fix comments --- helion/autotuner/surrogate_pattern_search.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index 708d794ea..5070bda0f 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -384,7 +384,14 @@ def _pruned_pattern_search_from( run multiple copies of pattern search in parallel. Only keep self.frac_selected of the neighbors generated from the current - search_copy. Filter them using the GaussianProcess. + search_copy using _surrogate_select. + + Args: + current: The current best configuration. + visited: A set of visited configurations. + + Returns: + A generator that yields the new population at each generation. """ for _ in range(self.max_generations): candidates = [current] From b6b191e9790f30a6d4669c90f68e3ed28f64a941 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Wed, 19 Nov 2025 10:06:09 -0800 Subject: [PATCH 35/36] better lfbo hyperparams --- helion/autotuner/surrogate_pattern_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index 5070bda0f..8be6e643f 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -96,8 +96,8 @@ def __init__( copies: int = PATTERN_SEARCH_DEFAULTS.copies, max_generations: int = PATTERN_SEARCH_DEFAULTS.max_generations, min_improvement_delta: float = 0.001, - frac_selected: float = 0.15, - num_neighbors: int = 300, + frac_selected: float = 0.4, + num_neighbors: int = 100, radius: int = 2, quantile: float = 0.3, ) -> None: From 430166912acebe97501ed3f0a10372ad62d3aa42 Mon Sep 17 00:00:00 2001 From: Ethan Che Date: Wed, 19 Nov 2025 19:49:42 -0800 Subject: [PATCH 36/36] rename to surrogate --- helion/autotuner/surrogate_pattern_search.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/helion/autotuner/surrogate_pattern_search.py b/helion/autotuner/surrogate_pattern_search.py index 8be6e643f..afc9b6865 100644 --- a/helion/autotuner/surrogate_pattern_search.py +++ b/helion/autotuner/surrogate_pattern_search.py @@ -125,6 +125,7 @@ def __init__( self.train_X = [] self.train_Y = [] self.quantile = quantile + self.surrogate = None def _fit_surrogate(self) -> None: train_X = np.array(self.train_X) # type: ignore[union-attr] @@ -143,7 +144,7 @@ def _fit_surrogate(self) -> None: pos_weights = pos_weights / normalizing_factor sample_weight = np.where(train_Y < train_Y_quantile, pos_weights, 1.0) # type: ignore[union-attr] - self.model = RandomForestClassifier( # type: ignore[misc] + self.surrogate = RandomForestClassifier( # type: ignore[misc] criterion="log_loss", random_state=42, n_estimators=100, @@ -151,7 +152,7 @@ def _fit_surrogate(self) -> None: min_samples_leaf=1, n_jobs=-1, ) - self.model.fit(train_X, train_labels, sample_weight=sample_weight) + self.surrogate.fit(train_X, train_labels, sample_weight=sample_weight) def _surrogate_select( self, candidates: list[PopulationMember], n_sorted: int @@ -160,7 +161,7 @@ def _surrogate_select( candidate_X = np.array( # type: ignore[union-attr] [self.config_gen.encode_config(member.flat_values) for member in candidates] ) - scores = self.model.predict_proba(candidate_X) # type: ignore[assignment] + scores = self.surrogate.predict_proba(candidate_X) # type: ignore[assignment] if scores.shape[1] == 2: # type: ignore[union-attr] scores = scores[:, 1] # type: ignore[index]