Skip to content

Commit

Permalink
feat(api): add error flag to image ready response
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 18, 2023
1 parent 0ab52f0 commit 7cf5554
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 78 deletions.
12 changes: 7 additions & 5 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@
logger = getLogger(__name__)


def ready_reply(ready: bool, progress: int = 0):
def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False):
return jsonify(
{
"cancel": cancel,
"error": error,
"progress": progress,
"ready": ready,
}
Expand Down Expand Up @@ -437,7 +439,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
output_file = sanitize_name(output_file)
cancel = pool.cancel(output_file)

return ready_reply(cancel)
return ready_reply(cancel == False, cancel=cancel)


def ready(context: ServerContext, pool: DevicePoolExecutor):
Expand All @@ -446,14 +448,14 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
return error_reply("output name is required")

output_file = sanitize_name(output_file)
done, progress = pool.done(output_file)
progress = pool.done(output_file)

if done is None:
if progress is None:
output = base_join(context.output_path, output_file)
if path.exists(output):
return ready_reply(True)

return ready_reply(done or False, progress=progress)
return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel)


def status(context: ServerContext, pool: DevicePoolExecutor):
Expand Down
43 changes: 43 additions & 0 deletions api/onnx_web/worker/command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Callable, Any

class ProgressCommand():
device: str
job: str
finished: bool
progress: int
cancel: bool
error: bool

def __init__(
self,
job: str,
device: str,
finished: bool,
progress: int,
cancel: bool = False,
error: bool = False,
):
self.job = job
self.device = device
self.finished = finished
self.progress = progress
self.cancel = cancel
self.error = error

class JobCommand():
name: str
fn: Callable[..., None]
args: Any
kwargs: dict[str, Any]

def __init__(
self,
name: str,
fn: Callable[..., None],
args: Any,
kwargs: dict[str, Any],
):
self.name = name
self.fn = fn
self.args = args
self.kwargs = kwargs
37 changes: 20 additions & 17 deletions api/onnx_web/worker/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from torch.multiprocessing import Queue, Value

from .command import JobCommand, ProgressCommand
from ..params import DeviceParams

logger = getLogger(__name__)
Expand All @@ -15,26 +16,24 @@
class WorkerContext:
cancel: "Value[bool]"
job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
pending: "Queue[JobCommand]"
current: "Value[int]"
progress: "Queue[Tuple[str, str, int]]"
progress: "Queue[ProgressCommand]"

def __init__(
self,
job: str,
device: DeviceParams,
cancel: "Value[bool]",
logs: "Queue[str]",
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[Tuple[str, str]]",
pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]",
current: "Value[int]",
):
self.job = job
self.device = device
self.cancel = cancel
self.progress = progress
self.finished = finished
self.logs = logs
self.pending = pending
self.current = current
Expand All @@ -61,11 +60,7 @@ def get_progress(self) -> int:
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, step)
self.set_progress(step)
self.set_progress(step)

return on_progress

Expand All @@ -74,14 +69,22 @@ def set_cancel(self, cancel: bool = True) -> None:
self.cancel.value = cancel

def set_progress(self, progress: int) -> None:
self.progress.put((self.job, self.device.device, progress), block=False)
if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, progress)
self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False)

def set_finished(self) -> None:
self.finished.put((self.job, self.device.device), block=False)

def clear_flags(self) -> None:
self.set_cancel(False)
self.set_progress(0)
logger.debug("setting finished for job %s", self.job)
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False)

def set_failed(self) -> None:
logger.warning("setting failure for job %s", self.job)
try:
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False)
except:
logger.exception("error setting failure on job %s", self.job)


class JobStatus:
Expand Down
102 changes: 56 additions & 46 deletions api/onnx_web/worker/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..params import DeviceParams
from ..server import ServerContext
from .command import JobCommand, ProgressCommand
from .context import WorkerContext
from .worker import worker_main

Expand All @@ -24,18 +25,17 @@ class DevicePoolExecutor:
leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]

active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
cancelled_jobs: List[str]
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
finished_jobs: List[ProgressCommand]
total_jobs: Dict[str, int] # Device -> job count

logs: "Queue"
progress: "Queue[Tuple[str, str, int]]"
finished: "Queue[Tuple[str, str]]"
logs: "Queue[str]"
progress: "Queue[ProgressCommand]"

def __init__(
self,
Expand Down Expand Up @@ -142,18 +142,27 @@ def logger_worker(logs: Queue):
logger_thread.start()

def create_progress_worker(self) -> None:
def progress_worker(progress: Queue):
def update_job(progress: ProgressCommand):
if progress.finished:
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.active_jobs[progress.job]
self.join_leaking()
else:
logger.debug("progress update for job: %s to %s", progress.job, progress.progress)
self.active_jobs[progress.job] = progress
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", progress.job, progress.device
)
self.context[progress.device].set_cancel()

