Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 68 additions & 30 deletions torch/_inductor/async_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing
import os
import sys
import hashlib
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from functools import partial
Expand All @@ -18,15 +19,16 @@
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
from torch._inductor import config
from torch._inductor.codecache import (
_load_triton_kernel_from_source,
CodeCacheFuture,
CppCodeCache,
CppPythonBindingsCodeCache,
CUDACodeCache,
HalideCodeCache,
LambdaFuture,
ROCmCodeCache,
TritonCodeCache,
TritonFuture,
torch_key,
code_hash,
)
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
Expand All @@ -39,9 +41,9 @@
from torch.utils._ordered_set import OrderedSet
from torch.utils._triton import has_triton_package


if TYPE_CHECKING:
from torch._inductor.runtime.hints import HalideMeta
from torch._inductor.runtime.triton_heuristics import CachingAutotuner

# timing metrics for time spent in the compilation
_cumulative_compile_time = 0.0
Expand Down Expand Up @@ -128,12 +130,45 @@
config.compile_threads = config.decide_compile_threads()
return config.compile_threads


@clear_on_fresh_inductor_cache
@functools.lru_cache(None)
def get_future_cache():
return {}
class CompiledTritonKernels:
"""
In memory cache for storing compiled triton kernels.

Each triton kernel is keyed by the hash of its source code. Each value stored
in the cache is a return value of AsyncCompile.triton().

Currently, the cache stores Future objects, but it should be generalizable for any kernels.
"""
_cache = {}

Check failure on line 143 in torch/_inductor/async_compile.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [var-annotated]

Need type annotation for "_cache" (hint: "_cache: dict[<type>, <type>] = ...")

@staticmethod
def key(kernel_src : str):
"""
Generates a cache key given a
"""
return code_hash(kernel_src, extra=str(torch_key()))

@staticmethod
def save(kernel_src : str, future : LambdaFuture):
"""
Saves a compiled triton kernel to the cache.
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
but the real type we want to return here is actually an abstract triton kernel.

TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc.
so it could be less strict.
"""
key = CompiledTritonKernels.key(kernel_src)
CompiledTritonKernels._cache[key] = future

@staticmethod
def get(kernel_src : str, default : Any) -> LambdaFuture:
return CompiledTritonKernels._cache.get(kernel_src, default)

@staticmethod
def cache_clear():
CompiledTritonKernels._cache = {}

class AsyncCompile:
def __init__(self) -> None:
Expand Down Expand Up @@ -208,51 +243,54 @@
)

def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
if future := CompiledTritonKernels.get(source_code, None):
counters["inductor"]["async_compile_cache_hit"] += 1
return future

counters["inductor"]["async_compile_cache_miss"] += 1

kernel_code_log.info("Triton Kernel:\n%s", source_code)
_compile_start()
_set_triton_ptxas_path()

if os.environ.get("TRITON_INTERPRET", "0") == "1":
return getattr(
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
)

kernel = TritonCodeCache.load(kernel_name, source_code)
if self.use_process_pool():
set_feature_use("parallel_compile_post_warmup", True)
load_kernel = functools.partial(
_load_triton_kernel_from_source, kernel_name, source_code
)
is_parallel = self.use_process_pool()
set_feature_use("parallel_compile_post_warmup", is_parallel)
if is_parallel:
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}

future_cache = get_future_cache()

if future := future_cache.get(source_code, None):
counters["inductor"]["async_compile_cache_hit"] += 1
return future

counters["inductor"]["async_compile_cache_miss"] += 1
future = TritonFuture(
kernel,
self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
extra_env,
),
task = self.process_pool().submit(
_worker_compile_triton,
load_kernel,
extra_env,
)
future_cache[source_code] = future
def get_result() -> CachingAutotuner:
kernel = task.result()
kernel.precompile(warm_cache_only=False, reload_in_parent=load_kernel)
return kernel
future = LambdaFuture(get_result, future=task)
CompiledTritonKernels.save(source_code, future)
return future

else:
set_feature_use("parallel_compile_post_warmup", False)
with dynamo_timed(
"async_compile.precompile",
log_pt2_compile_event=True,
dynamo_compile_column_us="triton_compile_time_us",
log_waitcounter=True,
):
kernel.precompile()
return kernel
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.precompile(warm_cache_only=False)
return kernel

def multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import MultiKernelCall
Expand Down
33 changes: 7 additions & 26 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
_reload_python_module_in_subproc,
)
Expand Down Expand Up @@ -2815,10 +2814,10 @@ def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
return parse_stack_trace(entry)


class TritonCodeCache:
@classmethod
def load(cls, kernel_name: str, source_code: str) -> ModuleType:
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
def _load_triton_kernel_from_source(
kernel_name: str, source_code: str
) -> CachingAutotuner:
return getattr(PyCodeCache.load(source_code), kernel_name)


def _cuda_compiler() -> Optional[str]:
Expand Down Expand Up @@ -3222,30 +3221,12 @@ def result(self) -> Callable[..., Any]:
raise NotImplementedError


