Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,11 @@ def connect(self, event: str, callback: Callable[..., None]) -> None:
}

theme_variables = pytorch_sphinx_theme2.get_theme_variables()
templates_path = [
"_templates",
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
]
templates_path = ["_templates"]
if pytorch_sphinx_theme2.__file__ is not None:
templates_path.append(
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates")
)

html_context = {
"theme_variables": theme_variables,
Expand Down
10 changes: 10 additions & 0 deletions helion/autotuner/base_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,28 @@ def get(self) -> Config | None:
def put(self, config: Config) -> None:
raise NotImplementedError

def _get_cache_info_message(self) -> str:
"""Return a message describing where the cache is and how to clear it."""
return ""

def autotune(self) -> Config:
if os.environ.get("HELION_SKIP_CACHE", "") not in {"", "0", "false", "False"}:
return self.autotuner.autotune()

if (config := self.get()) is not None:
counters["autotune"]["cache_hit"] += 1
log.debug("cache hit: %s", str(config))
cache_info = self._get_cache_info_message()
self.autotuner.log(
f"Found cached config for {self.kernel.kernel.name}, skipping autotuning.\n{cache_info}"
)
return config

counters["autotune"]["cache_miss"] += 1
log.debug("cache miss")

self.autotuner.log("Starting autotuning process, this may take a while...")

config = self.autotuner.autotune()

self.put(config)
Expand Down
19 changes: 9 additions & 10 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import torch.multiprocessing as mp
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_map
from tqdm.rich import tqdm
from triton.testing import do_bench

from .. import exc
Expand All @@ -40,6 +39,7 @@
from .logger import LambdaLogger
from .logger import classify_triton_exception
from .logger import format_triton_compile_failure
from .progress_bar import iter_with_progress

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -321,15 +321,14 @@ def parallel_benchmark(
else:
is_workings = [True] * len(configs)
results = []
iterator = zip(configs, fns, is_workings, strict=True)
if self.settings.autotune_progress_bar:
iterator = tqdm(
iterator,
total=len(configs),
desc=desc,
unit="config",
disable=not self.settings.autotune_progress_bar,
)

# Render a progress bar only when the user requested it.
iterator = iter_with_progress(
zip(configs, fns, is_workings, strict=True),
total=len(configs),
description=desc,
enabled=self.settings.autotune_progress_bar,
)
for config, fn, is_working in iterator:
if is_working:
# benchmark one-by-one to avoid noisy results
Expand Down
15 changes: 11 additions & 4 deletions helion/autotuner/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import statistics
from typing import Callable

from tqdm.rich import tqdm
from triton import runtime

from .progress_bar import iter_with_progress


def interleaved_bench(
fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None
Expand Down Expand Up @@ -38,9 +39,15 @@ def interleaved_bench(
]

di.synchronize()
iterator = range(repeat)
if desc is not None:
iterator = tqdm(iterator, desc=desc, total=repeat, unit="round")

# When a description is supplied we show a progress bar so the user can
# track the repeated benchmarking loop.
iterator = iter_with_progress(
range(repeat),
total=repeat,
description=desc,
enabled=desc is not None,
)
for i in iterator:
for j in range(len(fns)):
clear_cache()
Expand Down
3 changes: 2 additions & 1 deletion helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def _autotune(self) -> Config:
)
self.initial_two_generations()
for i in range(2, self.max_generations):
self.log(f"Generation {i} starting")
replaced = self.evolve_population()
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
self.log(f"Generation {i} complete: replaced={replaced}", self.statistics)
self.rebenchmark_population()
return self.best.config
4 changes: 4 additions & 0 deletions helion/autotuner/local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def put(self, config: Config) -> None:
path = self._get_local_cache_path()
config.save(path)

def _get_cache_info_message(self) -> str:
cache_dir = self._get_local_cache_path().parent
return f"Cache directory: {cache_dir}. To run autotuning again, delete the cache directory or set HELION_SKIP_CACHE=1."


class StrictLocalAutotuneCache(LocalAutotuneCache):
"""
Expand Down
20 changes: 12 additions & 8 deletions helion/autotuner/pattern_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(

def _autotune(self) -> Config:
self.log(
f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}"
f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}"
)
visited = set()
self.population = []
Expand All @@ -59,7 +59,7 @@ def _autotune(self) -> Config:
self.population.append(member)
self.parallel_benchmark_population(self.population, desc="Initial population")
# again with higher accuracy
self.rebenchmark_population(self.population, desc="Initial rebench")
self.rebenchmark_population(self.population, desc="Verifying initial results")
self.population.sort(key=performance)
starting_points = []
for member in self.population[: self.copies]:
Expand Down Expand Up @@ -88,21 +88,25 @@ def _autotune(self) -> Config:
new_population[id(member)] = member
if num_active == 0:
break

# Log generation header before compiling/benchmarking
self.log(
f"Generation {generation} starting: {num_neighbors} neighbors, {num_active} active search path(s)"
)

self.population = [*new_population.values()]
# compile any unbenchmarked members in parallel
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
if unbenchmarked:
self.parallel_benchmark_population(
unbenchmarked, desc=f"Gen {generation} neighbors"
unbenchmarked, desc=f"Generation {generation}: Exploring neighbors"
)
# higher-accuracy rebenchmark
self.rebenchmark_population(
self.population, desc=f"Gen {generation} rebench"
)
self.log(
f"Generation {generation}, {num_neighbors} neighbors, {num_active} active:",
self.statistics,
self.population, desc=f"Generation {generation}: Verifying top configs"
)
# Log final statistics for this generation
self.log(f"Generation {generation} complete:", self.statistics)
return self.best.config

def _pattern_search_from(
Expand Down
70 changes: 70 additions & 0 deletions helion/autotuner/progress_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Progress-bar utilities used by the autotuner.

We rely on `rich` to render colored, full-width progress bars that
show the description, percentage complete, and how many items have been
processed.
"""

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TypeVar

from rich.progress import BarColumn
from rich.progress import MofNCompleteColumn
from rich.progress import Progress
from rich.progress import ProgressColumn
from rich.progress import TextColumn
from rich.text import Text

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterator

from rich.progress import Task

T = TypeVar("T")


class SpeedColumn(ProgressColumn):
"""Render the processing speed in configs per second."""

def render(self, task: Task) -> Text:
return Text(
f"{task.speed:.1f} configs/s" if task.speed is not None else "- configs/s",
style="magenta",
)


def iter_with_progress(
iterable: Iterable[T], *, total: int, description: str | None = None, enabled: bool
) -> Iterator[T]:
"""Yield items from *iterable*, optionally showing a progress bar.

Parameters
----------
iterable:
Any iterable whose items should be yielded.
total:
Total number of items expected from the iterable.
description:
Text displayed on the left side of the bar. Defaults to ``"Progress"``.
enabled:
When ``False`` the iterable is returned unchanged so there is zero
overhead; when ``True`` a Rich progress bar is rendered.
"""
if not enabled:
yield from iterable
return

if description is None:
description = "Progress"

with Progress(
TextColumn("[progress.description]{task.description}"),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
BarColumn(bar_width=None, complete_style="yellow", finished_style="green"),
MofNCompleteColumn(),
SpeedColumn(),
) as progress:
yield from progress.track(iterable, total=total, description=description)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ pre-commit
filecheck
expecttest
numpy
tqdm
rich
Loading