Skip to content

Commit

Permalink
fix(api): maintain list of pending jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 18, 2023
1 parent 588c8c7 commit 15b6e03
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 38 deletions.
2 changes: 2 additions & 0 deletions api/onnx_web/worker/command.py
Expand Up @@ -27,6 +27,7 @@ def __init__(


class JobCommand:
device: str
name: str
fn: Callable[..., None]
args: Any
Expand All @@ -35,6 +36,7 @@ class JobCommand:
def __init__(
self,
name: str,
device: str,
fn: Callable[..., None],
args: Any,
kwargs: dict[str, Any],
Expand Down
99 changes: 61 additions & 38 deletions api/onnx_web/worker/pool.py
Expand Up @@ -23,15 +23,16 @@ class DevicePoolExecutor:
join_timeout: float

leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]

active_jobs: Dict[str, ProgressCommand] # Device -> job progress
cancelled_jobs: List[str]
finished_jobs: List[ProgressCommand]
pending_jobs: List[JobCommand]
running_jobs: Dict[str, ProgressCommand] # Device -> job progress
total_jobs: Dict[str, int] # Device -> job count

logs: "Queue[str]"
Expand All @@ -57,7 +58,7 @@ def __init__(
self.threads = {}
self.workers = {}

self.active_jobs = {}
self.running_jobs = {}
self.cancelled_jobs = []
self.finished_jobs = []
self.total_jobs = {}
Expand Down Expand Up @@ -139,31 +140,12 @@ def logger_worker(logs: Queue):
logger_thread.start()

def create_progress_worker(self) -> None:
def update_job(progress: ProgressCommand):
if progress.finished:
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.active_jobs[progress.job]
self.join_leaking()
else:
logger.debug(
"progress update for job: %s to %s", progress.job, progress.progress
)
self.active_jobs[progress.job] = progress
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s",
progress.job,
progress.device,
)
self.context[progress.device].set_cancel()

def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
progress = queue.get(timeout=(self.join_timeout / 2))
update_job(progress)
self.update_job(progress)
except Empty:
pass
except ValueError:
Expand All @@ -183,7 +165,7 @@ def progress_worker(queue: "Queue[ProgressCommand]"):
progress_thread.start()

def get_job_context(self, key: str) -> WorkerContext:
device, _progress = self.active_jobs[key]
device, _progress = self.running_jobs[key]
return self.context[device]

def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
Expand Down Expand Up @@ -217,31 +199,43 @@ def cancel(self, key: str) -> bool:
logger.debug("cannot cancel finished job: %s", key)
return False

if key not in self.active_jobs:
for job in self.pending_jobs:
if job.name == key:
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key]
logger.info("cancelled pending job: %s", key)
return True

if key not in self.running_jobs:
logger.debug("cancelled job is not active: %s", key)
else:
job = self.active_jobs[key]
job = self.running_jobs[key]
logger.info("cancelling job %s, active on device %s", key, job.device)

self.cancelled_jobs.append(key)
return True

def done(self, key: str) -> Optional[ProgressCommand]:
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
"""
Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None.
If the job is still pending, the first item will be True and there will be no ProgressCommand.
"""
if key in self.running_jobs:
logger.debug("checking status for running job: %s", key)
return (False, self.running_jobs[key])

for job in self.finished_jobs:
if job.job == key:
return job
logger.debug("checking status for finished job: %s", key)
return (False, job)

if key not in self.active_jobs:
logger.debug("checking status for unknown job: %s", key)
return None
for job in self.pending_jobs:
if job.name == key:
logger.debug("checking status for pending job: %s", key)
return (True, ProgressCommand(job.name, job.device, False, 0))

return self.active_jobs[key]
logger.trace("checking status for unknown job: %s", key)
return (False, None)

def join(self):
logger.info("stopping worker pool")
Expand Down Expand Up @@ -355,17 +349,21 @@ def submit(
self.devices[device_idx],
)

# increment job count before recycling (why tho?)
device = self.devices[device_idx].device

if device in self.total_jobs:
self.total_jobs[device] += 1
else:
self.total_jobs[device] = 1

# recycle before attempting to run
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle()

self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
# build and queue job
job = JobCommand(key, device, fn, args, kwargs)
self.pending_jobs.append(job)
self.pending[device].put(job, block=False)

def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
history = [
Expand All @@ -376,7 +374,7 @@ def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
job.cancel,
job.error,
)
for name, job in self.active_jobs.items()
for name, job in self.running_jobs.items()
]
history.extend(
[
Expand All @@ -391,3 +389,28 @@ def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
]
)
return history

def update_job(self, progress: ProgressCommand):
if progress.finished:
# move from running to finished
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.running_jobs[progress.job]
self.join_leaking()
if progress.job in self.cancelled_jobs:
self.cancelled_jobs.remove(progress.job)
else:
# move from pending to running
logger.debug(
"progress update for job: %s to %s", progress.job, progress.progress
)
self.running_jobs[progress.job] = progress
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job]

if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s",
progress.job,
progress.device,
)
self.context[progress.device].set_cancel()

0 comments on commit 15b6e03

Please sign in to comment.