Skip to content

Commit

Permalink
fix(api): track currently active worker for each device
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 6, 2023
1 parent 57fed94 commit c0a01ef
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
10 changes: 10 additions & 0 deletions api/onnx_web/worker/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from logging import getLogger
from os import getpid
from typing import Any, Callable, Tuple

from torch.multiprocessing import Queue, Value
Expand All @@ -15,6 +16,7 @@ class WorkerContext:
cancel: "Value[bool]"
job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
current: "Value[int]"
progress: "Queue[Tuple[str, str, int]]"

def __init__(
Expand All @@ -26,6 +28,7 @@ def __init__(
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[Tuple[str, str]]",
current: "Value[int]",
):
self.job = job
self.device = device
Expand All @@ -34,10 +37,17 @@ def __init__(
self.finished = finished
self.logs = logs
self.pending = pending
self.current = current

def is_cancelled(self) -> bool:
return self.cancel.value

def is_current(self) -> bool:
if self.current.value > 0:
return self.current.value == getpid()

return True

def get_device(self) -> DeviceParams:
"""
Get the device assigned to this job.
Expand Down
17 changes: 15 additions & 2 deletions api/onnx_web/worker/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class DevicePoolExecutor:

leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(

self.leaking = []
self.context = {}
self.current = {}
self.pending = {}
self.threads = {}
self.workers = {}
Expand Down Expand Up @@ -85,6 +87,14 @@ def create_device_worker(self, device: DeviceParams) -> None:
pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending

if name in self.current:
logger.debug("using existing current worker value")
current = self.current[name]
else:
logger.debug("creating new current worker value")
current = Value("L", 0)
self.current[name] = current

context = WorkerContext(
name,
device,
Expand All @@ -93,16 +103,19 @@ def create_device_worker(self, device: DeviceParams) -> None:
finished=self.finished,
logs=self.logs,
pending=pending,
current=current,
)
self.context[name] = context
self.workers[name] = Process(
worker = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
args=(context, self.server),
)

logger.debug("starting worker for device %s", device)
self.workers[name].start()
worker.start()
self.workers[name] = worker
current.value = worker.pid

def create_logger_worker(self) -> None:
def logger_worker(logs: Queue):
Expand Down
4 changes: 4 additions & 0 deletions api/onnx_web/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def worker_main(context: WorkerContext, server: ServerContext):

while True:
try:
if not context.is_current():
logger.warning("worker has been replaced, exiting")
exit(3)

name, fn, args, kwargs = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, name)

Expand Down

0 comments on commit c0a01ef

Please sign in to comment.