Skip to content

Commit

Permalink
feat(api): add progress to ready endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 4, 2023
1 parent 1491a9e commit 294c831
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
16 changes: 12 additions & 4 deletions api/onnx_web/device_pool.py
@@ -1,7 +1,7 @@
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
from logging import getLogger
from multiprocessing import Value
from typing import Any, Callable, List, Union, Optional
from typing import Any, Callable, List, Optional, Tuple, Union

logger = getLogger(__name__)

Expand Down Expand Up @@ -35,6 +35,9 @@ def get_device(self) -> str:
else:
return self.devices[device_index]

def get_progress(self) -> int:
return self.progress.value

def get_progress_callback(self) -> Callable[..., None]:
def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled():
Expand Down Expand Up @@ -64,6 +67,9 @@ def __init__(
self.future = future
self.key = key

def get_progress(self) -> int:
self.context.get_progress()

def set_cancel(self, cancel: bool = True):
self.context.set_cancel(cancel)

Expand Down Expand Up @@ -94,13 +100,15 @@ def cancel(self, key: str) -> bool:
else:
job.set_cancel()

def done(self, key: str) -> bool:
def done(self, key: str) -> Tuple[bool, int]:
for job in self.jobs:
if job.key == key:
return job.future.done()
done = job.future.done()
progress = job.get_progress()
return (done, progress)

logger.warn('checking status for unknown key: %s', key)
return None
return (None, 0)

def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()]
Expand Down
7 changes: 4 additions & 3 deletions api/onnx_web/serve.py
Expand Up @@ -331,8 +331,9 @@ def load_platforms():
gc.set_debug(gc.DEBUG_STATS)


def ready_reply(ready: bool):
def ready_reply(ready: bool, progress: int = 0):
return jsonify({
'progress': progress,
'ready': ready,
})

Expand Down Expand Up @@ -609,14 +610,14 @@ def chain():
def ready():
output_file = request.args.get('output', None)

done = executor.done(output_file)
done, progress = executor.done(output_file)

if done is None:
file = base_join(context.output_path, output_file)
if path.exists(file):
return ready_reply(True)

return ready_reply(done)
return ready_reply(done, progress=progress)


@app.route('/api/cancel', methods=['PUT'])
Expand Down

0 comments on commit 294c831

Please sign in to comment.