diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 1fc362d547788..9731a081f0414 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1,13 +1,19 @@ # Owner(s): ["module: inductor"] +from typing import Optional import torch from torch import multiprocessing as mp from torch._dynamo.test_case import run_tests, TestCase from torch._inductor import config +from torch._inductor.autotune_process import BenchmarkRequest, TuningProcessPool from torch._inductor.graph import GraphLowering from torch._inductor.ir import Buffer, FixedLayout from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm -from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller +from torch._inductor.select_algorithm import ( + AlgorithmSelectorCache, + ChoiceCaller, + TritonTemplateCaller, +) from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx @@ -118,7 +124,8 @@ def test_benchmark_choice_fail_in_subproc(self): self.assertNotEqual(0, child.exitcode) @parametrize("autotune_in_subproc", (True, False)) - def test_max_autotune_mm_plus_mm(self, autotune_in_subproc): + @parametrize("autotune_multi_device", (True, False)) + def test_max_autotune_mm_plus_mm(self, autotune_in_subproc, autotune_multi_device): """ This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 . With autotuning in subprocess, we don't crash anymore. @@ -134,7 +141,11 @@ def mm_plus_mm(a, b, c, d): d = torch.randn(k, n).cuda() with config.patch( - {"max_autotune": True, "autotune_in_subproc": autotune_in_subproc} + { + "max_autotune": True, + "autotune_in_subproc": autotune_in_subproc, + "autotune_multi_device": autotune_multi_device, + } ): torch.compile(mm_plus_mm)(a, b, c, d) @@ -318,6 +329,53 @@ def fn( torch.testing.assert_close(y1, y1_expected) +class TestBenchmarkRequest(BenchmarkRequest): + def __init__(self, value: Optional[float] = None) -> None: + self.value = value + + def benchmark( + self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None + ) -> float: + if self.value is None: + raise Exception("Failed to run") + return self.value + + +class TestTritonTemplateCaller(TritonTemplateCaller): + def __init__(self, bmreq: TestBenchmarkRequest): + self.bmreq = bmreq + + def __str__(self) -> str: + return "test" + + +class TestTuningProcess(TestCase): + def test_tuning_pool(self): + # Use only one device: + with config.patch({"autotune_multi_device": False}): + tuning_pool = TuningProcessPool() + tuning_pool.initialize() + + # First cause the tuning process to crash. + bmreq = TestBenchmarkRequest(value=None) + choice = TestTritonTemplateCaller(bmreq) + + timings = tuning_pool.benchmark([choice]) + self.assertTrue(choice in timings) + self.assertEqual(timings[choice], float("inf")) + + # Then send another request and make sure the sub-process + # has restarted and is operational. + value = 3.14 + choice.bmreq.value = value + + timings = tuning_pool.benchmark([choice]) + self.assertTrue(choice in timings) + self.assertEqual(timings[choice], value) + + tuning_pool.terminate() + + if __name__ == "__main__": if HAS_CUDA: run_tests() diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 6392706c33541..011109f6f61f2 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -20,7 +20,7 @@ def patches(fn): def skip_cache(self, choices, name, key, generate): - return {choice: generate(choice) for choice in choices} + return generate(choices) for patcher in [ dynamo_config.patch(verbose=True), diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 92f5a314537db..d09da015a2c25 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import dataclasses import logging +import os import queue import time import warnings +from concurrent.futures import ThreadPoolExecutor from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union @@ -17,9 +21,11 @@ if TYPE_CHECKING: from torch._inductor.select_algorithm import TritonTemplateCaller +from . import config from .utils import do_bench from .virtualized import V +CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" EXIT_HANDLER_REGISTERED = False log = logging.getLogger(__name__) @@ -42,29 +48,30 @@ class TuningProcess: requests and return results. """ + device: Optional[int] = None process: Optional[BaseProcess] = None - request_queue: Optional["Queue[Any]"] = None - response_queue: Optional["Queue[Any]"] = None + request_queue: Optional[Queue[Any]] = None + response_queue: Optional[Queue[Any]] = None @staticmethod def process_main( - request_queue: "Queue[Any]", - response_queue: "Queue[Any]", + device: Optional[int], + request_queue: Queue[Any], + response_queue: Queue[Any], ) -> None: """ Entry point for the child process. """ - log.debug("Entering TuningProcess child main") + log.debug("Entering TuningProcess child main: %s", device) + if device is not None: + os.environ[CUDA_VISIBLE_DEVICES] = str(device) try: TuningProcess.workloop(request_queue, response_queue) except Exception as ex: log.exception("Exception in TuningProcess: %s", ex) @staticmethod - def workloop( - request_queue: "Queue[Any]", - response_queue: "Queue[Any]", - ) -> None: + def workloop(request_queue: Queue[Any], response_queue: Queue[Any]) -> None: """ Work loop for the benchmarking subprocess. """ @@ -99,6 +106,8 @@ def clear(self) -> None: def initialize(self) -> None: """ Create child process, request/response queues, and do the warm up. + Set the environment to make only the provided GPU device visible + to the process. """ if self.valid(): return @@ -111,6 +120,7 @@ def initialize(self) -> None: self.process = ctx.Process( target=self.process_main, args=( + self.device, self.request_queue, self.response_queue, ), @@ -118,19 +128,6 @@ def initialize(self) -> None: assert self.process is not None self.process.start() - # register the exit handler for the parent process so it will terminate - # the child processes - global EXIT_HANDLER_REGISTERED - if not EXIT_HANDLER_REGISTERED: - EXIT_HANDLER_REGISTERED = True - import atexit - - atexit.register(lambda: self.terminate()) - - # wait for the initialization to be done - self.put(Ping()) - assert isinstance(self.get(), Pong) - def put(self, obj: Any) -> None: """ Push a work item to the child process. @@ -160,17 +157,147 @@ def get(self) -> Any: def terminate(self) -> None: """ - Signal the child process to terminate and wait for it to exit. + Signal the child process to terminate. """ if self.valid(): assert self.process is not None assert self.request_queue is not None self.request_queue.put(None) + + def wait(self) -> None: + """ + Wait for the child process to exit. + """ + if self.process is not None: self.process.join() self.clear() -tuning_process = TuningProcess() +@dataclasses.dataclass +class TuningProcessPool: + """ + Maintains a pool of TuningProcesses to benchmark kernels in parallel + across devices. By default, we create one TuningProcess per device and + set the sub-process environment to make only that device visible. + """ + + processes: Optional[queue.Queue[TuningProcess]] = None + executor: Optional[ThreadPoolExecutor] = None + + def initialize(self) -> None: + """ + Start the child processes. + """ + assert (self.processes is None) == (self.executor is None) + if self.processes is not None: + return + + devices = self.get_device_list() + log.debug("Device list: %s", devices) + + # Launch the child processes and push a msg to "warm up" + self.processes = queue.Queue() + for device in devices: + p = TuningProcess(device=device) + p.initialize() + p.put(Ping()) + self.processes.put(p) + + # Wait for the initialization to finish + for p in self.processes.queue: + assert isinstance(p.get(), Pong) + + # Use a thread pool to manage distributing work to the subprocesses. + # Threads block on an available process, so it makes sense to match + # the number of threads with the number of devices. + self.executor = ThreadPoolExecutor(max_workers=len(devices)) + + # Register the exit handler for the parent process so it will terminate + # the child processes. + global EXIT_HANDLER_REGISTERED + if not EXIT_HANDLER_REGISTERED: + EXIT_HANDLER_REGISTERED = True + import atexit + + atexit.register(lambda: self.terminate()) + + def get_device_list(self) -> List[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + count = torch.cuda.device_count() + + # If the user specified the visible devices in the env, use those. + if CUDA_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] + assert len(devices) <= count + return devices # type: ignore[return-value] + + return list(range(count)) + + def terminate(self) -> None: + """ + Signal all child processes to terminate. + """ + if self.executor is not None: + self.executor.shutdown() + self.executor = None + + if self.processes is not None: + for p in self.processes.queue: + p.terminate() + for p in self.processes.queue: + p.wait() + self.processes = None + + def target(self, choice: TritonTemplateCaller) -> float: + """ + Entry point for the thread-pool helper threads: Wait for an open TuningProcess, + remove it from the queue, execute the benchmark in that subprocess, and return + the TuningProcess to the queue. + """ + assert choice.bmreq is not None + assert self.processes is not None + + process = self.processes.get() + process.put(choice.bmreq) + try: + return process.get() + except queue.Empty: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # set to INF so this choice will be ignored + return float("inf") + finally: + self.processes.put(process) + + def benchmark( + self, + choices: List[TritonTemplateCaller], + ) -> Dict[TritonTemplateCaller, float]: + """ + Benchmark each choice in a separate process. + """ + assert self.processes is not None, "Tuning process pool is not initialized" + assert self.executor is not None + + results = {} + + # Use a ThreadExecutorPool to spread the work across the subproccesses and + # to grab subprocesses as soon as they're free. + for choice, result in zip(choices, self.executor.map(self.target, choices)): + results[choice] = result + + return results + + +tuning_pool = TuningProcessPool() LayoutOrBuffer = Union[ir.Layout, ir.Buffer] @@ -187,7 +314,7 @@ class TensorMeta: @classmethod def from_irnodes( cls, irnodes: Union[LayoutOrBuffer, Tuple[LayoutOrBuffer], List[LayoutOrBuffer]] - ) -> Union["TensorMeta", List["TensorMeta"]]: + ) -> Union[TensorMeta, List[TensorMeta]]: if isinstance(irnodes, (tuple, list)): result: List[Any] = [cls.from_irnodes(x) for x in irnodes] assert all(isinstance(x, TensorMeta) for x in result) @@ -233,8 +360,8 @@ class BenchmarkRequest: num_stages: int num_warps: int - input_tensors: Union["TensorMeta", List["TensorMeta"]] - output_tensor: Union["TensorMeta", List["TensorMeta"]] + input_tensors: Union[TensorMeta, List[TensorMeta]] + output_tensor: Union[TensorMeta, List[TensorMeta]] def benchmark( self, *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None @@ -296,22 +423,9 @@ def worker() -> float: def benchmark_in_sub_process( - choice: "TritonTemplateCaller", -) -> float: + choices: List[TritonTemplateCaller], +) -> Dict[TritonTemplateCaller, float]: """ Do benchmarking in a subprocess and return the perf number (latency). """ - assert choice.bmreq is not None - tuning_process.initialize() - assert tuning_process.valid() - - tuning_process.put(choice.bmreq) - try: - return tuning_process.get() - except queue.Empty: - warnings.warn( - f"Fail to benchmark choice '{choice}'. It will be ignored. " - "Please debug the root cause in case the choice can bring perf gains." - ) - # return INF so this choice will be ignored - return float("inf") + return tuning_pool.benchmark(choices) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index a290f44fa7abf..39946674858a4 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -270,11 +270,13 @@ def check_cache(cache, callback=None): ): try: # re-benchmark everything to try to get consistent numbers from the same machine - for choice in choices: - timings[choice] = benchmark(choice) - local_cache.setdefault(name, {}) - local_cache[name].setdefault(inputs, {}) - local_cache[name][inputs][choice.hash_key()] = timings[choice] + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + + local_cache.setdefault(name, {}) + local_cache[name].setdefault(inputs, {}) + for choice, timing in timings.items(): + local_cache[name][inputs][choice.hash_key()] = timing except RuntimeError as e: # catch and log autotuning failures log_errors(e) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 17699bec3c722..67e465a4c95aa 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -112,6 +112,9 @@ # We will disable creating subprocess for autotuning if this is False autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + coordinate_descent_tuning = ( os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 64495368a6ca2..ee0ed79fef7f5 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -774,31 +774,14 @@ def __call__( def make_benchmark_fn(): return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) - def autotune(choice): - benchmark_fn = make_benchmark_fn() - try: - timing = benchmark_fn( - choice, - ) - except RuntimeError as e: - msg = str(e) - if "invalid argument" in msg: - msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" - log.warning(msg) - return float("inf") - elif "illegal memory access" in msg: - msg += "\n\nEither error in template or triton bug.\n" - raise ErrorFromChoice(msg, choice, benchmark_fn.debug_str()) - except AssertionError as e: - raise AssertionError(f"Incorrect result from choice {choice}\n\n{e}") - return timing + def autotune(choices): + return make_benchmark_fn()(choices) if config.autotune_in_subproc: - from .autotune_process import tuning_process + from .autotune_process import tuning_pool # do the optional warmup - tuning_process.initialize() - assert tuning_process.valid() + tuning_pool.initialize() autotune_start_ts = time.time() timings = self.lookup( @@ -854,9 +837,22 @@ def make_benchmark_fn( if DEBUG: print(f"{len(choices)} tuning requests:") - def benchmark_in_current_process(choice): - if DEBUG: - start_ts = time.time() + def debug_str(): + def tensor_repr(x): + return ( + f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " + f"dtype={x.dtype!r}, device={x.device.type!r})" + ) + + lines = [ + "inputs = [", + ] + for x in example_inputs: + lines.append(f" {tensor_repr(x)},") + lines += ["]", f"out = {tensor_repr(out)}", ""] + return "\n".join(lines) + + def benchmark_choice_in_current_process(choice): out.zero_() if isinstance(choice, ExternKernelCaller): # aten kernels want the offset baked in for sliced tensors @@ -869,24 +865,41 @@ def benchmark_in_current_process(choice): torch.cuda.synchronize() # shake out any CUDA errors return result - def benchmark_in_sub_process(choice): - # only benchmark triton kernel in sub process for now. - # ATen/Extern kernel are still benchmarked in the current process. - if isinstance(choice, ExternKernelCaller): - return benchmark_in_current_process(choice) - + def benchmark_in_current_process(choices): + timings = {} + for choice in choices: + try: + timing = benchmark_choice_in_current_process(choice) + except RuntimeError as e: + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + log.warning(msg) + timing = float("inf") + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + raise ErrorFromChoice(msg, choice, debug_str()) + except AssertionError as e: + raise AssertionError( + f"Incorrect result from choice {choice}\n\n{e}" + ) + + timings[choice] = timing + + return timings + + def benchmark_in_sub_process(choices): from . import autotune_process - if DEBUG: - start_ts = time.time() + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if isinstance(c, ExternKernelCaller)] + triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] - out = autotune_process.benchmark_in_sub_process( - choice, - ) - if DEBUG: - elapse = time.time() - start_ts - print(f"MultiProcessTuning {choice}: {elapse}") - return out + timings = benchmark_in_current_process(extern) + timings.update(autotune_process.benchmark_in_sub_process(triton)) + return timings benchmark = ( benchmark_in_sub_process @@ -894,22 +907,6 @@ def benchmark_in_sub_process(choice): else benchmark_in_current_process ) - def debug_str(): - def tensor_repr(x): - return ( - f"torch.empty_strided({tuple(x.size())!r}, {tuple(x.stride())!r}, " - f"dtype={x.dtype!r}, device={x.device.type!r})" - ) - - lines = [ - "inputs = [", - ] - for x in example_inputs: - lines.append(f" {tensor_repr(x)},") - lines += ["]", f"out = {tensor_repr(out)}", ""] - return "\n".join(lines) - - benchmark.debug_str = debug_str # type: ignore[attr-defined] return benchmark @staticmethod