def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
job, device, value = progress.get(timeout=(self.join_timeout / 2))
logger.debug("progress update for job: %s to %s", job, value)
self.active_jobs[job] = (device, value)
if job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", job, device
)
self.context[device].set_cancel()
progress = queue.get(timeout=(self.join_timeout / 2))
update_job(progress)
except Empty:
pass
except ValueError:
Expand All @@ -178,12 +187,7 @@ def finished_worker(finished: Queue):
while True:
try:
job, device = finished.get(timeout=(self.join_timeout / 2))
logger.info("job has been finished: %s", job)
context = self.context[device]
_device, progress = self.active_jobs[job]
self.finished_jobs.append((job, progress, context.cancel.value))
del self.active_jobs[job]
self.join_leaking()

except Empty:
pass
except ValueError:
Expand Down Expand Up @@ -232,37 +236,36 @@ def cancel(self, key: str) -> bool:
should be cancelled on the next progress callback.
"""

self.cancelled_jobs.append(key)
for job in self.finished_jobs:
if job.job == key:
logger.debug("cannot cancel finished job: %s", key)
return False

if key not in self.active_jobs:
logger.debug("cancelled job has not been started yet: %s", key)
return True

device, _progress = self.active_jobs[key]
logger.info("cancelling job %s, active on device %s", key, device)

context = self.context[device]
context.set_cancel()
logger.debug("cancelled job is not active: %s", key)
else:
job = self.active_jobs[key]
logger.info("cancelling job %s, active on device %s", key, job.device)

self.cancelled_jobs.append(key)
return True

def done(self, key: str) -> Tuple[Optional[bool], int]:
def done(self, key: str) -> Optional[ProgressCommand]:
"""
Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None.
"""
for k, p, c in self.finished_jobs:
if k == key:
return (True, p)
for job in self.finished_jobs:
if job.job == key:
return job

if key not in self.active_jobs:
logger.debug("checking status for unknown job: %s", key)
return (None, 0)
return None

_device, progress = self.active_jobs[key]
return (False, progress)
return self.active_jobs[key]

def join(self):
logger.info("stopping worker pool")
Expand Down Expand Up @@ -387,22 +390,29 @@ def submit(
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle()

self.pending[device].put((key, fn, args, kwargs), block=False)
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)

def status(self) -> List[Tuple[str, int, bool, bool]]:
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
history = [
(name, progress, False, name in self.cancelled_jobs)
for name, (_device, progress) in self.active_jobs.items()
(
name,
job.progress,
job.finished,
job.cancel,
job.error,
)
for name, job in self.active_jobs.items()
]
history.extend(
[
(
name,
progress,
True,
cancel,
job.job,
job.progress,
job.finished,
job.cancel,
job.error,
)
for name, progress, cancel in self.finished_jobs
for job in self.finished_jobs
]
)
return history
24 changes: 14 additions & 10 deletions api/onnx_web/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,36 @@ def worker_main(context: WorkerContext, server: ServerContext):
)
exit(EXIT_REPLACED)

name, fn, args, kwargs = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, name)
job = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, job.name)

context.job = name # TODO: hax
context.clear_flags()
logger.info("starting job: %s", name)
fn(context, *args, **kwargs)
logger.info("job succeeded: %s", name)
context.job = job.name # TODO: hax
logger.info("starting job: %s", job.name)
context.set_progress(0)
job.fn(context, *job.args, **job.kwargs)
logger.info("job succeeded: %s", job.name)
context.set_finished()
except Empty:
pass
except KeyboardInterrupt:
logger.info("worker got keyboard interrupt")
context.set_failed()
exit(EXIT_INTERRUPT)
except ValueError as e:
logger.info(
"value error in worker, exiting: %s",
format_exception(type(e), e, e.__traceback__),
logger.exception(
"value error in worker, exiting: %s"
)
context.set_failed()
exit(EXIT_ERROR)
except Exception as e:
e_str = str(e)
if "Failed to allocate memory" in e_str or "out of memory" in e_str:
logger.error("detected out-of-memory error, exiting: %s", e)
context.set_failed()
exit(EXIT_MEMORY)
else:
logger.exception(
"error while running job",
)
context.set_failed()
# carry on
2 changes: 2 additions & 0 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ export interface ImageResponse {
* Status response from the ready endpoint.
*/
export interface ReadyResponse {
cancel: boolean;
error: boolean;
progress: number;
ready: boolean;
}
Expand Down

0 comments on commit 7cf5554

Please sign in to comment.