class TritonFuture(CodeCacheFuture):
kernel: CachingAutotuner

class LambdaFuture(CodeCacheFuture):
def __init__(
self,
kernel: Any,
future: Optional[Future[Any]],
self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None
) -> None:
self.kernel = kernel
self.future = future

def result(self) -> Callable[..., Any]:
if self.future is not None:
# If the worker failed this will throw an exception.
result = self.future.result()
assert result is None
self.future = None
self.kernel.precompile()
return self.kernel


class LambdaFuture(CodeCacheFuture):
def __init__(self, result_fn: Callable[..., Any]) -> None:
self.result_fn = result_fn
self.future = future

def result(self) -> Callable[..., Any]: # type: ignore[override]
return self.result_fn()
21 changes: 16 additions & 5 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
from functools import lru_cache
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union


import sympy
from sympy.printing.precedence import PRECEDENCE

import torch
import torch._logging
import torch.utils._pytree as pytree
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import identity, preserve_rng_state
from torch._dynamo.utils import identity, preserve_rng_state, dynamo_timed
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._triton import has_triton_package

from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
Expand Down Expand Up @@ -95,6 +95,7 @@
should_unwrap_unspec_arg,
signature_to_meta,
)
from ..async_compile import AsyncCompile


if TYPE_CHECKING:
Expand All @@ -110,7 +111,7 @@
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")

async_compile = AsyncCompile()

class OpDtypeSupport:
"""
Expand Down Expand Up @@ -3843,7 +3844,6 @@ def iteration_ranges_codegen_header(
else:
code.writeline(f"{x}mask = {entry.name} < {x}numel")


class TritonScheduling(SIMDScheduling):
kernel_type: type[Any] = TritonKernel
backend_features = OrderedSet(
Expand Down Expand Up @@ -3939,9 +3939,20 @@ def define_kernel(self, src_code, node_schedule, kernel):
src_code = src_code.replace("#pragma CMT", "#")

_basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")

compile_wrapper = IndentedBuffer()
# TODO: Refactor this code so that instead of calling async_compile.triton after the entire code has been generated, we
# kick off the worker process here to start compiling the code, and save the Future object to await later.

# If it's a TritonBundler cache hit, we can avoid that altogether and return the compiled kernel directly.
if async_compile.use_process_pool():
# The process pool is warm, we can shell out to workers right away. This
# allows us to save the result in async_compile.CompiledTritonKernels,
# so that the second time we call async_compile.triton, we do no work.
async_compile.triton(subs_name, src_code)

compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")


compile_wrapper.splice(src_code, strip=True)
current_device = V.graph.get_current_device_or_throw()
compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
Expand Down
7 changes: 5 additions & 2 deletions torch/_inductor/runtime/compile_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def _set_triton_ptxas_path() -> None:

def _worker_compile_triton(
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
) -> None:
) -> CachingAutotuner:
_set_triton_ptxas_path()
os.environ.update(extra_env)
load_kernel().precompile(warm_cache_only=True)
kernel = load_kernel()
kernel.precompile(warm_cache_only=True)
kernel.prepare_for_pickle()
return kernel
10 changes: 6 additions & 4 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def precompile(

def _precompile_worker(self):
if self.compile_results:
for result in self.compile_results:
TritonBundler.put(
triton_hash_to_path_key(result.kernel.hash), self.triton_meta.get("device", 0)
)
return
assert not self.launchers
if not self.configs:
Expand Down Expand Up @@ -415,6 +419,7 @@ def _make_launchers(self):
for result in self.compile_results:
try:
launchers.append(result.make_launcher())

except (OutOfResources, PTXASError) as e:
exc = e
if len(launchers) == 0:
Expand Down Expand Up @@ -519,10 +524,7 @@ def _precompile_config(self, cfg: Config) -> TritonCompileResult:
compile_meta,
)
raise

TritonBundler.put(
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
)
TritonBundler.put(triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0))
return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta)

def _get_args_with_constexprs(self, args, launcher):
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import PyCodeCache, TritonFuture
from torch._inductor.codecache import PyCodeCache, LambdaFuture
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
Expand Down Expand Up @@ -2734,7 +2734,7 @@ def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None:

def compile_kernel(
nodes: Sequence[BaseSchedulerNode],
) -> tuple[Optional[TritonFuture], ModuleType]:
) -> tuple[Optional[LambdaFuture], ModuleType]:
src_code = self.generate_kernel_code_from_nodes(
nodes, benchmark_kernel=True
)
Expand All @@ -2743,7 +2743,7 @@ def compile_kernel(
fut = None
else:
fut = async_compile.triton(kernel_name="triton_", source_code=src_code)
assert isinstance(fut, TritonFuture)
assert isinstance(fut, LambdaFuture)

return (fut, mod)

Expand Down Expand Up @@ -2772,7 +2772,7 @@ def compile_kernel(
)

# Start compiling choices in parallel
future_choices: List[tuple[Any, Optional[TritonFuture], ModuleType]] = []
future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
triton_choices = 0
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]
Expand Down
Loading