Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import NoReturn
from typing import cast
Expand Down Expand Up @@ -362,6 +363,7 @@ def start_precompile_and_check_for_hangs(
args=(fn_spec, self._precompile_args_path, child_conn, decorator),
),
)
process.daemon = True
else:
ctx = mp.get_context("fork")
parent_conn, child_conn = ctx.Pipe()
Expand All @@ -372,6 +374,7 @@ def start_precompile_and_check_for_hangs(
args=(fn, device_args, config, self.kernel, child_conn, decorator),
),
)
process.daemon = True
return PrecompileFuture(
search=self,
config=config,
Expand Down Expand Up @@ -937,11 +940,8 @@ def wait_for_all(
while progress_left > len(remaining):
next(progress, None)
progress_left -= 1
except Exception:
for f in remaining:
if (p := f.process) is not None:
with contextlib.suppress(Exception):
p.terminate()
except BaseException:
PrecompileFuture._cancel_all(futures)
raise
result = []
for f in futures:
Expand Down Expand Up @@ -983,6 +983,47 @@ def _wait_for_all_step(
remaining.append(f)
return remaining

@staticmethod
def _cancel_all(futures: Iterable[PrecompileFuture]) -> None:
"""Cancel any futures that have not completed."""
active = [future for future in futures if future.ok is None]
for future in active:
with contextlib.suppress(Exception):
future._kill_without_wait()
for future in active:
with contextlib.suppress(Exception):
future.cancel()

def _kill_without_wait(self) -> None:
"""Issue a hard kill to the underlying process without waiting for exit."""
process = self.process
if process is None or not self.started:
return
if process.is_alive():
with contextlib.suppress(Exception):
process.kill()

def cancel(self) -> None:
"""Terminate the underlying process (if any) without waiting for success."""
self.end_time = time.time()
process = self.process
if process is not None:
if self.started:
with contextlib.suppress(Exception):
if process.is_alive():
process.kill()
process.join()
if self.child_conn is not None:
with contextlib.suppress(Exception):
self.child_conn.close()
self.child_conn = None
if self.ok is None:
self.ok = False
if self.failure_reason is None:
self.failure_reason = "error"
self._recv_result(block=False)
self._handle_remote_error(raise_on_raise=False)

def _mark_complete(self) -> bool:
"""
Mark the precompile future as complete and kill the process if needed.
Expand Down
Loading