diff --git a/docs/conf.py b/docs/conf.py index c22f18285..13f2f8506 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -129,10 +129,11 @@ def connect(self, event: str, callback: Callable[..., None]) -> None: } theme_variables = pytorch_sphinx_theme2.get_theme_variables() -templates_path = [ - "_templates", - os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"), -] +templates_path = ["_templates"] +if pytorch_sphinx_theme2.__file__ is not None: + templates_path.append( + os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates") + ) html_context = { "theme_variables": theme_variables, diff --git a/helion/autotuner/base_cache.py b/helion/autotuner/base_cache.py index 361e83791..574001938 100644 --- a/helion/autotuner/base_cache.py +++ b/helion/autotuner/base_cache.py @@ -153,6 +153,10 @@ def get(self) -> Config | None: def put(self, config: Config) -> None: raise NotImplementedError + def _get_cache_info_message(self) -> str: + """Return a message describing where the cache is and how to clear it.""" + return "" + def autotune(self) -> Config: if os.environ.get("HELION_SKIP_CACHE", "") not in {"", "0", "false", "False"}: return self.autotuner.autotune() @@ -160,11 +164,17 @@ def autotune(self) -> Config: if (config := self.get()) is not None: counters["autotune"]["cache_hit"] += 1 log.debug("cache hit: %s", str(config)) + cache_info = self._get_cache_info_message() + self.autotuner.log( + f"Found cached config for {self.kernel.kernel.name}, skipping autotuning.\n{cache_info}" + ) return config counters["autotune"]["cache_miss"] += 1 log.debug("cache miss") + self.autotuner.log("Starting autotuning process, this may take a while...") + config = self.autotuner.autotune() self.put(config) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index dd5ea72ea..252771e9c 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -29,7 +29,6 @@ import torch.multiprocessing as mp from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_map -from tqdm.rich import tqdm from triton.testing import do_bench from .. import exc @@ -40,6 +39,7 @@ from .logger import LambdaLogger from .logger import classify_triton_exception from .logger import format_triton_compile_failure +from .progress_bar import iter_with_progress log = logging.getLogger(__name__) @@ -321,15 +321,14 @@ def parallel_benchmark( else: is_workings = [True] * len(configs) results = [] - iterator = zip(configs, fns, is_workings, strict=True) - if self.settings.autotune_progress_bar: - iterator = tqdm( - iterator, - total=len(configs), - desc=desc, - unit="config", - disable=not self.settings.autotune_progress_bar, - ) + + # Render a progress bar only when the user requested it. + iterator = iter_with_progress( + zip(configs, fns, is_workings, strict=True), + total=len(configs), + description=desc, + enabled=self.settings.autotune_progress_bar, + ) for config, fn, is_working in iterator: if is_working: # benchmark one-by-one to avoid noisy results diff --git a/helion/autotuner/benchmarking.py b/helion/autotuner/benchmarking.py index e8d6313b0..b9cacbc1d 100644 --- a/helion/autotuner/benchmarking.py +++ b/helion/autotuner/benchmarking.py @@ -4,9 +4,10 @@ import statistics from typing import Callable -from tqdm.rich import tqdm from triton import runtime +from .progress_bar import iter_with_progress + def interleaved_bench( fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None @@ -38,9 +39,15 @@ def interleaved_bench( ] di.synchronize() - iterator = range(repeat) - if desc is not None: - iterator = tqdm(iterator, desc=desc, total=repeat, unit="round") + + # When a description is supplied we show a progress bar so the user can + # track the repeated benchmarking loop. + iterator = iter_with_progress( + range(repeat), + total=repeat, + description=desc, + enabled=desc is not None, + ) for i in iterator: for j in range(len(fns)): clear_cache() diff --git a/helion/autotuner/differential_evolution.py b/helion/autotuner/differential_evolution.py index b1655b301..41a6a3fc1 100644 --- a/helion/autotuner/differential_evolution.py +++ b/helion/autotuner/differential_evolution.py @@ -95,7 +95,8 @@ def _autotune(self) -> Config: ) self.initial_two_generations() for i in range(2, self.max_generations): + self.log(f"Generation {i} starting") replaced = self.evolve_population() - self.log(f"Generation {i}: replaced={replaced}", self.statistics) + self.log(f"Generation {i} complete: replaced={replaced}", self.statistics) self.rebenchmark_population() return self.best.config diff --git a/helion/autotuner/local_cache.py b/helion/autotuner/local_cache.py index 99afbfec6..22543dd71 100644 --- a/helion/autotuner/local_cache.py +++ b/helion/autotuner/local_cache.py @@ -94,6 +94,10 @@ def put(self, config: Config) -> None: path = self._get_local_cache_path() config.save(path) + def _get_cache_info_message(self) -> str: + cache_dir = self._get_local_cache_path().parent + return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1." + class StrictLocalAutotuneCache(LocalAutotuneCache): """ diff --git a/helion/autotuner/pattern_search.py b/helion/autotuner/pattern_search.py index de23f6218..a1b322aed 100644 --- a/helion/autotuner/pattern_search.py +++ b/helion/autotuner/pattern_search.py @@ -46,7 +46,7 @@ def __init__( def _autotune(self) -> Config: self.log( - f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}" + f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}" ) visited = set() self.population = [] @@ -59,7 +59,7 @@ def _autotune(self) -> Config: self.population.append(member) self.parallel_benchmark_population(self.population, desc="Initial population") # again with higher accuracy - self.rebenchmark_population(self.population, desc="Initial rebench") + self.rebenchmark_population(self.population, desc="Verifying initial results") self.population.sort(key=performance) starting_points = [] for member in self.population[: self.copies]: @@ -88,21 +88,25 @@ def _autotune(self) -> Config: new_population[id(member)] = member if num_active == 0: break + + # Log generation header before compiling/benchmarking + self.log( + f"Generation {generation} starting: {num_neighbors} neighbors, {num_active} active search path(s)" + ) + self.population = [*new_population.values()] # compile any unbenchmarked members in parallel unbenchmarked = [m for m in self.population if len(m.perfs) == 0] if unbenchmarked: self.parallel_benchmark_population( - unbenchmarked, desc=f"Gen {generation} neighbors" + unbenchmarked, desc=f"Generation {generation}: Exploring neighbors" ) # higher-accuracy rebenchmark self.rebenchmark_population( - self.population, desc=f"Gen {generation} rebench" - ) - self.log( - f"Generation {generation}, {num_neighbors} neighbors, {num_active} active:", - self.statistics, + self.population, desc=f"Generation {generation}: Verifying top configs" ) + # Log final statistics for this generation + self.log(f"Generation {generation} complete:", self.statistics) return self.best.config def _pattern_search_from( diff --git a/helion/autotuner/progress_bar.py b/helion/autotuner/progress_bar.py new file mode 100644 index 000000000..f69896837 --- /dev/null +++ b/helion/autotuner/progress_bar.py @@ -0,0 +1,70 @@ +"""Progress-bar utilities used by the autotuner. + +We rely on `rich` to render colored, full-width progress bars that +show the description, percentage complete, and how many items have been +processed. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import TypeVar + +from rich.progress import BarColumn +from rich.progress import MofNCompleteColumn +from rich.progress import Progress +from rich.progress import ProgressColumn +from rich.progress import TextColumn +from rich.text import Text + +if TYPE_CHECKING: + from collections.abc import Iterable + from collections.abc import Iterator + + from rich.progress import Task + +T = TypeVar("T") + + +class SpeedColumn(ProgressColumn): + """Render the processing speed in configs per second.""" + + def render(self, task: Task) -> Text: + return Text( + f"{task.speed:.1f} configs/s" if task.speed is not None else "- configs/s", + style="magenta", + ) + + +def iter_with_progress( + iterable: Iterable[T], *, total: int, description: str | None = None, enabled: bool +) -> Iterator[T]: + """Yield items from *iterable*, optionally showing a progress bar. + + Parameters + ---------- + iterable: + Any iterable whose items should be yielded. + total: + Total number of items expected from the iterable. + description: + Text displayed on the left side of the bar. Defaults to ``"Progress"``. + enabled: + When ``False`` the iterable is returned unchanged so there is zero + overhead; when ``True`` a Rich progress bar is rendered. + """ + if not enabled: + yield from iterable + return + + if description is None: + description = "Progress" + + with Progress( + TextColumn("[progress.description]{task.description}"), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + BarColumn(bar_width=None, complete_style="yellow", finished_style="green"), + MofNCompleteColumn(), + SpeedColumn(), + ) as progress: + yield from progress.track(iterable, total=total, description=description) diff --git a/requirements.txt b/requirements.txt index 894173feb..965858186 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ pre-commit filecheck expecttest numpy -tqdm rich