Skip to content

Commit

Permalink
Mux prediction events (#1405)
Browse files Browse the repository at this point in the history
* race utility for racing awaitables
* start mux, tag events with id, read pipe in a task, get events from mux
* use async pipe for async child loop
* _shutting_down vs _terminating
* race with shutdown event
* keep reading events during shutdown, but call terminate after the last Done
* emit heartbeats from mux.read
* don't use _wait. instead, setup reads event from the mux too
* worker semaphore and prediction ctx
* where _wait used to raise a fatal error, have _read_events set an error on Mux, and then Mux.read can raise the error in the right context. otherwise, the exception is stuck in a task and doesn't propagate correctly
* fix event loop errors for <3.9
* keep track of predictions in flight explicitly and use that to route logs
* don't wait for executor shutdown
* progress: check for cancelation in task done_handler
* let mux check if child is alive and set mux shutdown after leaving read event loop
* close pipe when exiting
* predict requires IDLE or PROCESSING
* try adding a BUSY state distinct from PROCESSING when we no longer have capacity
* move resetting events to setup() instead of _read_events()

previously this was in _read_events because it's a coroutine that will have the correct event loop. however, _read_events actually gets created in a task, which can run *after* the first mux.read call by setup. since setup is now the first async entrypoint in worker and in tests, we can safely move it there

* state_from_predictions_in_flight instead of checking the value of semaphore
* make prediction_ctx "private"

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Feb 13, 2024
1 parent 9efc0e4 commit f57474d
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 78 deletions.
2 changes: 2 additions & 0 deletions python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import secrets
from typing import Any, Dict

from attrs import define, field, validators
Expand All @@ -8,6 +9,7 @@
@define
class PredictionInput:
payload: Dict[str, Any]
id: str = field(factory=lambda: secrets.token_hex(4))


@define
Expand Down
48 changes: 45 additions & 3 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import io
import os
import selectors
import sys
import threading
import uuid
from multiprocessing.connection import Connection
from typing import (
Any,
Callable,
Coroutine,
Generic,
Optional,
Sequence,
TextIO,
TypeVar,
Union,
)


Expand Down Expand Up @@ -160,13 +163,44 @@ def run(self) -> None:
self.drain_event.set()
drain_tokens_seen = 0


X = TypeVar("X")
Y = TypeVar("Y")


async def race(
x: Coroutine[None, None, X],
y: Coroutine[None, None, Y],
timeout: Optional[float] = None,
) -> Union[X, Y]:
tasks = [asyncio.create_task(x), asyncio.create_task(y)]
wait = asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
done, pending = await wait
for task in pending:
task.cancel()
if not done:
raise TimeoutError
# done is an unordered set but we want to preserve original order
result_task, *others = (t for t in tasks if t in done)
# during shutdown, some of the other completed tasks might be an error
# cancel them instead of handling the error to avoid the warning
# "Task exception was never retrieved"
for task in others:
msg = "was completed at the same time as another selected task, canceling"
# FIXME: ues a logger?
print(task, msg, file=sys.stderr)
task.cancel()
return result_task.result()


# functionally this is the exact same thing as aioprocessing but 0.1% the code
# however it's still worse than just using actual asynchronous io
class AsyncPipe(Generic[X]):
def __init__(self, conn: Connection) -> None:
def __init__(
self, conn: Connection, alive: Callable[[], bool] = lambda: True
) -> None:
self.conn = conn
self.alive = alive
self.exiting = threading.Event()
self.executor = concurrent.futures.ThreadPoolExecutor(1)

Expand All @@ -175,7 +209,7 @@ def send(self, obj: Any) -> None:

def shutdown(self) -> None:
self.exiting.set()
self.executor.shutdown(wait=False)
self.executor.shutdown(wait=True)
# if we ever need cancel_futures (introduced 3.9), we can copy it in from
# https://github.com/python/cpython/blob/3.11/Lib/concurrent/futures/thread.py#L216-L235

Expand All @@ -185,12 +219,20 @@ def poll(self, timeout: float = 0.0) -> bool:
def _recv(self) -> Optional[X]:
# this ugly mess could easily be avoided with loop.connect_read_pipe
# even loop.add_reader would help but we don't want to mess with a thread-local loop
while not self.exiting.is_set():
while not self.exiting.is_set() and not self.conn.closed and self.alive():
if self.conn.poll(0.01):
if self.conn.closed or not self.alive():
print("caught conn closed or unalive")
return
return self.conn.recv()
return None

async def coro_recv(self) -> Optional[X]:
loop = asyncio.get_running_loop()
# believe it or not this can still deadlock!
return await loop.run_in_executor(self.executor, self._recv)

async def coro_recv_with_exit(self, exit: asyncio.Event) -> Optional[X]:
result = await race(self.coro_recv(), exit.wait())
if result is not True: # wait() would return True
return result
Loading

0 comments on commit f57474d

Please sign in to comment.