From 6e7588fdd379f444e39646952797d1f3b429160a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 15 Oct 2025 19:39:32 -0700 Subject: [PATCH] Exit autotuning faster on KeyboardInterrupt stack-info: PR: https://github.com/pytorch/helion/pull/963, branch: jansel/stack/202 --- helion/autotuner/base_search.py | 51 +++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 2f37bf8dc..cabbed8ab 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Iterable from typing import Literal from typing import NoReturn from typing import cast @@ -362,6 +363,7 @@ def start_precompile_and_check_for_hangs( args=(fn_spec, self._precompile_args_path, child_conn, decorator), ), ) + process.daemon = True else: ctx = mp.get_context("fork") parent_conn, child_conn = ctx.Pipe() @@ -372,6 +374,7 @@ def start_precompile_and_check_for_hangs( args=(fn, device_args, config, self.kernel, child_conn, decorator), ), ) + process.daemon = True return PrecompileFuture( search=self, config=config, @@ -937,11 +940,8 @@ def wait_for_all( while progress_left > len(remaining): next(progress, None) progress_left -= 1 - except Exception: - for f in remaining: - if (p := f.process) is not None: - with contextlib.suppress(Exception): - p.terminate() + except BaseException: + PrecompileFuture._cancel_all(futures) raise result = [] for f in futures: @@ -983,6 +983,47 @@ def _wait_for_all_step( remaining.append(f) return remaining + @staticmethod + def _cancel_all(futures: Iterable[PrecompileFuture]) -> None: + """Cancel any futures that have not completed.""" + active = [future for future in futures if future.ok is None] + for future in active: + with contextlib.suppress(Exception): + future._kill_without_wait() + for future in active: + with contextlib.suppress(Exception): + future.cancel() + + def _kill_without_wait(self) -> None: + """Issue a hard kill to the underlying process without waiting for exit.""" + process = self.process + if process is None or not self.started: + return + if process.is_alive(): + with contextlib.suppress(Exception): + process.kill() + + def cancel(self) -> None: + """Terminate the underlying process (if any) without waiting for success.""" + self.end_time = time.time() + process = self.process + if process is not None: + if self.started: + with contextlib.suppress(Exception): + if process.is_alive(): + process.kill() + process.join() + if self.child_conn is not None: + with contextlib.suppress(Exception): + self.child_conn.close() + self.child_conn = None + if self.ok is None: + self.ok = False + if self.failure_reason is None: + self.failure_reason = "error" + self._recv_result(block=False) + self._handle_remote_error(raise_on_raise=False) + def _mark_complete(self) -> bool: """ Mark the precompile future as complete and kill the process if needed.