From e46a1e5fd0e1c2dd018e31e40f1115608586f1bd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 25 Feb 2023 23:16:32 -0600 Subject: [PATCH 01/40] begin switching to per-device torch mp workers --- api/onnx_web/server/device_pool.py | 157 ++++++++++------------------- 1 file changed, 56 insertions(+), 101 deletions(-) diff --git a/api/onnx_web/server/device_pool.py b/api/onnx_web/server/device_pool.py index d6799b755..152e1d740 100644 --- a/api/onnx_web/server/device_pool.py +++ b/api/onnx_web/server/device_pool.py @@ -1,9 +1,11 @@ from collections import Counter -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures import Future from logging import getLogger -from multiprocessing import Value +from multiprocessing import Queue +from torch.multiprocessing import Lock, Process, SimpleQueue, Value from traceback import format_exception -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from time import sleep from ..params import DeviceParams from ..utils import run_gc @@ -13,6 +15,18 @@ ProgressCallback = Callable[[int, int, Any], None] +def worker_init(lock: Lock, job_queue: SimpleQueue): + logger.info("checking in from worker") + + while True: + if job_queue.empty(): + logger.info("no jobs, sleeping") + sleep(5) + else: + job = job_queue.get() + logger.info("got job: %s", job) + + class JobContext: cancel: Value = None device_index: Value = None @@ -104,38 +118,31 @@ def set_progress(self, progress: int): class DevicePoolExecutor: devices: List[DeviceParams] = None - jobs: List[Job] = None - next_device: int = 0 - pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None - recent: List[Tuple[str, int]] = None + finished: List[Tuple[str, int]] = None + pending: Dict[str, "Queue[Job]"] = None + progress: Dict[str, Value] = None + workers: Dict[str, Process] = None def __init__( self, devices: List[DeviceParams], - pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None, - recent_limit: int = 10, + finished_limit: int = 10, ): self.devices = devices - self.jobs = [] - self.next_device = 0 - self.recent = [] - self.recent_limit = recent_limit - - device_count = len(devices) - if pool is None: - logger.info( - "creating thread pool executor for %s devices: %s", - device_count, - [d.device for d in devices], - ) - self.pool = ThreadPoolExecutor(device_count) - else: - logger.info( - "using existing pool for %s devices: %s", - device_count, - [d.device for d in devices], - ) - self.pool = pool + self.finished = [] + self.finished_limit = finished_limit + self.lock = Lock() + self.pending = {} + self.progress = {} + self.workers = {} + + # create a pending queue and progress value for each device + for device in devices: + name = device.device + job_queue = Queue() + self.pending[name] = job_queue + self.progress[name] = Value("I", 0, lock=self.lock) + self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue)) def cancel(self, key: str) -> bool: """ @@ -143,31 +150,13 @@ def cancel(self, key: str) -> bool: the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. """ - for job in self.jobs: - if job.key == key: - if job.future.cancel(): - return True - else: - job.set_cancel() - return True - - return False + raise NotImplementedError() def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, progress in self.recent: + for k, progress in self.finished: if key == k: return (True, progress) - for job in self.jobs: - if job.key == key: - done = job.future.done() - progress = job.get_progress() - - if done: - self.prune() - - return (done, progress) - logger.warn("checking status for unknown key: %s", key) return (None, 0) @@ -198,24 +187,14 @@ def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: return lowest_devices[0] def prune(self): - pending_jobs = [job for job in self.jobs if job.future.done()] - logger.debug("pruning %s of %s pending jobs", len(pending_jobs), len(self.jobs)) - - for job in pending_jobs: - self.recent.append((job.key, job.get_progress())) - try: - self.jobs.remove(job) - except ValueError as e: - logger.warning("error removing pruned job from pending: %s", e) - - recent_count = len(self.recent) - if recent_count > self.recent_limit: + finished_count = len(self.finished) + if finished_count > self.finished_limit: logger.debug( - "pruning %s of %s recent jobs", - recent_count - self.recent_limit, - recent_count, + "pruning %s of %s finished jobs", + finished_count - self.finished_limit, + finished_count, ) - self.recent[:] = self.recent[-self.recent_limit :] + self.finished[:] = self.finished[-self.finished_limit:] def submit( self, @@ -227,49 +206,25 @@ def submit( **kwargs, ) -> None: self.prune() - device = self.get_next_device(needs_device=needs_device) + device_idx = self.get_next_device(needs_device=needs_device) logger.info( - "assigning job %s to device %s: %s", key, device, self.devices[device] + "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] ) - context = JobContext(key, self.devices, device_index=device) - future = self.pool.submit(fn, context, *args, **kwargs) - job = Job(key, future, context) - self.jobs.append(job) - - def job_done(f: Future): - try: - f.result() - logger.info("job %s finished successfully", key) - except Exception as err: - logger.warn( - "job %s failed with an error: %s", - key, - format_exception(type(err), err, err.__traceback__), - ) - run_gc([self.devices[device]]) - - future.add_done_callback(job_done) + context = JobContext(key, self.devices, device_index=device_idx) + device = self.devices[device_idx] + + queue = self.pending[device.device] + queue.put((fn, context, args, kwargs)) + def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ ( - job.key, - job.context.device_index.value, - job.future.done(), - job.get_progress(), + device.device, + self.pending[device.device].qsize(), ) - for job in self.jobs + for device in self.devices ] - recent = [ - ( - key, - None, - True, - progress, - ) - for key, progress in self.recent - ] - - pending.extend(recent) + pending.extend(self.finished) return pending From f898de8c5490673f60bc91129f8fedb07e7495f6 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 25 Feb 2023 23:49:39 -0600 Subject: [PATCH 02/40] background workers, logger --- api/logging.yaml | 6 +- api/onnx_web/__init__.py | 6 +- api/onnx_web/chain/base.py | 7 +- api/onnx_web/chain/blend_img2img.py | 5 +- api/onnx_web/chain/blend_inpaint.py | 5 +- api/onnx_web/chain/blend_mask.py | 5 +- api/onnx_web/chain/correct_codeformer.py | 5 +- api/onnx_web/chain/correct_gfpgan.py | 6 +- api/onnx_web/chain/persist_disk.py | 6 +- api/onnx_web/chain/persist_s3.py | 6 +- api/onnx_web/chain/reduce_crop.py | 6 +- api/onnx_web/chain/reduce_thumbnail.py | 6 +- api/onnx_web/chain/source_noise.py | 6 +- api/onnx_web/chain/source_txt2img.py | 6 +- api/onnx_web/chain/upscale_outpaint.py | 5 +- api/onnx_web/chain/upscale_resrgan.py | 5 +- .../chain/upscale_stable_diffusion.py | 5 +- api/onnx_web/diffusion/run.py | 13 +- api/onnx_web/serve.py | 3 +- api/onnx_web/server/__init__.py | 7 - api/onnx_web/server/device_pool.py | 230 ------------------ api/onnx_web/transformers.py | 5 +- api/onnx_web/upscale.py | 5 +- api/onnx_web/worker/__init__.py | 2 + api/onnx_web/worker/context.py | 60 +++++ api/onnx_web/worker/logging.py | 1 + api/onnx_web/worker/pool.py | 136 +++++++++++ api/onnx_web/worker/worker.py | 32 +++ 28 files changed, 306 insertions(+), 284 deletions(-) delete mode 100644 api/onnx_web/server/device_pool.py create mode 100644 api/onnx_web/worker/__init__.py create mode 100644 api/onnx_web/worker/context.py create mode 100644 api/onnx_web/worker/logging.py create mode 100644 api/onnx_web/worker/pool.py create mode 100644 api/onnx_web/worker/worker.py diff --git a/api/logging.yaml b/api/logging.yaml index 0bf543100..24bd3c292 100644 --- a/api/logging.yaml +++ b/api/logging.yaml @@ -5,14 +5,14 @@ formatters: handlers: console: class: logging.StreamHandler - level: INFO + level: DEBUG formatter: simple stream: ext://sys.stdout loggers: '': - level: INFO + level: DEBUG handlers: [console] propagate: True root: - level: INFO + level: DEBUG handlers: [console] diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 3afedb07f..7316bb87e 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -25,6 +25,7 @@ from .onnx import OnnxNet, OnnxTensor from .params import ( Border, + DeviceParams, ImageParams, Param, Point, @@ -33,8 +34,6 @@ UpscaleParams, ) from .server import ( - DeviceParams, - DevicePoolExecutor, ModelCache, ServerContext, apply_patch_basicsr, @@ -51,3 +50,6 @@ get_from_map, get_not_empty, ) +from .worker import ( + DevicePoolExecutor, +) \ No newline at end of file diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 7e77bab7c..dc8073223 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,7 +7,8 @@ from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug from .utils import process_tile_order @@ -17,7 +18,7 @@ class StageCallback(Protocol): def __call__( self, - job: JobContext, + job: WorkerContext, ctx: ServerContext, stage: StageParams, params: ImageParams, @@ -77,7 +78,7 @@ def append(self, stage: PipelineStage): def __call__( self, - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, source: Image.Image, diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0ef9ef960..675311036 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -7,13 +7,14 @@ from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext logger = getLogger(__name__) def blend_img2img( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 16422cce3..51b6d983b 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -10,7 +10,8 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug from .utils import process_tile_order @@ -18,7 +19,7 @@ def blend_inpaint( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index be5d11597..bfe11aabe 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -7,14 +7,15 @@ from onnx_web.output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug logger = getLogger(__name__) def blend_mask( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 6b4e235d7..01d61db7c 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -3,7 +3,8 @@ from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext logger = getLogger(__name__) @@ -11,7 +12,7 @@ def correct_codeformer( - job: JobContext, + job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 6796c1bab..afcae86b3 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,8 +5,10 @@ from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext from ..utils import run_gc +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) @@ -46,7 +48,7 @@ def load_gfpgan( def correct_gfpgan( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 9a5f0cd0c..58020b57a 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -4,13 +4,15 @@ from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def persist_disk( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index bf3682bf3..926f1598b 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -5,13 +5,15 @@ from PIL import Image from ..params import ImageParams, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def persist_s3( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 226f6cf26..4cd715b16 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -3,13 +3,15 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def reduce_crop( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 0037084c0..4950a9731 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -3,13 +3,15 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def reduce_thumbnail( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index f6267b26b..9ab302b1e 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -4,13 +4,15 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext + logger = getLogger(__name__) def source_noise( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index a5cbb07fe..b933ecc97 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -7,13 +7,15 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext + logger = getLogger(__name__) def source_txt2img( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 7919e3251..23393491e 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,7 +10,8 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import is_debug from .utils import process_tile_grid, process_tile_order @@ -18,7 +19,7 @@ def upscale_outpaint( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 178360c36..ccbb3644f 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -6,7 +6,8 @@ from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext from ..utils import run_gc logger = getLogger(__name__) @@ -96,7 +97,7 @@ def load_resrgan( def upscale_resrgan( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 5747f13fe..00c1b9d4a 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -10,7 +10,8 @@ OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..worker import WorkerContext, ProgressCallback +from ..server import ServerContext from ..utils import run_gc logger = getLogger(__name__) @@ -62,7 +63,7 @@ def load_stable_diffusion( def upscale_stable_diffusion( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index ec988d660..0e44b6cc4 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -12,7 +12,8 @@ from ..chain import upscale_outpaint from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..worker import WorkerContext +from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc from .load import get_latents_from_seed, load_pipeline @@ -21,7 +22,7 @@ def run_txt2img_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -95,7 +96,7 @@ def run_txt2img_pipeline( def run_img2img_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, outputs: List[str], @@ -167,7 +168,7 @@ def run_img2img_pipeline( def run_inpaint_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -217,7 +218,7 @@ def run_inpaint_pipeline( def run_upscale_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -243,7 +244,7 @@ def run_upscale_pipeline( def run_blend_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index fb95d8ef3..4d20f2960 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -63,7 +63,7 @@ TileOrder, UpscaleParams, ) -from .server import DevicePoolExecutor, ServerContext, apply_patches +from .server import ServerContext, apply_patches from .transformers import run_txt2txt_pipeline from .utils import ( base_join, @@ -75,6 +75,7 @@ get_size, is_debug, ) +from .worker import DevicePoolExecutor logger = getLogger(__name__) diff --git a/api/onnx_web/server/__init__.py b/api/onnx_web/server/__init__.py index 0403746c5..f02fa35a6 100644 --- a/api/onnx_web/server/__init__.py +++ b/api/onnx_web/server/__init__.py @@ -1,10 +1,3 @@ -from .device_pool import ( - DeviceParams, - DevicePoolExecutor, - Job, - JobContext, - ProgressCallback, -) from .hacks import ( apply_patch_basicsr, apply_patch_codeformer, diff --git a/api/onnx_web/server/device_pool.py b/api/onnx_web/server/device_pool.py deleted file mode 100644 index 152e1d740..000000000 --- a/api/onnx_web/server/device_pool.py +++ /dev/null @@ -1,230 +0,0 @@ -from collections import Counter -from concurrent.futures import Future -from logging import getLogger -from multiprocessing import Queue -from torch.multiprocessing import Lock, Process, SimpleQueue, Value -from traceback import format_exception -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from time import sleep - -from ..params import DeviceParams -from ..utils import run_gc - -logger = getLogger(__name__) - -ProgressCallback = Callable[[int, int, Any], None] - - -def worker_init(lock: Lock, job_queue: SimpleQueue): - logger.info("checking in from worker") - - while True: - if job_queue.empty(): - logger.info("no jobs, sleeping") - sleep(5) - else: - job = job_queue.get() - logger.info("got job: %s", job) - - -class JobContext: - cancel: Value = None - device_index: Value = None - devices: List[DeviceParams] = None - key: str = None - progress: Value = None - - def __init__( - self, - key: str, - devices: List[DeviceParams], - cancel: bool = False, - device_index: int = -1, - progress: int = 0, - ): - self.key = key - self.devices = list(devices) - self.cancel = Value("B", cancel) - self.device_index = Value("i", device_index) - self.progress = Value("I", progress) - - def is_cancelled(self) -> bool: - return self.cancel.value - - def get_device(self) -> DeviceParams: - """ - Get the device assigned to this job. - """ - with self.device_index.get_lock(): - device_index = self.device_index.value - if device_index < 0: - raise ValueError("job has not been assigned to a device") - else: - device = self.devices[device_index] - logger.debug("job %s assigned to device %s", self.key, device) - return device - - def get_progress(self) -> int: - return self.progress.value - - def get_progress_callback(self) -> ProgressCallback: - def on_progress(step: int, timestep: int, latents: Any): - on_progress.step = step - if self.is_cancelled(): - raise RuntimeError("job has been cancelled") - else: - logger.debug("setting progress for job %s to %s", self.key, step) - self.set_progress(step) - - return on_progress - - def set_cancel(self, cancel: bool = True) -> None: - with self.cancel.get_lock(): - self.cancel.value = cancel - - def set_progress(self, progress: int) -> None: - with self.progress.get_lock(): - self.progress.value = progress - - -class Job: - """ - Link a future to its context. - """ - - context: JobContext = None - future: Future = None - key: str = None - - def __init__( - self, - key: str, - future: Future, - context: JobContext, - ): - self.context = context - self.future = future - self.key = key - - def get_progress(self) -> int: - return self.context.get_progress() - - def set_cancel(self, cancel: bool = True): - return self.context.set_cancel(cancel) - - def set_progress(self, progress: int): - return self.context.set_progress(progress) - - -class DevicePoolExecutor: - devices: List[DeviceParams] = None - finished: List[Tuple[str, int]] = None - pending: Dict[str, "Queue[Job]"] = None - progress: Dict[str, Value] = None - workers: Dict[str, Process] = None - - def __init__( - self, - devices: List[DeviceParams], - finished_limit: int = 10, - ): - self.devices = devices - self.finished = [] - self.finished_limit = finished_limit - self.lock = Lock() - self.pending = {} - self.progress = {} - self.workers = {} - - # create a pending queue and progress value for each device - for device in devices: - name = device.device - job_queue = Queue() - self.pending[name] = job_queue - self.progress[name] = Value("I", 0, lock=self.lock) - self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue)) - - def cancel(self, key: str) -> bool: - """ - Cancel a job. If the job has not been started, this will cancel - the future and never execute it. If the job has been started, it - should be cancelled on the next progress callback. - """ - raise NotImplementedError() - - def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, progress in self.finished: - if key == k: - return (True, progress) - - logger.warn("checking status for unknown key: %s", key) - return (None, 0) - - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: - # respect overrides if possible - if needs_device is not None: - for i in range(len(self.devices)): - if self.devices[i].device == needs_device.device: - return i - - # use the first/default device if there are no jobs - if len(self.jobs) == 0: - return 0 - - job_devices = [ - job.context.device_index.value for job in self.jobs if not job.future.done() - ] - job_counts = Counter(range(len(self.devices))) - job_counts.update(job_devices) - - queued = job_counts.most_common() - logger.debug("jobs queued by device: %s", queued) - - lowest_count = queued[-1][1] - lowest_devices = [d[0] for d in queued if d[1] == lowest_count] - lowest_devices.sort() - - return lowest_devices[0] - - def prune(self): - finished_count = len(self.finished) - if finished_count > self.finished_limit: - logger.debug( - "pruning %s of %s finished jobs", - finished_count - self.finished_limit, - finished_count, - ) - self.finished[:] = self.finished[-self.finished_limit:] - - def submit( - self, - key: str, - fn: Callable[..., None], - /, - *args, - needs_device: Optional[DeviceParams] = None, - **kwargs, - ) -> None: - self.prune() - device_idx = self.get_next_device(needs_device=needs_device) - logger.info( - "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] - ) - - context = JobContext(key, self.devices, device_index=device_idx) - device = self.devices[device_idx] - - queue = self.pending[device.device] - queue.put((fn, context, args, kwargs)) - - - def status(self) -> List[Tuple[str, int, bool, int]]: - pending = [ - ( - device.device, - self.pending[device.device].qsize(), - ) - for device in self.devices - ] - pending.extend(self.finished) - return pending diff --git a/api/onnx_web/transformers.py b/api/onnx_web/transformers.py index f7a70693b..18d90f0a1 100644 --- a/api/onnx_web/transformers.py +++ b/api/onnx_web/transformers.py @@ -1,13 +1,14 @@ from logging import getLogger from .params import ImageParams, Size -from .server import JobContext, ServerContext +from .server import ServerContext +from .worker import WorkerContext logger = getLogger(__name__) def run_txt2txt_pipeline( - job: JobContext, + job: WorkerContext, _server: ServerContext, params: ImageParams, _size: Size, diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index c04d6efb4..8636f8c1d 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -10,13 +10,14 @@ upscale_stable_diffusion, ) from .params import ImageParams, SizeChart, StageParams, UpscaleParams -from .server import JobContext, ProgressCallback, ServerContext +from .server import ServerContext +from .worker import WorkerContext, ProgressCallback logger = getLogger(__name__) def run_upscale_correction( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, diff --git a/api/onnx_web/worker/__init__.py b/api/onnx_web/worker/__init__.py new file mode 100644 index 000000000..0ca5eefc7 --- /dev/null +++ b/api/onnx_web/worker/__init__.py @@ -0,0 +1,2 @@ +from .context import WorkerContext, ProgressCallback +from .pool import DevicePoolExecutor \ No newline at end of file diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py new file mode 100644 index 000000000..bf927f03c --- /dev/null +++ b/api/onnx_web/worker/context.py @@ -0,0 +1,60 @@ +from logging import getLogger +from torch.multiprocessing import Queue, Value +from typing import Any, Callable + +from ..params import DeviceParams + +logger = getLogger(__name__) + + +ProgressCallback = Callable[[int, int, Any], None] + +class WorkerContext: + cancel: "Value[bool]" = None + key: str = None + progress: "Value[int]" = None + + def __init__( + self, + key: str, + cancel: "Value[bool]", + device: DeviceParams, + pending: "Queue[Any]", + progress: "Value[int]", + ): + self.key = key + self.cancel = cancel + self.device = device + self.pending = pending + self.progress = progress + + def is_cancelled(self) -> bool: + return self.cancel.value + + def get_device(self) -> DeviceParams: + """ + Get the device assigned to this job. + """ + return self.device + + def get_progress(self) -> int: + return self.progress.value + + def get_progress_callback(self) -> ProgressCallback: + def on_progress(step: int, timestep: int, latents: Any): + on_progress.step = step + if self.is_cancelled(): + raise RuntimeError("job has been cancelled") + else: + logger.debug("setting progress for job %s to %s", self.key, step) + self.set_progress(step) + + return on_progress + + def set_cancel(self, cancel: bool = True) -> None: + with self.cancel.get_lock(): + self.cancel.value = cancel + + def set_progress(self, progress: int) -> None: + with self.progress.get_lock(): + self.progress.value = progress diff --git a/api/onnx_web/worker/logging.py b/api/onnx_web/worker/logging.py new file mode 100644 index 000000000..39808a640 --- /dev/null +++ b/api/onnx_web/worker/logging.py @@ -0,0 +1 @@ +# TODO: queue-based logger \ No newline at end of file diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py new file mode 100644 index 000000000..7d83427c8 --- /dev/null +++ b/api/onnx_web/worker/pool.py @@ -0,0 +1,136 @@ +from collections import Counter +from logging import getLogger +from multiprocessing import Queue +from torch.multiprocessing import Lock, Process, Value +from typing import Callable, Dict, List, Optional, Tuple + +from ..params import DeviceParams +from .context import WorkerContext +from .worker import logger_init, worker_init + +logger = getLogger(__name__) + + +class DevicePoolExecutor: + devices: List[DeviceParams] = None + finished: List[Tuple[str, int]] = None + pending: Dict[str, "Queue[WorkerContext]"] = None + progress: Dict[str, Value] = None + workers: Dict[str, Process] = None + + def __init__( + self, + devices: List[DeviceParams], + finished_limit: int = 10, + ): + self.devices = devices + self.finished = [] + self.finished_limit = finished_limit + self.lock = Lock() + self.pending = {} + self.progress = {} + self.workers = {} + + log_queue = Queue() + logger_context = WorkerContext("logger", None, None, log_queue, None) + + logger.debug("starting log worker") + self.logger = Process(target=logger_init, args=(self.lock, logger_context)) + self.logger.start() + + # create a pending queue and progress value for each device + for device in devices: + name = device.device + cancel = Value("B", False, lock=self.lock) + progress = Value("I", 0, lock=self.lock) + pending = Queue() + context = WorkerContext(name, cancel, device, pending, progress) + self.pending[name] = pending + self.progress[name] = pending + + logger.debug("starting worker for device %s", device) + self.workers[name] = Process(target=worker_init, args=(self.lock, context)) + self.workers[name].start() + + logger.debug("testing log worker") + log_queue.put("testing") + + def cancel(self, key: str) -> bool: + """ + Cancel a job. If the job has not been started, this will cancel + the future and never execute it. If the job has been started, it + should be cancelled on the next progress callback. + """ + raise NotImplementedError() + + def done(self, key: str) -> Tuple[Optional[bool], int]: + for k, progress in self.finished: + if key == k: + return (True, progress) + + logger.warn("checking status for unknown key: %s", key) + return (None, 0) + + def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: + # respect overrides if possible + if needs_device is not None: + for i in range(len(self.devices)): + if self.devices[i].device == needs_device.device: + return i + + pending = [ + self.pending[d.device].qsize() for d in self.devices + ] + jobs = Counter(range(len(self.devices))) + jobs.update(pending) + + queued = jobs.most_common() + logger.debug("jobs queued by device: %s", queued) + + lowest_count = queued[-1][1] + lowest_devices = [d[0] for d in queued if d[1] == lowest_count] + lowest_devices.sort() + + return lowest_devices[0] + + def prune(self): + finished_count = len(self.finished) + if finished_count > self.finished_limit: + logger.debug( + "pruning %s of %s finished jobs", + finished_count - self.finished_limit, + finished_count, + ) + self.finished[:] = self.finished[-self.finished_limit:] + + def submit( + self, + key: str, + fn: Callable[..., None], + /, + *args, + needs_device: Optional[DeviceParams] = None, + **kwargs, + ) -> None: + self.prune() + device_idx = self.get_next_device(needs_device=needs_device) + logger.info( + "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] + ) + + device = self.devices[device_idx] + queue = self.pending[device.device] + queue.put((fn, args, kwargs)) + + + def status(self) -> List[Tuple[str, int, bool, int]]: + pending = [ + ( + device.device, + self.pending[device.device].qsize(), + self.workers[device.device].is_alive(), + ) + for device in self.devices + ] + pending.extend(self.finished) + return pending diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py new file mode 100644 index 000000000..8f6adccb0 --- /dev/null +++ b/api/onnx_web/worker/worker.py @@ -0,0 +1,32 @@ +from logging import getLogger +from torch.multiprocessing import Lock +from time import sleep + +from .context import WorkerContext + +logger = getLogger(__name__) + +def logger_init(lock: Lock, context: WorkerContext): + logger.info("checking in from logger") + + with open("worker.log", "w") as f: + while True: + if context.pending.empty(): + logger.info("no logs, sleeping") + sleep(5) + else: + job = context.pending.get() + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + + +def worker_init(lock: Lock, context: WorkerContext): + logger.info("checking in from worker") + + while True: + if context.pending.empty(): + logger.info("no jobs, sleeping") + sleep(5) + else: + job = context.pending.get() + logger.info("got job: %s", job) From 943281feb53f716fa4c805b8cc7f580ba1969fdb Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 25 Feb 2023 23:55:30 -0600 Subject: [PATCH 03/40] wire up worker jobs --- api/onnx_web/worker/context.py | 3 ++- api/onnx_web/worker/worker.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index bf927f03c..8dfb77155 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -1,6 +1,6 @@ from logging import getLogger from torch.multiprocessing import Queue, Value -from typing import Any, Callable +from typing import Any, Callable, Tuple from ..params import DeviceParams @@ -12,6 +12,7 @@ class WorkerContext: cancel: "Value[bool]" = None key: str = None + pending: "Queue[Tuple[Callable, Any, Any]]" = None progress: "Value[int]" = None def __init__( diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 8f6adccb0..cf47f85e6 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,6 +1,7 @@ from logging import getLogger from torch.multiprocessing import Lock from time import sleep +from traceback import print_exception from .context import WorkerContext @@ -30,3 +31,10 @@ def worker_init(lock: Lock, context: WorkerContext): else: job = context.pending.get() logger.info("got job: %s", job) + try: + fn, args, kwargs = job + fn(context, *args, **kwargs) + logger.info("finished job") + except Exception as e: + print_exception(type(e), e, e.__traceback__) + From 06c74a7a96b73facd19a7ed575252679748eac41 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 10:15:12 -0600 Subject: [PATCH 04/40] feat(api): remove Flask app from global scope --- api/.gitignore | 1 + api/launch-extras.sh | 2 +- api/launch.sh | 2 +- api/onnx_web/__init__.py | 7 +- api/onnx_web/chain/__init__.py | 17 + api/onnx_web/diffusion/load.py | 17 +- api/onnx_web/main.py | 53 ++ api/onnx_web/output.py | 5 +- api/onnx_web/params.py | 4 +- api/onnx_web/serve.py | 881 --------------------------------- api/onnx_web/server/api.py | 477 ++++++++++++++++++ api/onnx_web/server/config.py | 224 +++++++++ api/onnx_web/server/params.py | 183 +++++++ api/onnx_web/server/static.py | 34 ++ api/onnx_web/server/utils.py | 32 ++ 15 files changed, 1044 insertions(+), 895 deletions(-) create mode 100644 api/onnx_web/main.py delete mode 100644 api/onnx_web/serve.py create mode 100644 api/onnx_web/server/api.py create mode 100644 api/onnx_web/server/config.py create mode 100644 api/onnx_web/server/params.py create mode 100644 api/onnx_web/server/static.py create mode 100644 api/onnx_web/server/utils.py diff --git a/api/.gitignore b/api/.gitignore index a315070ee..2ba1650c7 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -1,6 +1,7 @@ .coverage coverage.xml +*.log *.swp *.pyc diff --git a/api/launch-extras.sh b/api/launch-extras.sh index 96db9bb2c..f18e14c01 100755 --- a/api/launch-extras.sh +++ b/api/launch-extras.sh @@ -25,4 +25,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app='onnx_web.main:main()' run --host=0.0.0.0 diff --git a/api/launch.sh b/api/launch.sh index 50863ba88..983e09308 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -24,4 +24,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app='onnx_web.main:main()' run --host=0.0.0.0 diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 7316bb87e..b019bb3db 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -1,5 +1,10 @@ from . import logging -from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion +from .chain import ( + correct_codeformer, + correct_gfpgan, + upscale_resrgan, + upscale_stable_diffusion, +) from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline from .diffusion.run import ( run_blend_pipeline, diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 5aa56c567..44fdd6c1e 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -13,3 +13,20 @@ from .upscale_outpaint import upscale_outpaint from .upscale_resrgan import upscale_resrgan from .upscale_stable_diffusion import upscale_stable_diffusion + +CHAIN_STAGES = { + "blend-img2img": blend_img2img, + "blend-inpaint": blend_inpaint, + "blend-mask": blend_mask, + "correct-codeformer": correct_codeformer, + "correct-gfpgan": correct_gfpgan, + "persist-disk": persist_disk, + "persist-s3": persist_s3, + "reduce-crop": reduce_crop, + "reduce-thumbnail": reduce_thumbnail, + "source-noise": source_noise, + "source-txt2img": source_txt2img, + "upscale-outpaint": upscale_outpaint, + "upscale-resrgan": upscale_resrgan, + "upscale-stable-diffusion": upscale_stable_diffusion, +} \ No newline at end of file diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 47c35fcdd..1fff3f1d1 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -4,9 +4,11 @@ import numpy as np from diffusers import ( + DiffusionPipeline, + OnnxRuntimeModel, + StableDiffusionPipeline, DDIMScheduler, DDPMScheduler, - DiffusionPipeline, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, @@ -17,15 +19,13 @@ KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, - OnnxRuntimeModel, PNDMScheduler, - StableDiffusionPipeline, ) try: from diffusers import DEISMultistepScheduler except ImportError: - from .stub_scheduler import StubScheduler as DEISMultistepScheduler + from ..diffusion.stub_scheduler import StubScheduler as DEISMultistepScheduler from ..params import DeviceParams, Size from ..server import ServerContext @@ -54,6 +54,10 @@ } +def get_pipeline_schedulers(): + return pipeline_schedulers + + def get_scheduler_name(scheduler: Any) -> Optional[str]: for k, v in pipeline_schedulers.items(): if scheduler == v or scheduler == v.__name__: @@ -137,13 +141,14 @@ def load_pipeline( server: ServerContext, pipeline: DiffusionPipeline, model: str, - scheduler_type: Any, + scheduler_name: str, device: DeviceParams, lpw: bool, inversion: Optional[str], ): pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion) - scheduler_key = (scheduler_type, model) + scheduler_key = (scheduler_name, model) + scheduler_type = get_pipeline_schedulers()[scheduler_name] cache_pipe = server.cache.get("diffusion", pipe_key) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py new file mode 100644 index 000000000..6381a2416 --- /dev/null +++ b/api/onnx_web/main.py @@ -0,0 +1,53 @@ +import gc + +from diffusers.utils.logging import disable_progress_bar +from flask import Flask +from flask_cors import CORS +from huggingface_hub.utils.tqdm import disable_progress_bars + +from .server.api import register_api_routes +from .server.static import register_static_routes +from .server.config import get_available_platforms, load_models, load_params, load_platforms +from .server.utils import check_paths +from .server.context import ServerContext +from .server.hacks import apply_patches +from .utils import ( + is_debug, +) +from .worker import DevicePoolExecutor + + +def main(): + context = ServerContext.from_environ() + apply_patches(context) + check_paths(context) + load_models(context) + load_params(context) + load_platforms(context) + + if is_debug(): + gc.set_debug(gc.DEBUG_STATS) + + if not context.show_progress: + disable_progress_bar() + disable_progress_bars() + + app = Flask(__name__) + CORS(app, origins=context.cors_origin) + + # any is a fake device, should not be in the pool + pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"]) + + # register routes + register_static_routes(app, context, pool) + register_api_routes(app, context, pool) + + return app #, context, pool + + +if __name__ == "__main__": + # app, context, pool = main() + app = main() + app.run("0.0.0.0", 5000, debug=is_debug()) + # pool.join() + diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index f60b3941f..749008a2a 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -8,7 +8,6 @@ from PIL import Image -from .diffusion.load import get_scheduler_name from .params import Border, ImageParams, Param, Size, UpscaleParams from .server import ServerContext from .utils import base_join @@ -44,7 +43,7 @@ def json_params( } json["params"]["model"] = path.basename(params.model) - json["params"]["scheduler"] = get_scheduler_name(params.scheduler) + json["params"]["scheduler"] = params.scheduler if border is not None: json["border"] = border.tojson() @@ -71,7 +70,7 @@ def make_output_name( hash_value(sha, mode) hash_value(sha, params.model) - hash_value(sha, params.scheduler.__name__) + hash_value(sha, params.scheduler) hash_value(sha, params.prompt) hash_value(sha, params.negative_prompt) hash_value(sha, params.cfg) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index db23414aa..07cf082a1 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -148,7 +148,7 @@ class ImageParams: def __init__( self, model: str, - scheduler: Any, + scheduler: str, prompt: str, cfg: float, steps: int, @@ -174,7 +174,7 @@ def __init__( def tojson(self) -> Dict[str, Optional[Param]]: return { "model": self.model, - "scheduler": self.scheduler.__name__, + "scheduler": self.scheduler, "prompt": self.prompt, "negative_prompt": self.negative_prompt, "cfg": self.cfg, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py deleted file mode 100644 index 4d20f2960..000000000 --- a/api/onnx_web/serve.py +++ /dev/null @@ -1,881 +0,0 @@ -import gc -from functools import cmp_to_key -from glob import glob -from io import BytesIO -from logging import getLogger -from os import makedirs, path -from typing import Dict, List, Tuple, Union - -import numpy as np -import torch -import yaml -from diffusers.utils.logging import disable_progress_bar -from flask import Flask, jsonify, make_response, request, send_from_directory, url_for -from flask_cors import CORS -from huggingface_hub.utils.tqdm import disable_progress_bars -from jsonschema import validate -from onnxruntime import get_available_providers -from PIL import Image - -from .chain import ( - ChainPipeline, - blend_img2img, - blend_inpaint, - correct_codeformer, - correct_gfpgan, - persist_disk, - persist_s3, - reduce_crop, - reduce_thumbnail, - source_noise, - source_txt2img, - upscale_outpaint, - upscale_resrgan, - upscale_stable_diffusion, -) -from .diffusion.load import pipeline_schedulers -from .diffusion.run import ( - run_blend_pipeline, - run_img2img_pipeline, - run_inpaint_pipeline, - run_txt2img_pipeline, - run_upscale_pipeline, -) -from .image import ( # mask filters; noise sources - mask_filter_gaussian_multiply, - mask_filter_gaussian_screen, - mask_filter_none, - noise_source_fill_edge, - noise_source_fill_mask, - noise_source_gaussian, - noise_source_histogram, - noise_source_normal, - noise_source_uniform, - valid_image, -) -from .output import json_params, make_output_name -from .params import ( - Border, - DeviceParams, - ImageParams, - Size, - StageParams, - TileOrder, - UpscaleParams, -) -from .server import ServerContext, apply_patches -from .transformers import run_txt2txt_pipeline -from .utils import ( - base_join, - get_and_clamp_float, - get_and_clamp_int, - get_from_list, - get_from_map, - get_not_empty, - get_size, - is_debug, -) -from .worker import DevicePoolExecutor - -logger = getLogger(__name__) - -# config caching -config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} - -# pipeline params -platform_providers = { - "cpu": "CPUExecutionProvider", - "cuda": "CUDAExecutionProvider", - "directml": "DmlExecutionProvider", - "rocm": "ROCMExecutionProvider", -} - -noise_sources = { - "fill-edge": noise_source_fill_edge, - "fill-mask": noise_source_fill_mask, - "gaussian": noise_source_gaussian, - "histogram": noise_source_histogram, - "normal": noise_source_normal, - "uniform": noise_source_uniform, -} -mask_filters = { - "none": mask_filter_none, - "gaussian-multiply": mask_filter_gaussian_multiply, - "gaussian-screen": mask_filter_gaussian_screen, -} -chain_stages = { - "blend-img2img": blend_img2img, - "blend-inpaint": blend_inpaint, - "correct-codeformer": correct_codeformer, - "correct-gfpgan": correct_gfpgan, - "persist-disk": persist_disk, - "persist-s3": persist_s3, - "reduce-crop": reduce_crop, - "reduce-thumbnail": reduce_thumbnail, - "source-noise": source_noise, - "source-txt2img": source_txt2img, - "upscale-outpaint": upscale_outpaint, - "upscale-resrgan": upscale_resrgan, - "upscale-stable-diffusion": upscale_stable_diffusion, -} - -# Available ORT providers -available_platforms: List[DeviceParams] = [] - -# loaded from model_path -correction_models: List[str] = [] -diffusion_models: List[str] = [] -inversion_models: List[str] = [] -upscaling_models: List[str] = [] - - -def get_config_value(key: str, subkey: str = "default", default=None): - return config_params.get(key, {}).get(subkey, default) - - -def url_from_rule(rule) -> str: - options = {} - for arg in rule.arguments: - options[arg] = ":%s" % (arg) - - return url_for(rule.endpoint, **options) - - -def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: - user = request.remote_addr - - # platform stuff - device = None - device_name = request.args.get("platform") - - if device_name is not None and device_name != "any": - for platform in available_platforms: - if platform.device == device_name: - device = platform - - # pipeline stuff - lpw = get_not_empty(request.args, "lpw", "false") == "true" - model = get_not_empty(request.args, "model", get_config_value("model")) - model_path = get_model_path(model) - scheduler = get_from_map( - request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler") - ) - - inversion = request.args.get("inversion", None) - inversion_path = None - if inversion is not None and inversion.strip() != "": - inversion_path = get_model_path(inversion) - - # image params - prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) - negative_prompt = request.args.get("negativePrompt", None) - - if negative_prompt is not None and negative_prompt.strip() == "": - negative_prompt = None - - batch = get_and_clamp_int( - request.args, - "batch", - get_config_value("batch"), - get_config_value("batch", "max"), - get_config_value("batch", "min"), - ) - cfg = get_and_clamp_float( - request.args, - "cfg", - get_config_value("cfg"), - get_config_value("cfg", "max"), - get_config_value("cfg", "min"), - ) - eta = get_and_clamp_float( - request.args, - "eta", - get_config_value("eta"), - get_config_value("eta", "max"), - get_config_value("eta", "min"), - ) - steps = get_and_clamp_int( - request.args, - "steps", - get_config_value("steps"), - get_config_value("steps", "max"), - get_config_value("steps", "min"), - ) - height = get_and_clamp_int( - request.args, - "height", - get_config_value("height"), - get_config_value("height", "max"), - get_config_value("height", "min"), - ) - width = get_and_clamp_int( - request.args, - "width", - get_config_value("width"), - get_config_value("width", "max"), - get_config_value("width", "min"), - ) - - seed = int(request.args.get("seed", -1)) - if seed == -1: - # this one can safely use np.random because it produces a single value - seed = np.random.randint(np.iinfo(np.int32).max) - - logger.info( - "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", - user, - steps, - scheduler.__name__, - model_path, - device or "any device", - width, - height, - cfg, - seed, - prompt, - ) - - params = ImageParams( - model_path, - scheduler, - prompt, - cfg, - steps, - seed, - eta=eta, - lpw=lpw, - negative_prompt=negative_prompt, - batch=batch, - inversion=inversion_path, - ) - size = Size(width, height) - return (device, params, size) - - -def border_from_request() -> Border: - left = get_and_clamp_int( - request.args, "left", 0, get_config_value("width", "max"), 0 - ) - right = get_and_clamp_int( - request.args, "right", 0, get_config_value("width", "max"), 0 - ) - top = get_and_clamp_int( - request.args, "top", 0, get_config_value("height", "max"), 0 - ) - bottom = get_and_clamp_int( - request.args, "bottom", 0, get_config_value("height", "max"), 0 - ) - - return Border(left, right, top, bottom) - - -def upscale_from_request() -> UpscaleParams: - denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) - scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) - outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) - upscaling = get_from_list(request.args, "upscaling", upscaling_models) - correction = get_from_list(request.args, "correction", correction_models) - faces = get_not_empty(request.args, "faces", "false") == "true" - face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) - face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) - upscale_order = request.args.get("upscaleOrder", "correction-first") - - return UpscaleParams( - upscaling, - correction_model=correction, - denoise=denoise, - faces=faces, - face_outscale=face_outscale, - face_strength=face_strength, - format="onnx", - outscale=outscale, - scale=scale, - upscale_order=upscale_order, - ) - - -def check_paths(context: ServerContext) -> None: - if not path.exists(context.model_path): - raise RuntimeError("model path must exist") - - if not path.exists(context.output_path): - makedirs(context.output_path) - - -def get_model_name(model: str) -> str: - base = path.basename(model) - (file, _ext) = path.splitext(base) - return file - - -def load_models(context: ServerContext) -> None: - global correction_models - global diffusion_models - global inversion_models - global upscaling_models - - diffusion_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) - ] - diffusion_models.extend( - [ - get_model_name(f) - for f in glob(path.join(context.model_path, "stable-diffusion-*")) - ] - ) - diffusion_models = list(set(diffusion_models)) - diffusion_models.sort() - - correction_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) - ] - correction_models = list(set(correction_models)) - correction_models.sort() - - inversion_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) - ] - inversion_models = list(set(inversion_models)) - inversion_models.sort() - - upscaling_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) - ] - upscaling_models = list(set(upscaling_models)) - upscaling_models.sort() - - -def load_params(context: ServerContext) -> None: - global config_params - params_file = path.join(context.params_path, "params.json") - with open(params_file, "r") as f: - config_params = yaml.safe_load(f) - - if "platform" in config_params and context.default_platform is not None: - logger.info( - "Overriding default platform from environment: %s", - context.default_platform, - ) - config_platform = config_params.get("platform", {}) - config_platform["default"] = context.default_platform - - -def load_platforms(context: ServerContext) -> None: - global available_platforms - - providers = list(get_available_providers()) - - for potential in platform_providers: - if ( - platform_providers[potential] in providers - and potential not in context.block_platforms - ): - if potential == "cuda": - for i in range(torch.cuda.device_count()): - available_platforms.append( - DeviceParams( - potential, - platform_providers[potential], - { - "device_id": i, - }, - context.optimizations, - ) - ) - else: - available_platforms.append( - DeviceParams( - potential, - platform_providers[potential], - None, - context.optimizations, - ) - ) - - if context.any_platform: - # the platform should be ignored when the job is scheduled, but set to CPU just in case - available_platforms.append( - DeviceParams( - "any", - platform_providers["cpu"], - None, - context.optimizations, - ) - ) - - # make sure CPU is last on the list - def any_first_cpu_last(a: DeviceParams, b: DeviceParams): - if a.device == b.device: - return 0 - - # any should be first, if it's available - if a.device == "any": - return -1 - - # cpu should be last, if it's available - if a.device == "cpu": - return 1 - - return -1 - - available_platforms = sorted( - available_platforms, key=cmp_to_key(any_first_cpu_last) - ) - - logger.info( - "available acceleration platforms: %s", - ", ".join([str(p) for p in available_platforms]), - ) - - -context = ServerContext.from_environ() -apply_patches(context) -check_paths(context) -load_models(context) -load_params(context) -load_platforms(context) - -if not context.show_progress: - disable_progress_bar() - disable_progress_bars() - -app = Flask(__name__) -CORS(app, origins=context.cors_origin) - -# any is a fake device, should not be in the pool -executor = DevicePoolExecutor([p for p in available_platforms if p.device != "any"]) - -if is_debug(): - gc.set_debug(gc.DEBUG_STATS) - - -def ready_reply(ready: bool, progress: int = 0): - return jsonify( - { - "progress": progress, - "ready": ready, - } - ) - - -def error_reply(err: str): - response = make_response( - jsonify( - { - "error": err, - } - ) - ) - response.status_code = 400 - return response - - -def get_model_path(model: str): - return base_join(context.model_path, model) - - -def serve_bundle_file(filename="index.html"): - return send_from_directory(path.join("..", context.bundle_path), filename) - - -# routes - - -@app.route("/") -def index(): - return serve_bundle_file() - - -@app.route("/") -def index_path(filename): - return serve_bundle_file(filename) - - -@app.route("/api") -def introspect(): - return { - "name": "onnx-web", - "routes": [ - {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} - for rule in app.url_map.iter_rules() - ], - } - - -@app.route("/api/settings/masks") -def list_mask_filters(): - return jsonify(list(mask_filters.keys())) - - -@app.route("/api/settings/models") -def list_models(): - return jsonify( - { - "correction": correction_models, - "diffusion": diffusion_models, - "inversion": inversion_models, - "upscaling": upscaling_models, - } - ) - - -@app.route("/api/settings/noises") -def list_noise_sources(): - return jsonify(list(noise_sources.keys())) - - -@app.route("/api/settings/params") -def list_params(): - return jsonify(config_params) - - -@app.route("/api/settings/platforms") -def list_platforms(): - return jsonify([p.device for p in available_platforms]) - - -@app.route("/api/settings/schedulers") -def list_schedulers(): - return jsonify(list(pipeline_schedulers.keys())) - - -@app.route("/api/img2img", methods=["POST"]) -def img2img(): - if "source" not in request.files: - return error_reply("source image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - strength = get_and_clamp_float( - request.args, - "strength", - get_config_value("strength"), - get_config_value("strength", "max"), - get_config_value("strength", "min"), - ) - - output = make_output_name(context, "img2img", params, size, extras=(strength,)) - job_name = output[0] - logger.info("img2img job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_img2img_pipeline, - context, - params, - output, - upscale, - source, - strength, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/txt2img", methods=["POST"]) -def txt2img(): - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "txt2img", params, size) - job_name = output[0] - logger.info("txt2img job queued for: %s", job_name) - - executor.submit( - job_name, - run_txt2img_pipeline, - context, - params, - size, - output, - upscale, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/inpaint", methods=["POST"]) -def inpaint(): - if "source" not in request.files: - return error_reply("source image is required") - - if "mask" not in request.files: - return error_reply("mask image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - mask_file = request.files.get("mask") - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - expand = border_from_request() - upscale = upscale_from_request() - - fill_color = get_not_empty(request.args, "fillColor", "white") - mask_filter = get_from_map(request.args, "filter", mask_filters, "none") - noise_source = get_from_map(request.args, "noise", noise_sources, "histogram") - tile_order = get_from_list( - request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] - ) - - output = make_output_name( - context, - "inpaint", - params, - size, - extras=( - expand.left, - expand.right, - expand.top, - expand.bottom, - mask_filter.__name__, - noise_source.__name__, - fill_color, - tile_order, - ), - ) - job_name = output[0] - logger.info("inpaint job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - mask = valid_image(mask, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_inpaint_pipeline, - context, - params, - size, - output, - upscale, - source, - mask, - expand, - noise_source, - mask_filter, - fill_color, - tile_order, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) - - -@app.route("/api/upscale", methods=["POST"]) -def upscale(): - if "source" not in request.files: - return error_reply("source image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "upscale", params, size) - job_name = output[0] - logger.info("upscale job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_upscale_pipeline, - context, - params, - size, - output, - upscale, - source, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/chain", methods=["POST"]) -def chain(): - logger.debug( - "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() - ) - body = request.form.get("chain") or request.files.get("chain") - if body is None: - return error_reply("chain pipeline must have a body") - - data = yaml.safe_load(body) - with open("./schemas/chain.yaml", "r") as f: - schema = yaml.safe_load(f.read()) - - logger.debug("validating chain request: %s against %s", data, schema) - validate(data, schema) - - # get defaults from the regular parameters - device, params, size = pipeline_from_request() - output = make_output_name(context, "chain", params, size) - job_name = output[0] - - pipeline = ChainPipeline() - for stage_data in data.get("stages", []): - callback = chain_stages[stage_data.get("type")] - kwargs = stage_data.get("params", {}) - logger.info("request stage: %s, %s", callback.__name__, kwargs) - - stage = StageParams( - stage_data.get("name", callback.__name__), - tile_size=get_size(kwargs.get("tile_size")), - outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), - ) - - if "border" in kwargs: - border = Border.even(int(kwargs.get("border"))) - kwargs["border"] = border - - if "upscale" in kwargs: - upscale = UpscaleParams(kwargs.get("upscale")) - kwargs["upscale"] = upscale - - stage_source_name = "source:%s" % (stage.name) - stage_mask_name = "mask:%s" % (stage.name) - - if stage_source_name in request.files: - logger.debug( - "loading source image %s for pipeline stage %s", - stage_source_name, - stage.name, - ) - source_file = request.files.get(stage_source_name) - source = Image.open(BytesIO(source_file.read())).convert("RGB") - source = valid_image(source, max_dims=(size.width, size.height)) - kwargs["stage_source"] = source - - if stage_mask_name in request.files: - logger.debug( - "loading mask image %s for pipeline stage %s", - stage_mask_name, - stage.name, - ) - mask_file = request.files.get(stage_mask_name) - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - mask = valid_image(mask, max_dims=(size.width, size.height)) - kwargs["stage_mask"] = mask - - pipeline.append((callback, stage, kwargs)) - - logger.info("running chain pipeline with %s stages", len(pipeline.stages)) - - # build and run chain pipeline - empty_source = Image.new("RGB", (size.width, size.height)) - executor.submit( - job_name, - pipeline, - context, - params, - empty_source, - output=output[0], - size=size, - needs_device=device, - ) - - return jsonify(json_params(output, params, size)) - - -@app.route("/api/blend", methods=["POST"]) -def blend(): - if "mask" not in request.files: - return error_reply("mask image is required") - - mask_file = request.files.get("mask") - mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") - mask = valid_image(mask) - - max_sources = 2 - sources = [] - - for i in range(max_sources): - source_file = request.files.get("source:%s" % (i)) - source = Image.open(BytesIO(source_file.read())).convert("RGBA") - source = valid_image(source, mask.size, mask.size) - sources.append(source) - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "upscale", params, size) - job_name = output[0] - logger.info("upscale job queued for: %s", job_name) - - executor.submit( - job_name, - run_blend_pipeline, - context, - params, - size, - output, - upscale, - sources, - mask, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/txt2txt", methods=["POST"]) -def txt2txt(): - device, params, size = pipeline_from_request() - - output = make_output_name(context, "upscale", params, size) - logger.info("upscale job queued for: %s", output) - - executor.submit( - output, - run_txt2txt_pipeline, - context, - params, - size, - output, - needs_device=device, - ) - - return jsonify(json_params(output, params, size)) - - -@app.route("/api/cancel", methods=["PUT"]) -def cancel(): - output_file = request.args.get("output", None) - - cancel = executor.cancel(output_file) - - return ready_reply(cancel) - - -@app.route("/api/ready") -def ready(): - output_file = request.args.get("output", None) - - done, progress = executor.done(output_file) - - if done is None: - output = base_join(context.output_path, output_file) - if path.exists(output): - return ready_reply(True) - - return ready_reply(done, progress=progress) - - -@app.route("/api/status") -def status(): - return jsonify(executor.status()) - - -@app.route("/output/") -def output(filename: str): - return send_from_directory( - path.join("..", context.output_path), filename, as_attachment=False - ) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py new file mode 100644 index 000000000..ed7dc8c05 --- /dev/null +++ b/api/onnx_web/server/api.py @@ -0,0 +1,477 @@ +from io import BytesIO +from logging import getLogger +from os import path + +import yaml +from flask import Flask, jsonify, make_response, request, url_for +from jsonschema import validate +from PIL import Image + +from .context import ServerContext +from .utils import wrap_route +from ..worker.pool import DevicePoolExecutor + +from .config import ( + get_available_platforms, + get_config_params, + get_config_value, + get_correction_models, + get_diffusion_models, + get_inversion_models, + get_mask_filters, + get_noise_sources, + get_upscaling_models, +) +from .params import border_from_request, pipeline_from_request, upscale_from_request + +from ..chain import ( + CHAIN_STAGES, + ChainPipeline, +) +from ..diffusion.load import get_pipeline_schedulers +from ..diffusion.run import ( + run_blend_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, + run_upscale_pipeline, +) +from ..image import ( # mask filters; noise sources + valid_image, +) +from ..output import json_params, make_output_name +from ..params import ( + Border, + StageParams, + TileOrder, + UpscaleParams, +) +from ..transformers import run_txt2txt_pipeline +from ..utils import ( + base_join, + get_and_clamp_float, + get_and_clamp_int, + get_from_list, + get_from_map, + get_not_empty, + get_size, +) + +logger = getLogger(__name__) + + +def ready_reply(ready: bool, progress: int = 0): + return jsonify( + { + "progress": progress, + "ready": ready, + } + ) + + +def error_reply(err: str): + response = make_response( + jsonify( + { + "error": err, + } + ) + ) + response.status_code = 400 + return response + + +def url_from_rule(rule) -> str: + options = {} + for arg in rule.arguments: + options[arg] = ":%s" % (arg) + + return url_for(rule.endpoint, **options) + + +def introspect(context: ServerContext, app: Flask): + return { + "name": "onnx-web", + "routes": [ + {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} + for rule in app.url_map.iter_rules() + ], + } + + +def list_mask_filters(context: ServerContext): + return jsonify(list(get_mask_filters().keys())) + + +def list_models(context: ServerContext): + return jsonify( + { + "correction": get_correction_models(), + "diffusion": get_diffusion_models(), + "inversion": get_inversion_models(), + "upscaling": get_upscaling_models(), + } + ) + + +def list_noise_sources(context: ServerContext): + return jsonify(list(get_noise_sources().keys())) + + +def list_params(context: ServerContext): + return jsonify(get_config_params()) + + +def list_platforms(context: ServerContext): + return jsonify([p.device for p in get_available_platforms()]) + + +def list_schedulers(context: ServerContext): + return jsonify(list(get_pipeline_schedulers().keys())) + + +def img2img(context: ServerContext, pool: DevicePoolExecutor): + if "source" not in request.files: + return error_reply("source image is required") + + source_file = request.files.get("source") + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + strength = get_and_clamp_float( + request.args, + "strength", + get_config_value("strength"), + get_config_value("strength", "max"), + get_config_value("strength", "min"), + ) + + output = make_output_name(context, "img2img", params, size, extras=(strength,)) + job_name = output[0] + logger.info("img2img job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_img2img_pipeline, + context, + params, + output, + upscale, + source, + strength, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def txt2img(context: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "txt2img", params, size) + job_name = output[0] + logger.info("txt2img job queued for: %s", job_name) + + pool.submit( + job_name, + run_txt2img_pipeline, + context, + params, + size, + output, + upscale, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def inpaint(context: ServerContext, pool: DevicePoolExecutor): + if "source" not in request.files: + return error_reply("source image is required") + + if "mask" not in request.files: + return error_reply("mask image is required") + + source_file = request.files.get("source") + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + mask_file = request.files.get("mask") + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + expand = border_from_request() + upscale = upscale_from_request() + + fill_color = get_not_empty(request.args, "fillColor", "white") + mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") + noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram") + tile_order = get_from_list( + request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] + ) + + output = make_output_name( + context, + "inpaint", + params, + size, + extras=( + expand.left, + expand.right, + expand.top, + expand.bottom, + mask_filter.__name__, + noise_source.__name__, + fill_color, + tile_order, + ), + ) + job_name = output[0] + logger.info("inpaint job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + mask = valid_image(mask, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_inpaint_pipeline, + context, + params, + size, + output, + upscale, + source, + mask, + expand, + noise_source, + mask_filter, + fill_color, + tile_order, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) + + +def upscale(context: ServerContext, pool: DevicePoolExecutor): + if "source" not in request.files: + return error_reply("source image is required") + + source_file = request.files.get("source") + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_upscale_pipeline, + context, + params, + size, + output, + upscale, + source, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def chain(context: ServerContext, pool: DevicePoolExecutor): + logger.debug( + "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() + ) + body = request.form.get("chain") or request.files.get("chain") + if body is None: + return error_reply("chain pipeline must have a body") + + data = yaml.safe_load(body) + with open("./schemas/chain.yaml", "r") as f: + schema = yaml.safe_load(f.read()) + + logger.debug("validating chain request: %s against %s", data, schema) + validate(data, schema) + + # get defaults from the regular parameters + device, params, size = pipeline_from_request(context) + output = make_output_name(context, "chain", params, size) + job_name = output[0] + + pipeline = ChainPipeline() + for stage_data in data.get("stages", []): + callback = CHAIN_STAGES[stage_data.get("type")] + kwargs = stage_data.get("params", {}) + logger.info("request stage: %s, %s", callback.__name__, kwargs) + + stage = StageParams( + stage_data.get("name", callback.__name__), + tile_size=get_size(kwargs.get("tile_size")), + outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), + ) + + if "border" in kwargs: + border = Border.even(int(kwargs.get("border"))) + kwargs["border"] = border + + if "upscale" in kwargs: + upscale = UpscaleParams(kwargs.get("upscale")) + kwargs["upscale"] = upscale + + stage_source_name = "source:%s" % (stage.name) + stage_mask_name = "mask:%s" % (stage.name) + + if stage_source_name in request.files: + logger.debug( + "loading source image %s for pipeline stage %s", + stage_source_name, + stage.name, + ) + source_file = request.files.get(stage_source_name) + source = Image.open(BytesIO(source_file.read())).convert("RGB") + source = valid_image(source, max_dims=(size.width, size.height)) + kwargs["stage_source"] = source + + if stage_mask_name in request.files: + logger.debug( + "loading mask image %s for pipeline stage %s", + stage_mask_name, + stage.name, + ) + mask_file = request.files.get(stage_mask_name) + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + mask = valid_image(mask, max_dims=(size.width, size.height)) + kwargs["stage_mask"] = mask + + pipeline.append((callback, stage, kwargs)) + + logger.info("running chain pipeline with %s stages", len(pipeline.stages)) + + # build and run chain pipeline + empty_source = Image.new("RGB", (size.width, size.height)) + pool.submit( + job_name, + pipeline, + context, + params, + empty_source, + output=output[0], + size=size, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + +def blend(context: ServerContext, pool: DevicePoolExecutor): + if "mask" not in request.files: + return error_reply("mask image is required") + + mask_file = request.files.get("mask") + mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") + mask = valid_image(mask) + + max_sources = 2 + sources = [] + + for i in range(max_sources): + source_file = request.files.get("source:%s" % (i)) + source = Image.open(BytesIO(source_file.read())).convert("RGBA") + source = valid_image(source, mask.size, mask.size) + sources.append(source) + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + pool.submit( + job_name, + run_blend_pipeline, + context, + params, + size, + output, + upscale, + sources, + mask, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def txt2txt(context: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(context) + + output = make_output_name(context, "upscale", params, size) + logger.info("upscale job queued for: %s", output) + + pool.submit( + output, + run_txt2txt_pipeline, + context, + params, + size, + output, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + +def cancel(context: ServerContext, pool: DevicePoolExecutor): + output_file = request.args.get("output", None) + + cancel = pool.cancel(output_file) + + return ready_reply(cancel) + + +def ready(context: ServerContext, pool: DevicePoolExecutor): + output_file = request.args.get("output", None) + + done, progress = pool.done(output_file) + + if done is None: + output = base_join(context.output_path, output_file) + if path.exists(output): + return ready_reply(True) + + return ready_reply(done, progress=progress) + + +def status(context: ServerContext, pool: DevicePoolExecutor): + return jsonify(pool.status()) + + +def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): + return [ + app.route("/api")(wrap_route(introspect, context, app=app)), + app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), + app.route("/api/settings/models")(wrap_route(list_models, context)), + app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), + app.route("/api/settings/params")(wrap_route(list_params, context)), + app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), + app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), + app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)), + app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)), + app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)), + app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)), + app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)), + app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)), + app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)), + app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)), + app.route("/api/ready")(wrap_route(ready, context, pool=pool)), + app.route("/api/status")(wrap_route(status, context, pool=pool)), + ] diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py new file mode 100644 index 000000000..011105489 --- /dev/null +++ b/api/onnx_web/server/config.py @@ -0,0 +1,224 @@ +from functools import cmp_to_key +from glob import glob +from logging import getLogger +from os import path +from typing import Dict, List, Optional, Union + +import torch +import yaml +from onnxruntime import get_available_providers + +from .context import ServerContext +from ..image import ( # mask filters; noise sources + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, + noise_source_fill_edge, + noise_source_fill_mask, + noise_source_gaussian, + noise_source_histogram, + noise_source_normal, + noise_source_uniform, +) +from ..params import ( + DeviceParams, +) + +logger = getLogger(__name__) + +# config caching +config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} + +# pipeline params +platform_providers = { + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "directml": "DmlExecutionProvider", + "rocm": "ROCMExecutionProvider", +} +noise_sources = { + "fill-edge": noise_source_fill_edge, + "fill-mask": noise_source_fill_mask, + "gaussian": noise_source_gaussian, + "histogram": noise_source_histogram, + "normal": noise_source_normal, + "uniform": noise_source_uniform, +} +mask_filters = { + "none": mask_filter_none, + "gaussian-multiply": mask_filter_gaussian_multiply, + "gaussian-screen": mask_filter_gaussian_screen, +} + + +# Available ORT providers +available_platforms: List[DeviceParams] = [] + +# loaded from model_path +correction_models: List[str] = [] +diffusion_models: List[str] = [] +inversion_models: List[str] = [] +upscaling_models: List[str] = [] + + +def get_config_params(): + return config_params + + +def get_available_platforms(): + return available_platforms + + +def get_correction_models(): + return correction_models + + +def get_diffusion_models(): + return diffusion_models + + +def get_inversion_models(): + return inversion_models + + +def get_upscaling_models(): + return upscaling_models + + +def get_mask_filters(): + return mask_filters + + +def get_noise_sources(): + return noise_sources + + +def get_config_value(key: str, subkey: str = "default", default=None): + return config_params.get(key, {}).get(subkey, default) + + +def get_model_name(model: str) -> str: + base = path.basename(model) + (file, _ext) = path.splitext(base) + return file + + +def load_models(context: ServerContext) -> None: + global correction_models + global diffusion_models + global inversion_models + global upscaling_models + + diffusion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) + ] + diffusion_models.extend( + [ + get_model_name(f) + for f in glob(path.join(context.model_path, "stable-diffusion-*")) + ] + ) + diffusion_models = list(set(diffusion_models)) + diffusion_models.sort() + + correction_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) + ] + correction_models = list(set(correction_models)) + correction_models.sort() + + inversion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) + ] + inversion_models = list(set(inversion_models)) + inversion_models.sort() + + upscaling_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) + ] + upscaling_models = list(set(upscaling_models)) + upscaling_models.sort() + + +def load_params(context: ServerContext) -> None: + global config_params + params_file = path.join(context.params_path, "params.json") + with open(params_file, "r") as f: + config_params = yaml.safe_load(f) + + if "platform" in config_params and context.default_platform is not None: + logger.info( + "Overriding default platform from environment: %s", + context.default_platform, + ) + config_platform = config_params.get("platform", {}) + config_platform["default"] = context.default_platform + + +def load_platforms(context: ServerContext) -> None: + global available_platforms + + providers = list(get_available_providers()) + + for potential in platform_providers: + if ( + platform_providers[potential] in providers + and potential not in context.block_platforms + ): + if potential == "cuda": + for i in range(torch.cuda.device_count()): + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + { + "device_id": i, + }, + context.optimizations, + ) + ) + else: + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + None, + context.optimizations, + ) + ) + + if context.any_platform: + # the platform should be ignored when the job is scheduled, but set to CPU just in case + available_platforms.append( + DeviceParams( + "any", + platform_providers["cpu"], + None, + context.optimizations, + ) + ) + + # make sure CPU is last on the list + def any_first_cpu_last(a: DeviceParams, b: DeviceParams): + if a.device == b.device: + return 0 + + # any should be first, if it's available + if a.device == "any": + return -1 + + # cpu should be last, if it's available + if a.device == "cpu": + return 1 + + return -1 + + available_platforms = sorted( + available_platforms, key=cmp_to_key(any_first_cpu_last) + ) + + logger.info( + "available acceleration platforms: %s", + ", ".join([str(p) for p in available_platforms]), + ) + diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py new file mode 100644 index 000000000..b70ef79a1 --- /dev/null +++ b/api/onnx_web/server/params.py @@ -0,0 +1,183 @@ +from logging import getLogger +from typing import Tuple + +import numpy as np +from flask import request + +from .context import ServerContext + +from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models +from .utils import get_model_path + +from ..diffusion.load import pipeline_schedulers +from ..params import ( + Border, + DeviceParams, + ImageParams, + Size, + UpscaleParams, +) +from ..utils import ( + get_and_clamp_float, + get_and_clamp_int, + get_from_list, + get_not_empty, +) + +logger = getLogger(__name__) + + +def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]: + user = request.remote_addr + + # platform stuff + device = None + device_name = request.args.get("platform") + + if device_name is not None and device_name != "any": + for platform in get_available_platforms(): + if platform.device == device_name: + device = platform + + # pipeline stuff + lpw = get_not_empty(request.args, "lpw", "false") == "true" + model = get_not_empty(request.args, "model", get_config_value("model")) + model_path = get_model_path(context, model) + scheduler = get_from_list( + request.args, "scheduler", pipeline_schedulers.keys() + ) + + if scheduler is None: + scheduler = get_config_value("scheduler") + + inversion = request.args.get("inversion", None) + inversion_path = None + if inversion is not None and inversion.strip() != "": + inversion_path = get_model_path(context, inversion) + + # image params + prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) + negative_prompt = request.args.get("negativePrompt", None) + + if negative_prompt is not None and negative_prompt.strip() == "": + negative_prompt = None + + batch = get_and_clamp_int( + request.args, + "batch", + get_config_value("batch"), + get_config_value("batch", "max"), + get_config_value("batch", "min"), + ) + cfg = get_and_clamp_float( + request.args, + "cfg", + get_config_value("cfg"), + get_config_value("cfg", "max"), + get_config_value("cfg", "min"), + ) + eta = get_and_clamp_float( + request.args, + "eta", + get_config_value("eta"), + get_config_value("eta", "max"), + get_config_value("eta", "min"), + ) + steps = get_and_clamp_int( + request.args, + "steps", + get_config_value("steps"), + get_config_value("steps", "max"), + get_config_value("steps", "min"), + ) + height = get_and_clamp_int( + request.args, + "height", + get_config_value("height"), + get_config_value("height", "max"), + get_config_value("height", "min"), + ) + width = get_and_clamp_int( + request.args, + "width", + get_config_value("width"), + get_config_value("width", "max"), + get_config_value("width", "min"), + ) + + seed = int(request.args.get("seed", -1)) + if seed == -1: + # this one can safely use np.random because it produces a single value + seed = np.random.randint(np.iinfo(np.int32).max) + + logger.info( + "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", + user, + steps, + scheduler, + model_path, + device or "any device", + width, + height, + cfg, + seed, + prompt, + ) + + params = ImageParams( + model_path, + scheduler, + prompt, + cfg, + steps, + seed, + eta=eta, + lpw=lpw, + negative_prompt=negative_prompt, + batch=batch, + inversion=inversion_path, + ) + size = Size(width, height) + return (device, params, size) + + +def border_from_request() -> Border: + left = get_and_clamp_int( + request.args, "left", 0, get_config_value("width", "max"), 0 + ) + right = get_and_clamp_int( + request.args, "right", 0, get_config_value("width", "max"), 0 + ) + top = get_and_clamp_int( + request.args, "top", 0, get_config_value("height", "max"), 0 + ) + bottom = get_and_clamp_int( + request.args, "bottom", 0, get_config_value("height", "max"), 0 + ) + + return Border(left, right, top, bottom) + + +def upscale_from_request() -> UpscaleParams: + denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) + scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) + outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) + upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) + correction = get_from_list(request.args, "correction", get_correction_models()) + faces = get_not_empty(request.args, "faces", "false") == "true" + face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) + face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) + upscale_order = request.args.get("upscaleOrder", "correction-first") + + return UpscaleParams( + upscaling, + correction_model=correction, + denoise=denoise, + faces=faces, + face_outscale=face_outscale, + face_strength=face_strength, + format="onnx", + outscale=outscale, + scale=scale, + upscale_order=upscale_order, + ) diff --git a/api/onnx_web/server/static.py b/api/onnx_web/server/static.py new file mode 100644 index 000000000..296c8deb6 --- /dev/null +++ b/api/onnx_web/server/static.py @@ -0,0 +1,34 @@ +from os import path + +from flask import Flask, send_from_directory + +from .utils import wrap_route +from .context import ServerContext +from ..worker.pool import DevicePoolExecutor + + +def serve_bundle_file(context: ServerContext, filename="index.html"): + return send_from_directory(path.join("..", context.bundle_path), filename) + + +# non-API routes +def index(context: ServerContext): + return serve_bundle_file(context) + + +def index_path(context: ServerContext, filename: str): + return serve_bundle_file(context, filename) + + +def output(context: ServerContext, filename: str): + return send_from_directory( + path.join("..", context.output_path), filename, as_attachment=False + ) + + +def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): + return [ + app.route("/")(wrap_route(index, context)), + app.route("/")(wrap_route(index_path, context)), + app.route("/output/")(wrap_route(output, context)), + ] diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py new file mode 100644 index 000000000..8dd359a1b --- /dev/null +++ b/api/onnx_web/server/utils.py @@ -0,0 +1,32 @@ +from os import makedirs, path +from typing import Callable, Dict, List, Tuple +from functools import partial, update_wrapper + +from flask import Flask + +from onnx_web.utils import base_join +from onnx_web.worker.pool import DevicePoolExecutor + +from .context import ServerContext + + +def check_paths(context: ServerContext) -> None: + if not path.exists(context.model_path): + raise RuntimeError("model path must exist") + + if not path.exists(context.output_path): + makedirs(context.output_path) + + +def get_model_path(context: ServerContext, model: str): + return base_join(context.model_path, model) + + +def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]): + pass + + +def wrap_route(func, *args, **kwargs): + partial_func = partial(func, *args, **kwargs) + update_wrapper(partial_func, func) + return partial_func From 6998e8735ce059208ed04a2320867a7089edbc95 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 10:47:31 -0600 Subject: [PATCH 05/40] rejoin worker pool --- api/onnx_web/main.py | 10 ++++++---- api/onnx_web/worker/pool.py | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 6381a2416..bb19466c4 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -4,6 +4,7 @@ from flask import Flask from flask_cors import CORS from huggingface_hub.utils.tqdm import disable_progress_bars +from torch.multiprocessing import set_start_method from .server.api import register_api_routes from .server.static import register_static_routes @@ -18,6 +19,8 @@ def main(): + set_start_method("spawn", force=True) + context = ServerContext.from_environ() apply_patches(context) check_paths(context) @@ -42,12 +45,11 @@ def main(): register_static_routes(app, context, pool) register_api_routes(app, context, pool) - return app #, context, pool + return app, pool if __name__ == "__main__": - # app, context, pool = main() - app = main() + app, pool = main() app.run("0.0.0.0", 5000, debug=is_debug()) - # pool.join() + pool.join() diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 7d83427c8..b7eaf3a39 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -93,6 +93,12 @@ def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: return lowest_devices[0] + def join(self): + for device, worker in self.workers.items(): + if worker.is_alive(): + logger.info("stopping worker for device %s", device) + worker.join(5) + def prune(self): finished_count = len(self.finished) if finished_count > self.finished_limit: From d765a6f01ba9fd343005a4be3185793f6091b674 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 11:16:33 -0600 Subject: [PATCH 06/40] make logger start up well --- api/onnx_web/worker/pool.py | 15 ++++++------ api/onnx_web/worker/worker.py | 46 +++++++++++++++-------------------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index b7eaf3a39..fa2e5f57f 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -31,13 +31,14 @@ def __init__( self.progress = {} self.workers = {} - log_queue = Queue() - logger_context = WorkerContext("logger", None, None, log_queue, None) - logger.debug("starting log worker") - self.logger = Process(target=logger_init, args=(self.lock, logger_context)) + self.log_queue = Queue() + self.logger = Process(target=logger_init, args=(self.lock, self.log_queue)) self.logger.start() + logger.debug("testing log worker") + self.log_queue.put("testing") + # create a pending queue and progress value for each device for device in devices: name = device.device @@ -52,9 +53,6 @@ def __init__( self.workers[name] = Process(target=worker_init, args=(self.lock, context)) self.workers[name].start() - logger.debug("testing log worker") - log_queue.put("testing") - def cancel(self, key: str) -> bool: """ Cancel a job. If the job has not been started, this will cancel @@ -99,6 +97,9 @@ def join(self): logger.info("stopping worker for device %s", device) worker.join(5) + if self.logger.is_alive(): + self.logger.join(5) + def prune(self): finished_count = len(self.finished) if finished_count > self.finished_limit: diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index cf47f85e6..e5d463067 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,40 +1,34 @@ from logging import getLogger -from torch.multiprocessing import Lock -from time import sleep +from onnxruntime import get_available_providers +from torch.multiprocessing import Lock, Queue from traceback import print_exception from .context import WorkerContext logger = getLogger(__name__) -def logger_init(lock: Lock, context: WorkerContext): - logger.info("checking in from logger") +def logger_init(lock: Lock, logs: Queue): + with lock: + logger.info("checking in from logger, %s", lock) - with open("worker.log", "w") as f: - while True: - if context.pending.empty(): - logger.info("no logs, sleeping") - sleep(5) - else: - job = context.pending.get() - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") + while True: + job = logs.get() + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") def worker_init(lock: Lock, context: WorkerContext): - logger.info("checking in from worker") + with lock: + logger.info("checking in from worker, %s, %s", lock, get_available_providers()) while True: - if context.pending.empty(): - logger.info("no jobs, sleeping") - sleep(5) - else: - job = context.pending.get() - logger.info("got job: %s", job) - try: - fn, args, kwargs = job - fn(context, *args, **kwargs) - logger.info("finished job") - except Exception as e: - print_exception(type(e), e, e.__traceback__) + job = context.pending.get() + logger.info("got job: %s", job) + try: + fn, args, kwargs = job + fn(context, *args, **kwargs) + logger.info("finished job") + except Exception as e: + print_exception(type(e), e, e.__traceback__) From e1d0ad54b7781665d7c90206903266dc3e205e04 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 12:24:51 -0600 Subject: [PATCH 07/40] lock per worker, torch before ORT --- api/onnx_web/params.py | 1 + api/onnx_web/worker/pool.py | 17 +++++++++++------ api/onnx_web/worker/worker.py | 5 +++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 07cf082a1..9bd5e819e 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,6 +2,7 @@ from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import torch from onnxruntime import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index fa2e5f57f..8f6a6d132 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -26,7 +26,8 @@ def __init__( self.devices = devices self.finished = [] self.finished_limit = finished_limit - self.lock = Lock() + self.context = {} + self.locks = {} self.pending = {} self.progress = {} self.workers = {} @@ -42,15 +43,18 @@ def __init__( # create a pending queue and progress value for each device for device in devices: name = device.device - cancel = Value("B", False, lock=self.lock) - progress = Value("I", 0, lock=self.lock) + lock = Lock() + self.locks[name] = lock + cancel = Value("B", False, lock=lock) + progress = Value("I", 0, lock=lock) + self.progress[name] = progress pending = Queue() - context = WorkerContext(name, cancel, device, pending, progress) self.pending[name] = pending - self.progress[name] = pending + context = WorkerContext(name, cancel, device, pending, progress) + self.context[name] = context logger.debug("starting worker for device %s", device) - self.workers[name] = Process(target=worker_init, args=(self.lock, context)) + self.workers[name] = Process(target=worker_init, args=(lock, context)) self.workers[name].start() def cancel(self, key: str) -> bool: @@ -135,6 +139,7 @@ def status(self) -> List[Tuple[str, int, bool, int]]: ( device.device, self.pending[device.device].qsize(), + self.progress[device.device].value, self.workers[device.device].is_alive(), ) for device in self.devices diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index e5d463067..f5d3689c4 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,7 +1,8 @@ from logging import getLogger +import torch # has to come before ORT from onnxruntime import get_available_providers from torch.multiprocessing import Lock, Queue -from traceback import print_exception +from traceback import format_exception from .context import WorkerContext @@ -30,5 +31,5 @@ def worker_init(lock: Lock, context: WorkerContext): fn(context, *args, **kwargs) logger.info("finished job") except Exception as e: - print_exception(type(e), e, e.__traceback__) + logger.error(format_exception(type(e), e, e.__traceback__)) From f115326da78113d3ec38d4aa72c127ba67bc804d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 12:32:48 -0600 Subject: [PATCH 08/40] apply patches within workers --- api/onnx_web/main.py | 2 +- api/onnx_web/worker/pool.py | 13 ++++++++++--- api/onnx_web/worker/worker.py | 6 +++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index bb19466c4..86a5896d2 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -39,7 +39,7 @@ def main(): CORS(app, origins=context.cors_origin) # any is a fake device, should not be in the pool - pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"]) + pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"]) # register routes register_static_routes(app, context, pool) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 8f6a6d132..7dad4bee6 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -5,6 +5,7 @@ from typing import Callable, Dict, List, Optional, Tuple from ..params import DeviceParams +from ..server import ServerContext from .context import WorkerContext from .worker import logger_init, worker_init @@ -20,9 +21,11 @@ class DevicePoolExecutor: def __init__( self, + server: ServerContext, devices: List[DeviceParams], finished_limit: int = 10, ): + self.server = server self.devices = devices self.finished = [] self.finished_limit = finished_limit @@ -32,9 +35,12 @@ def __init__( self.progress = {} self.workers = {} + # TODO: make this a method logger.debug("starting log worker") self.log_queue = Queue() - self.logger = Process(target=logger_init, args=(self.lock, self.log_queue)) + log_lock = Lock() + self.locks["logger"] = log_lock + self.logger = Process(target=logger_init, args=(log_lock, self.log_queue)) self.logger.start() logger.debug("testing log worker") @@ -43,10 +49,11 @@ def __init__( # create a pending queue and progress value for each device for device in devices: name = device.device + # TODO: make this a method lock = Lock() self.locks[name] = lock cancel = Value("B", False, lock=lock) - progress = Value("I", 0, lock=lock) + progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? self.progress[name] = progress pending = Queue() self.pending[name] = pending @@ -54,7 +61,7 @@ def __init__( self.context[name] = context logger.debug("starting worker for device %s", device) - self.workers[name] = Process(target=worker_init, args=(lock, context)) + self.workers[name] = Process(target=worker_init, args=(lock, context, server)) self.workers[name].start() def cancel(self, key: str) -> bool: diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index f5d3689c4..3497da708 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -5,9 +5,11 @@ from traceback import format_exception from .context import WorkerContext +from ..server import ServerContext, apply_patches logger = getLogger(__name__) + def logger_init(lock: Lock, logs: Queue): with lock: logger.info("checking in from logger, %s", lock) @@ -19,10 +21,12 @@ def logger_init(lock: Lock, logs: Queue): f.write(str(job) + "\n\n") -def worker_init(lock: Lock, context: WorkerContext): +def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): with lock: logger.info("checking in from worker, %s, %s", lock, get_available_providers()) + apply_patches(server) + while True: job = context.pending.get() logger.info("got job: %s", job) From e0737e9e08bdf8b66514fdf8f4611ef1c2283915 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 12:51:11 -0600 Subject: [PATCH 09/40] update progress and finished flag from worker --- api/onnx_web/worker/context.py | 4 ++++ api/onnx_web/worker/pool.py | 28 +++++++++++++++++++--------- api/onnx_web/worker/worker.py | 14 +++++++++++++- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 8dfb77155..59f55fddc 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -22,12 +22,16 @@ def __init__( device: DeviceParams, pending: "Queue[Any]", progress: "Value[int]", + logs: "Queue[str]", + finished: "Value[bool]", ): self.key = key self.cancel = cancel self.device = device self.pending = pending self.progress = progress + self.logs = logs + self.finished = finished def is_cancelled(self) -> bool: return self.cancel.value diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 7dad4bee6..0721896e1 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -14,10 +14,11 @@ class DevicePoolExecutor: devices: List[DeviceParams] = None - finished: List[Tuple[str, int]] = None + finished: Dict[str, "Value[bool]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None - progress: Dict[str, Value] = None + progress: Dict[str, "Value[int]"] = None workers: Dict[str, Process] = None + jobs: Dict[str, str] = None def __init__( self, @@ -27,13 +28,14 @@ def __init__( ): self.server = server self.devices = devices - self.finished = [] + self.finished = {} self.finished_limit = finished_limit self.context = {} self.locks = {} self.pending = {} self.progress = {} self.workers = {} + self.jobs = {} # Dict[Output, Device] # TODO: make this a method logger.debug("starting log worker") @@ -53,11 +55,13 @@ def __init__( lock = Lock() self.locks[name] = lock cancel = Value("B", False, lock=lock) + finished = Value("B", False) + self.finished[name] = finished progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? self.progress[name] = progress pending = Queue() self.pending[name] = pending - context = WorkerContext(name, cancel, device, pending, progress) + context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished) self.context[name] = context logger.debug("starting worker for device %s", device) @@ -73,12 +77,16 @@ def cancel(self, key: str) -> bool: raise NotImplementedError() def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, progress in self.finished: - if key == k: - return (True, progress) + if not key in self.jobs: + logger.warn("checking status for unknown key: %s", key) + return (None, 0) + + device = self.jobs[key] + finished = self.finished[device] + progress = self.progress[device] + + return (finished.value, progress.value) - logger.warn("checking status for unknown key: %s", key) - return (None, 0) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -140,6 +148,8 @@ def submit( queue = self.pending[device.device] queue.put((fn, args, kwargs)) + self.jobs[key] = device.device + def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 3497da708..4edfb4bb2 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -32,8 +32,20 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): logger.info("got job: %s", job) try: fn, args, kwargs = job + name = args[3][0] + logger.info("starting job: %s", name) + with context.finished.get_lock(): + context.finished.value = False + + with context.progress.get_lock(): + context.progress.value = 0 + fn(context, *args, **kwargs) - logger.info("finished job") + logger.info("finished job: %s", name) + + with context.finished.get_lock(): + context.finished.value = True + except Exception as e: logger.error(format_exception(type(e), e, e.__traceback__)) From 6502e1e3c8365a25f2712c9e6c3e16ba59ccf52d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 12:58:38 -0600 Subject: [PATCH 10/40] recycle worker pool after 10 jobs --- api/onnx_web/worker/pool.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 0721896e1..cd78051dc 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -36,6 +36,7 @@ def __init__( self.progress = {} self.workers = {} self.jobs = {} # Dict[Output, Device] + self.job_count = 0 # TODO: make this a method logger.debug("starting log worker") @@ -129,6 +130,27 @@ def prune(self): ) self.finished[:] = self.finished[-self.finished_limit:] + def recycle(self): + for name, proc in self.workers.items(): + if proc.is_alive(): + logger.debug("shutting down worker for device %s", name) + proc.join(5) + else: + logger.warning("worker for device %s has died", name) + + self.workers[name] = None + + logger.info("starting new workers") + + for name in self.workers.keys(): + context = self.context[name] + lock = self.locks[name] + + logger.debug("starting worker for device %s", name) + self.workers[name] = Process(target=worker_init, args=(lock, context, self.server)) + self.workers[name].start() + + def submit( self, key: str, @@ -138,6 +160,11 @@ def submit( needs_device: Optional[DeviceParams] = None, **kwargs, ) -> None: + self.job_count += 1 + if self.job_count > 10: + self.recycle() + self.job_count = 0 + self.prune() device_idx = self.get_next_device(needs_device=needs_device) logger.info( From b880b7a121e9adbe88d30a457c0bfbd54a45d6d9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 13:09:24 -0600 Subject: [PATCH 11/40] set process titles, terminate workers --- api/onnx_web/worker/pool.py | 3 +++ api/onnx_web/worker/worker.py | 4 ++++ api/requirements.txt | 1 + api/scripts/test-memory-leak.sh | 2 +- 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index cd78051dc..34e5b70c1 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -135,10 +135,12 @@ def recycle(self): if proc.is_alive(): logger.debug("shutting down worker for device %s", name) proc.join(5) + proc.terminate() else: logger.warning("worker for device %s has died", name) self.workers[name] = None + del proc logger.info("starting new workers") @@ -161,6 +163,7 @@ def submit( **kwargs, ) -> None: self.job_count += 1 + logger.debug("pool job count: %s", self.job_count) if self.job_count > 10: self.recycle() self.job_count = 0 diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 4edfb4bb2..07c6bb025 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -3,6 +3,7 @@ from onnxruntime import get_available_providers from torch.multiprocessing import Lock, Queue from traceback import format_exception +from setproctitle import setproctitle from .context import WorkerContext from ..server import ServerContext, apply_patches @@ -14,6 +15,8 @@ def logger_init(lock: Lock, logs: Queue): with lock: logger.info("checking in from logger, %s", lock) + setproctitle("onnx-web logger") + while True: job = logs.get() with open("worker.log", "w") as f: @@ -26,6 +29,7 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): logger.info("checking in from worker, %s, %s", lock, get_available_providers()) apply_patches(server) + setproctitle("onnx-web worker: %s", context.device.device) while True: job = context.pending.get() diff --git a/api/requirements.txt b/api/requirements.txt index e08b7c312..2ac35af81 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -23,3 +23,4 @@ flask flask-cors jsonschema pyyaml +setproctitle \ No newline at end of file diff --git a/api/scripts/test-memory-leak.sh b/api/scripts/test-memory-leak.sh index 5e432e433..d7545e606 100755 --- a/api/scripts/test-memory-leak.sh +++ b/api/scripts/test-memory-leak.sh @@ -14,5 +14,5 @@ do --insecure || break; ((test_images++)); echo "waiting after $test_images"; - sleep 10; + sleep 30; done From 584dddb5d69c5a7091617b5e3278665070a386ab Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 14:15:30 -0600 Subject: [PATCH 12/40] lint all the new stuff --- api/onnx_web/__init__.py | 2 +- api/onnx_web/chain/__init__.py | 2 +- api/onnx_web/chain/base.py | 2 +- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/blend_inpaint.py | 2 +- api/onnx_web/chain/blend_mask.py | 2 +- api/onnx_web/chain/correct_codeformer.py | 2 +- api/onnx_web/chain/correct_gfpgan.py | 3 +- api/onnx_web/chain/persist_disk.py | 3 +- api/onnx_web/chain/persist_s3.py | 3 +- api/onnx_web/chain/reduce_crop.py | 3 +- api/onnx_web/chain/reduce_thumbnail.py | 3 +- api/onnx_web/chain/source_noise.py | 3 +- api/onnx_web/chain/source_txt2img.py | 3 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- .../chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/diffusion/load.py | 6 +- api/onnx_web/diffusion/run.py | 2 +- api/onnx_web/main.py | 64 +++++------ api/onnx_web/onnx/onnx_net.py | 2 +- api/onnx_web/onnx/torch_before_ort.py | 5 + api/onnx_web/params.py | 3 +- api/onnx_web/server/api.py | 102 +++++++++--------- api/onnx_web/server/config.py | 11 +- api/onnx_web/server/params.py | 34 +++--- api/onnx_web/server/static.py | 18 ++-- api/onnx_web/server/utils.py | 9 +- api/onnx_web/upscale.py | 2 +- api/onnx_web/worker/__init__.py | 2 +- api/onnx_web/worker/context.py | 4 +- api/onnx_web/worker/logging.py | 2 +- api/onnx_web/worker/pool.py | 37 ++++--- api/onnx_web/worker/worker.py | 11 +- 34 files changed, 182 insertions(+), 173 deletions(-) create mode 100644 api/onnx_web/onnx/torch_before_ort.py diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index b019bb3db..5294121c7 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -57,4 +57,4 @@ ) from .worker import ( DevicePoolExecutor, -) \ No newline at end of file +) diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 44fdd6c1e..a983c8498 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -29,4 +29,4 @@ "upscale-outpaint": upscale_outpaint, "upscale-resrgan": upscale_resrgan, "upscale-stable-diffusion": upscale_stable_diffusion, -} \ No newline at end of file +} diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index dc8073223..2ddc59aea 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,9 +7,9 @@ from ..output import save_image from ..params import ImageParams, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 675311036..f7c516058 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -7,8 +7,8 @@ from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 51b6d983b..7d864b5e8 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -10,9 +10,9 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index bfe11aabe..f7b68e6f3 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -7,9 +7,9 @@ from onnx_web.output import save_image from ..params import ImageParams, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 01d61db7c..c3eaec65c 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -3,8 +3,8 @@ from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams -from ..worker import WorkerContext from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index afcae86b3..2cff2e181 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,10 +5,9 @@ from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..server import ServerContext from ..utils import run_gc from ..worker import WorkerContext -from ..server import ServerContext - logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 58020b57a..eac0f36cb 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -4,9 +4,8 @@ from ..output import save_image from ..params import ImageParams, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 926f1598b..3e01b9cec 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -5,9 +5,8 @@ from PIL import Image from ..params import ImageParams, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 4cd715b16..cce82f0ca 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -3,9 +3,8 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 4950a9731..6df2ed6e1 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -3,9 +3,8 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 9ab302b1e..0092292c7 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -4,9 +4,8 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext from ..server import ServerContext - +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index b933ecc97..1dec32439 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -7,9 +7,8 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext - +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 23393491e..695652050 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,9 +10,9 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_grid, process_tile_order logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index ccbb3644f..055319d0d 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -6,9 +6,9 @@ from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..worker import WorkerContext from ..server import ServerContext from ..utils import run_gc +from ..worker import WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 00c1b9d4a..0accc8548 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -10,9 +10,9 @@ OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..worker import WorkerContext, ProgressCallback from ..server import ServerContext from ..utils import run_gc +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 1fff3f1d1..e3b545105 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -4,11 +4,9 @@ import numpy as np from diffusers import ( - DiffusionPipeline, - OnnxRuntimeModel, - StableDiffusionPipeline, DDIMScheduler, DDPMScheduler, + DiffusionPipeline, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, @@ -19,7 +17,9 @@ KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, + OnnxRuntimeModel, PNDMScheduler, + StableDiffusionPipeline, ) try: diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 0e44b6cc4..2e92294cd 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -12,10 +12,10 @@ from ..chain import upscale_outpaint from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams -from ..worker import WorkerContext from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc +from ..worker import WorkerContext from .load import get_latents_from_seed, load_pipeline logger = getLogger(__name__) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 86a5896d2..3f7c9a844 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -7,49 +7,53 @@ from torch.multiprocessing import set_start_method from .server.api import register_api_routes -from .server.static import register_static_routes -from .server.config import get_available_platforms, load_models, load_params, load_platforms -from .server.utils import check_paths +from .server.config import ( + get_available_platforms, + load_models, + load_params, + load_platforms, +) from .server.context import ServerContext from .server.hacks import apply_patches -from .utils import ( - is_debug, -) +from .server.static import register_static_routes +from .server.utils import check_paths +from .utils import is_debug from .worker import DevicePoolExecutor def main(): - set_start_method("spawn", force=True) + set_start_method("spawn", force=True) - context = ServerContext.from_environ() - apply_patches(context) - check_paths(context) - load_models(context) - load_params(context) - load_platforms(context) + context = ServerContext.from_environ() + apply_patches(context) + check_paths(context) + load_models(context) + load_params(context) + load_platforms(context) - if is_debug(): - gc.set_debug(gc.DEBUG_STATS) + if is_debug(): + gc.set_debug(gc.DEBUG_STATS) - if not context.show_progress: - disable_progress_bar() - disable_progress_bars() + if not context.show_progress: + disable_progress_bar() + disable_progress_bars() - app = Flask(__name__) - CORS(app, origins=context.cors_origin) + app = Flask(__name__) + CORS(app, origins=context.cors_origin) - # any is a fake device, should not be in the pool - pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"]) + # any is a fake device, should not be in the pool + pool = DevicePoolExecutor( + context, [p for p in get_available_platforms() if p.device != "any"] + ) - # register routes - register_static_routes(app, context, pool) - register_api_routes(app, context, pool) + # register routes + register_static_routes(app, context, pool) + register_api_routes(app, context, pool) - return app, pool + return app, pool if __name__ == "__main__": - app, pool = main() - app.run("0.0.0.0", 5000, debug=is_debug()) - pool.join() - + app, pool = main() + app.run("0.0.0.0", 5000, debug=is_debug()) + pool.join() diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 97d5c8b06..42f00a4d4 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -3,9 +3,9 @@ import numpy as np import torch -from onnxruntime import InferenceSession, SessionOptions from ..server import ServerContext +from .torch_before_ort import InferenceSession, SessionOptions class OnnxTensor: diff --git a/api/onnx_web/onnx/torch_before_ort.py b/api/onnx_web/onnx/torch_before_ort.py new file mode 100644 index 000000000..506c14783 --- /dev/null +++ b/api/onnx_web/onnx/torch_before_ort.py @@ -0,0 +1,5 @@ +# this file exists to make sure torch is always imported before onnxruntime +# to work around https://github.com/microsoft/onnxruntime/issues/11092 + +import torch # NOQA +from onnxruntime import * # NOQA diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 9bd5e819e..f92328c67 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,8 +2,7 @@ from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union -import torch -from onnxruntime import GraphOptimizationLevel, SessionOptions +from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index ed7dc8c05..7ff611927 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -7,27 +7,7 @@ from jsonschema import validate from PIL import Image -from .context import ServerContext -from .utils import wrap_route -from ..worker.pool import DevicePoolExecutor - -from .config import ( - get_available_platforms, - get_config_params, - get_config_value, - get_correction_models, - get_diffusion_models, - get_inversion_models, - get_mask_filters, - get_noise_sources, - get_upscaling_models, -) -from .params import border_from_request, pipeline_from_request, upscale_from_request - -from ..chain import ( - CHAIN_STAGES, - ChainPipeline, -) +from ..chain import CHAIN_STAGES, ChainPipeline from ..diffusion.load import get_pipeline_schedulers from ..diffusion.run import ( run_blend_pipeline, @@ -36,16 +16,9 @@ run_txt2img_pipeline, run_upscale_pipeline, ) -from ..image import ( # mask filters; noise sources - valid_image, -) +from ..image import valid_image # mask filters; noise sources from ..output import json_params, make_output_name -from ..params import ( - Border, - StageParams, - TileOrder, - UpscaleParams, -) +from ..params import Border, StageParams, TileOrder, UpscaleParams from ..transformers import run_txt2txt_pipeline from ..utils import ( base_join, @@ -56,6 +29,21 @@ get_not_empty, get_size, ) +from ..worker.pool import DevicePoolExecutor +from .config import ( + get_available_platforms, + get_config_params, + get_config_value, + get_correction_models, + get_diffusion_models, + get_inversion_models, + get_mask_filters, + get_noise_sources, + get_upscaling_models, +) +from .context import ServerContext +from .params import border_from_request, pipeline_from_request, upscale_from_request +from .utils import wrap_route logger = getLogger(__name__) @@ -456,22 +444,38 @@ def status(context: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): - return [ - app.route("/api")(wrap_route(introspect, context, app=app)), - app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), - app.route("/api/settings/models")(wrap_route(list_models, context)), - app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), - app.route("/api/settings/params")(wrap_route(list_params, context)), - app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), - app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), - app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)), - app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)), - app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)), - app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)), - app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)), - app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)), - app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)), - app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)), - app.route("/api/ready")(wrap_route(ready, context, pool=pool)), - app.route("/api/status")(wrap_route(status, context, pool=pool)), - ] + return [ + app.route("/api")(wrap_route(introspect, context, app=app)), + app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), + app.route("/api/settings/models")(wrap_route(list_models, context)), + app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), + app.route("/api/settings/params")(wrap_route(list_params, context)), + app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), + app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), + app.route("/api/img2img", methods=["POST"])( + wrap_route(img2img, context, pool=pool) + ), + app.route("/api/txt2img", methods=["POST"])( + wrap_route(txt2img, context, pool=pool) + ), + app.route("/api/txt2txt", methods=["POST"])( + wrap_route(txt2txt, context, pool=pool) + ), + app.route("/api/inpaint", methods=["POST"])( + wrap_route(inpaint, context, pool=pool) + ), + app.route("/api/upscale", methods=["POST"])( + wrap_route(upscale, context, pool=pool) + ), + app.route("/api/chain", methods=["POST"])( + wrap_route(chain, context, pool=pool) + ), + app.route("/api/blend", methods=["POST"])( + wrap_route(blend, context, pool=pool) + ), + app.route("/api/cancel", methods=["PUT"])( + wrap_route(cancel, context, pool=pool) + ), + app.route("/api/ready")(wrap_route(ready, context, pool=pool)), + app.route("/api/status")(wrap_route(status, context, pool=pool)), + ] diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py index 011105489..71b24709d 100644 --- a/api/onnx_web/server/config.py +++ b/api/onnx_web/server/config.py @@ -2,13 +2,11 @@ from glob import glob from logging import getLogger from os import path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union import torch import yaml -from onnxruntime import get_available_providers -from .context import ServerContext from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -20,9 +18,9 @@ noise_source_normal, noise_source_uniform, ) -from ..params import ( - DeviceParams, -) +from ..onnx.torch_before_ort import get_available_providers +from ..params import DeviceParams +from .context import ServerContext logger = getLogger(__name__) @@ -221,4 +219,3 @@ def any_first_cpu_last(a: DeviceParams, b: DeviceParams): "available acceleration platforms: %s", ", ".join([str(p) for p in available_platforms]), ) - diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index b70ef79a1..c16d19efa 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -4,30 +4,24 @@ import numpy as np from flask import request -from .context import ServerContext - -from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models -from .utils import get_model_path - from ..diffusion.load import pipeline_schedulers -from ..params import ( - Border, - DeviceParams, - ImageParams, - Size, - UpscaleParams, -) -from ..utils import ( - get_and_clamp_float, - get_and_clamp_int, - get_from_list, - get_not_empty, +from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams +from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty +from .config import ( + get_available_platforms, + get_config_value, + get_correction_models, + get_upscaling_models, ) +from .context import ServerContext +from .utils import get_model_path logger = getLogger(__name__) -def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]: +def pipeline_from_request( + context: ServerContext, +) -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr # platform stuff @@ -43,9 +37,7 @@ def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImagePa lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(context, model) - scheduler = get_from_list( - request.args, "scheduler", pipeline_schedulers.keys() - ) + scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys()) if scheduler is None: scheduler = get_config_value("scheduler") diff --git a/api/onnx_web/server/static.py b/api/onnx_web/server/static.py index 296c8deb6..9a67bf0ed 100644 --- a/api/onnx_web/server/static.py +++ b/api/onnx_web/server/static.py @@ -2,9 +2,9 @@ from flask import Flask, send_from_directory -from .utils import wrap_route -from .context import ServerContext from ..worker.pool import DevicePoolExecutor +from .context import ServerContext +from .utils import wrap_route def serve_bundle_file(context: ServerContext, filename="index.html"): @@ -26,9 +26,11 @@ def output(context: ServerContext, filename: str): ) -def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): - return [ - app.route("/")(wrap_route(index, context)), - app.route("/")(wrap_route(index_path, context)), - app.route("/output/")(wrap_route(output, context)), - ] +def register_static_routes( + app: Flask, context: ServerContext, pool: DevicePoolExecutor +): + return [ + app.route("/")(wrap_route(index, context)), + app.route("/")(wrap_route(index_path, context)), + app.route("/output/")(wrap_route(output, context)), + ] diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py index 8dd359a1b..582b1c6a5 100644 --- a/api/onnx_web/server/utils.py +++ b/api/onnx_web/server/utils.py @@ -1,6 +1,6 @@ +from functools import partial, update_wrapper from os import makedirs, path from typing import Callable, Dict, List, Tuple -from functools import partial, update_wrapper from flask import Flask @@ -22,7 +22,12 @@ def get_model_path(context: ServerContext, model: str): return base_join(context.model_path, model) -def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]): +def register_routes( + app: Flask, + context: ServerContext, + pool: DevicePoolExecutor, + routes: List[Tuple[str, Dict, Callable]], +): pass diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 8636f8c1d..098ce817c 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -11,7 +11,7 @@ ) from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .server import ServerContext -from .worker import WorkerContext, ProgressCallback +from .worker import ProgressCallback, WorkerContext logger = getLogger(__name__) diff --git a/api/onnx_web/worker/__init__.py b/api/onnx_web/worker/__init__.py index 0ca5eefc7..c1f2d7949 100644 --- a/api/onnx_web/worker/__init__.py +++ b/api/onnx_web/worker/__init__.py @@ -1,2 +1,2 @@ from .context import WorkerContext, ProgressCallback -from .pool import DevicePoolExecutor \ No newline at end of file +from .pool import DevicePoolExecutor diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 59f55fddc..ae083509d 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -1,7 +1,8 @@ from logging import getLogger -from torch.multiprocessing import Queue, Value from typing import Any, Callable, Tuple +from torch.multiprocessing import Queue, Value + from ..params import DeviceParams logger = getLogger(__name__) @@ -9,6 +10,7 @@ ProgressCallback = Callable[[int, int, Any], None] + class WorkerContext: cancel: "Value[bool]" = None key: str = None diff --git a/api/onnx_web/worker/logging.py b/api/onnx_web/worker/logging.py index 39808a640..ab90a266f 100644 --- a/api/onnx_web/worker/logging.py +++ b/api/onnx_web/worker/logging.py @@ -1 +1 @@ -# TODO: queue-based logger \ No newline at end of file +# TODO: queue-based logger diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 34e5b70c1..810bb91a2 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,9 +1,10 @@ from collections import Counter from logging import getLogger from multiprocessing import Queue -from torch.multiprocessing import Lock, Process, Value from typing import Callable, Dict, List, Optional, Tuple +from torch.multiprocessing import Lock, Process, Value + from ..params import DeviceParams from ..server import ServerContext from .context import WorkerContext @@ -35,7 +36,7 @@ def __init__( self.pending = {} self.progress = {} self.workers = {} - self.jobs = {} # Dict[Output, Device] + self.jobs = {} # Dict[Output, Device] self.job_count = 0 # TODO: make this a method @@ -58,15 +59,21 @@ def __init__( cancel = Value("B", False, lock=lock) finished = Value("B", False) self.finished[name] = finished - progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? + progress = Value( + "I", 0 + ) # , lock=lock) # needs its own lock for some reason. TODO: why? self.progress[name] = progress pending = Queue() self.pending[name] = pending - context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished) + context = WorkerContext( + name, cancel, device, pending, progress, self.log_queue, finished + ) self.context[name] = context logger.debug("starting worker for device %s", device) - self.workers[name] = Process(target=worker_init, args=(lock, context, server)) + self.workers[name] = Process( + target=worker_init, args=(lock, context, server) + ) self.workers[name].start() def cancel(self, key: str) -> bool: @@ -78,7 +85,7 @@ def cancel(self, key: str) -> bool: raise NotImplementedError() def done(self, key: str) -> Tuple[Optional[bool], int]: - if not key in self.jobs: + if key not in self.jobs: logger.warn("checking status for unknown key: %s", key) return (None, 0) @@ -88,7 +95,6 @@ def done(self, key: str) -> Tuple[Optional[bool], int]: return (finished.value, progress.value) - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible if needs_device is not None: @@ -96,9 +102,7 @@ def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: if self.devices[i].device == needs_device.device: return i - pending = [ - self.pending[d.device].qsize() for d in self.devices - ] + pending = [self.pending[d.device].qsize() for d in self.devices] jobs = Counter(range(len(self.devices))) jobs.update(pending) @@ -128,7 +132,7 @@ def prune(self): finished_count - self.finished_limit, finished_count, ) - self.finished[:] = self.finished[-self.finished_limit:] + self.finished[:] = self.finished[-self.finished_limit :] def recycle(self): for name, proc in self.workers.items(): @@ -149,10 +153,11 @@ def recycle(self): lock = self.locks[name] logger.debug("starting worker for device %s", name) - self.workers[name] = Process(target=worker_init, args=(lock, context, self.server)) + self.workers[name] = Process( + target=worker_init, args=(lock, context, self.server) + ) self.workers[name].start() - def submit( self, key: str, @@ -171,7 +176,10 @@ def submit( self.prune() device_idx = self.get_next_device(needs_device=needs_device) logger.info( - "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx] + "assigning job %s to device %s: %s", + key, + device_idx, + self.devices[device_idx], ) device = self.devices[device_idx] @@ -180,7 +188,6 @@ def submit( self.jobs[key] = device.device - def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ ( diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 07c6bb025..dbd86896c 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,12 +1,12 @@ from logging import getLogger -import torch # has to come before ORT -from onnxruntime import get_available_providers -from torch.multiprocessing import Lock, Queue from traceback import format_exception + from setproctitle import setproctitle +from torch.multiprocessing import Lock, Queue -from .context import WorkerContext +from ..onnx.torch_before_ort import get_available_providers from ..server import ServerContext, apply_patches +from .context import WorkerContext logger = getLogger(__name__) @@ -29,7 +29,7 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): logger.info("checking in from worker, %s, %s", lock, get_available_providers()) apply_patches(server) - setproctitle("onnx-web worker: %s", context.device.device) + setproctitle("onnx-web worker: %s" % (context.device.device)) while True: job = context.pending.get() @@ -52,4 +52,3 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): except Exception as e: logger.error(format_exception(type(e), e, e.__traceback__)) - From d1961afdbc9dcf89c2c68c1967c223be8d07d1e6 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 14:36:32 -0600 Subject: [PATCH 13/40] re-implement cancellation --- api/onnx_web/worker/pool.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 810bb91a2..a58fd3c08 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -82,11 +82,22 @@ def cancel(self, key: str) -> bool: the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. """ - raise NotImplementedError() + if key not in self.jobs: + logger.warn("attempting to cancel unknown job: %s", key) + return False + + device = self.jobs[key] + cancel = self.context[device].cancel + logger.info("cancelling job %s on device %s", key, device) + + if cancel.get_lock(): + cancel.value = True + + return True def done(self, key: str) -> Tuple[Optional[bool], int]: if key not in self.jobs: - logger.warn("checking status for unknown key: %s", key) + logger.warn("checking status for unknown job: %s", key) return (None, 0) device = self.jobs[key] From 85118d17c656450126ec09f0c13310cc644db491 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 15:06:40 -0600 Subject: [PATCH 14/40] clear worker flags between jobs, attempt to record finished jobs again --- api/onnx_web/worker/context.py | 19 +++- api/onnx_web/worker/pool.py | 162 +++++++++++++++------------------ api/onnx_web/worker/worker.py | 38 ++++---- 3 files changed, 106 insertions(+), 113 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index ae083509d..056a2517f 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -20,12 +20,12 @@ class WorkerContext: def __init__( self, key: str, - cancel: "Value[bool]", device: DeviceParams, - pending: "Queue[Any]", - progress: "Value[int]", - logs: "Queue[str]", - finished: "Value[bool]", + cancel: "Value[bool]" = None, + finished: "Value[bool]" = None, + progress: "Value[int]" = None, + logs: "Queue[str]" = None, + pending: "Queue[Any]" = None, ): self.key = key self.cancel = cancel @@ -62,6 +62,15 @@ def set_cancel(self, cancel: bool = True) -> None: with self.cancel.get_lock(): self.cancel.value = cancel + def set_finished(self, finished: bool = True) -> None: + with self.finished.get_lock(): + self.finished.value = finished + def set_progress(self, progress: int) -> None: with self.progress.get_lock(): self.progress.value = progress + + def clear_flags(self) -> None: + self.set_cancel(False) + self.set_finished(False) + self.set_progress(0) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index a58fd3c08..fc58ba57c 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -3,7 +3,7 @@ from multiprocessing import Queue from typing import Callable, Dict, List, Optional, Tuple -from torch.multiprocessing import Lock, Process, Value +from torch.multiprocessing import Process, Value from ..params import DeviceParams from ..server import ServerContext @@ -14,67 +14,67 @@ class DevicePoolExecutor: + context: Dict[str, WorkerContext] = None devices: List[DeviceParams] = None - finished: Dict[str, "Value[bool]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None - progress: Dict[str, "Value[int]"] = None workers: Dict[str, Process] = None - jobs: Dict[str, str] = None + active_job: Dict[str, str] = None + finished: List[Tuple[str, int, bool]] = None def __init__( self, server: ServerContext, devices: List[DeviceParams], - finished_limit: int = 10, + max_jobs_per_worker: int = 10, + join_timeout: float = 5.0, ): self.server = server self.devices = devices - self.finished = {} - self.finished_limit = finished_limit + self.max_jobs_per_worker = max_jobs_per_worker + self.join_timeout = join_timeout + self.context = {} - self.locks = {} self.pending = {} - self.progress = {} self.workers = {} - self.jobs = {} # Dict[Output, Device] - self.job_count = 0 + self.active_job = {} + self.finished_jobs = 0 # TODO: turn this into a Dict per-worker - # TODO: make this a method - logger.debug("starting log worker") - self.log_queue = Queue() - log_lock = Lock() - self.locks["logger"] = log_lock - self.logger = Process(target=logger_init, args=(log_lock, self.log_queue)) - self.logger.start() + self.create_logger_worker() + for device in devices: + self.create_device_worker(device) logger.debug("testing log worker") self.log_queue.put("testing") - # create a pending queue and progress value for each device - for device in devices: - name = device.device - # TODO: make this a method - lock = Lock() - self.locks[name] = lock - cancel = Value("B", False, lock=lock) - finished = Value("B", False) - self.finished[name] = finished - progress = Value( - "I", 0 - ) # , lock=lock) # needs its own lock for some reason. TODO: why? - self.progress[name] = progress - pending = Queue() - self.pending[name] = pending - context = WorkerContext( - name, cancel, device, pending, progress, self.log_queue, finished - ) - self.context[name] = context + def create_logger_worker(self) -> None: + self.log_queue = Queue() + self.logger = Process(target=logger_init, args=(self.log_queue)) - logger.debug("starting worker for device %s", device) - self.workers[name] = Process( - target=worker_init, args=(lock, context, server) - ) - self.workers[name].start() + logger.debug("starting log worker") + self.logger.start() + + def create_device_worker(self, device: DeviceParams) -> None: + name = device.device + pending = Queue() + self.pending[name] = pending + context = WorkerContext( + name, + device, + cancel=Value("B", False), + finished=Value("B", False), + progress=Value("I", 0), + pending=pending, + logs=self.log_queue, + ) + self.context[name] = context + self.workers[name] = Process(target=worker_init, args=(context, self.server)) + + logger.debug("starting worker for device %s", device) + self.workers[name].start() + + def create_prune_worker(self) -> None: + # TODO: create a background thread to prune completed jobs + pass def cancel(self, key: str) -> bool: """ @@ -82,29 +82,34 @@ def cancel(self, key: str) -> bool: the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. """ - if key not in self.jobs: + if key not in self.active_job: logger.warn("attempting to cancel unknown job: %s", key) return False - device = self.jobs[key] - cancel = self.context[device].cancel + device = self.active_job[key] + context = self.context[device] logger.info("cancelling job %s on device %s", key, device) - if cancel.get_lock(): - cancel.value = True + if context.cancel.get_lock(): + context.cancel.value = True + # self.finished.append((key, context.progress.value, context.cancel.value)) maybe? return True def done(self, key: str) -> Tuple[Optional[bool], int]: - if key not in self.jobs: + if key not in self.active_job: logger.warn("checking status for unknown job: %s", key) return (None, 0) - device = self.jobs[key] - finished = self.finished[device] - progress = self.progress[device] + # TODO: prune here, maybe? - return (finished.value, progress.value) + device = self.active_job[key] + context = self.context[device] + + if context.finished.value is True: + self.finished.append((key, context.progress.value, context.cancel.value)) + + return (context.finished.value, context.progress.value) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -113,9 +118,8 @@ def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: if self.devices[i].device == needs_device.device: return i - pending = [self.pending[d.device].qsize() for d in self.devices] jobs = Counter(range(len(self.devices))) - jobs.update(pending) + jobs.update([self.pending[d.device].qsize() for d in self.devices]) queued = jobs.most_common() logger.debug("jobs queued by device: %s", queued) @@ -130,26 +134,16 @@ def join(self): for device, worker in self.workers.items(): if worker.is_alive(): logger.info("stopping worker for device %s", device) - worker.join(5) + worker.join(self.join_timeout) if self.logger.is_alive(): - self.logger.join(5) - - def prune(self): - finished_count = len(self.finished) - if finished_count > self.finished_limit: - logger.debug( - "pruning %s of %s finished jobs", - finished_count - self.finished_limit, - finished_count, - ) - self.finished[:] = self.finished[-self.finished_limit :] + self.logger.join(self.join_timeout) def recycle(self): for name, proc in self.workers.items(): if proc.is_alive(): logger.debug("shutting down worker for device %s", name) - proc.join(5) + proc.join(self.join_timeout) proc.terminate() else: logger.warning("worker for device %s has died", name) @@ -159,15 +153,8 @@ def recycle(self): logger.info("starting new workers") - for name in self.workers.keys(): - context = self.context[name] - lock = self.locks[name] - - logger.debug("starting worker for device %s", name) - self.workers[name] = Process( - target=worker_init, args=(lock, context, self.server) - ) - self.workers[name].start() + for device in self.devices: + self.create_device_worker(device) def submit( self, @@ -178,13 +165,12 @@ def submit( needs_device: Optional[DeviceParams] = None, **kwargs, ) -> None: - self.job_count += 1 - logger.debug("pool job count: %s", self.job_count) - if self.job_count > 10: + self.finished_jobs += 1 + logger.debug("pool job count: %s", self.finished_jobs) + if self.finished_jobs > self.max_jobs_per_worker: self.recycle() - self.job_count = 0 + self.finished_jobs = 0 - self.prune() device_idx = self.get_next_device(needs_device=needs_device) logger.info( "assigning job %s to device %s: %s", @@ -197,17 +183,19 @@ def submit( queue = self.pending[device.device] queue.put((fn, args, kwargs)) - self.jobs[key] = device.device + self.active_job[key] = device.device def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ ( - device.device, - self.pending[device.device].qsize(), - self.progress[device.device].value, - self.workers[device.device].is_alive(), + name, + self.workers[name].is_alive(), + context.pending.qsize(), + context.cancel.value, + context.finished.value, + context.progress.value, ) - for device in self.devices + for name, context in self.context.items() ] pending.extend(self.finished) return pending diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index dbd86896c..efb598f6f 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -2,7 +2,7 @@ from traceback import format_exception from setproctitle import setproctitle -from torch.multiprocessing import Lock, Queue +from torch.multiprocessing import Queue from ..onnx.torch_before_ort import get_available_providers from ..server import ServerContext, apply_patches @@ -11,12 +11,11 @@ logger = getLogger(__name__) -def logger_init(lock: Lock, logs: Queue): - with lock: - logger.info("checking in from logger, %s", lock) - +def logger_init(logs: Queue): setproctitle("onnx-web logger") + logger.info("checking in from logger, %s") + while True: job = logs.get() with open("worker.log", "w") as f: @@ -24,31 +23,28 @@ def logger_init(lock: Lock, logs: Queue): f.write(str(job) + "\n\n") -def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): - with lock: - logger.info("checking in from worker, %s, %s", lock, get_available_providers()) - +def worker_init(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) + logger.info("checking in from worker, %s, %s", get_available_providers()) + while True: job = context.pending.get() logger.info("got job: %s", job) try: fn, args, kwargs = job name = args[3][0] - logger.info("starting job: %s", name) - with context.finished.get_lock(): - context.finished.value = False - - with context.progress.get_lock(): - context.progress.value = 0 + logger.info("starting job: %s", name) + context.clear_flags() fn(context, *args, **kwargs) - logger.info("finished job: %s", name) - - with context.finished.get_lock(): - context.finished.value = True - + logger.info("job succeeded: %s", name) except Exception as e: - logger.error(format_exception(type(e), e, e.__traceback__)) + logger.error( + "error while running job: %s", + format_exception(type(e), e, e.__traceback__), + ) + finally: + context.set_finished() + logger.info("finished job: %s", name) From b931da1d2c11cde4c0369a8482478c042cf5aadd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 15:21:58 -0600 Subject: [PATCH 15/40] fix imports, lint --- api/launch-extras.bat | 2 +- api/launch-extras.sh | 2 +- api/launch.bat | 2 +- api/launch.sh | 2 +- api/onnx_web/main.py | 14 ++++++++++++-- api/onnx_web/onnx/onnx_net.py | 2 +- api/onnx_web/params.py | 2 +- api/onnx_web/server/config.py | 2 +- api/onnx_web/{onnx => }/torch_before_ort.py | 0 api/onnx_web/worker/pool.py | 2 +- api/onnx_web/worker/worker.py | 6 +++--- 11 files changed, 23 insertions(+), 13 deletions(-) rename api/onnx_web/{onnx => }/torch_before_ort.py (100%) diff --git a/api/launch-extras.bat b/api/launch-extras.bat index fa3c89083..2f0b95c0e 100644 --- a/api/launch-extras.bat +++ b/api/launch-extras.bat @@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json) python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN% echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app="onnx_web.serve:run" run --host=0.0.0.0 diff --git a/api/launch-extras.sh b/api/launch-extras.sh index f18e14c01..50572aa43 100755 --- a/api/launch-extras.sh +++ b/api/launch-extras.sh @@ -25,4 +25,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app='onnx_web.main:main()' run --host=0.0.0.0 +flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/launch.bat b/api/launch.bat index f589bd119..4ee27fd63 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..." python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN% echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app="onnx_web.serve:run" run --host=0.0.0.0 diff --git a/api/launch.sh b/api/launch.sh index 983e09308..55b6ff729 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -24,4 +24,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app='onnx_web.main:main()' run --host=0.0.0.0 +flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 3f7c9a844..bbfa10398 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -1,4 +1,6 @@ +import atexit import gc +from logging import getLogger from diffusers.utils.logging import disable_progress_bar from flask import Flask @@ -20,6 +22,8 @@ from .utils import is_debug from .worker import DevicePoolExecutor +logger = getLogger(__name__) + def main(): set_start_method("spawn", force=True) @@ -53,7 +57,13 @@ def main(): return app, pool -if __name__ == "__main__": +def run(): app, pool = main() + atexit.register(lambda: pool.join()) + return app + + +if __name__ == "__main__": + app = run() app.run("0.0.0.0", 5000, debug=is_debug()) - pool.join() + logger.info("shutting down app") diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 42f00a4d4..a974aff48 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -5,7 +5,7 @@ import torch from ..server import ServerContext -from .torch_before_ort import InferenceSession, SessionOptions +from ..torch_before_ort import InferenceSession, SessionOptions class OnnxTensor: diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index f92328c67..32440c072 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,7 +2,7 @@ from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from .onnx.torch_before_ort import GraphOptimizationLevel, SessionOptions +from .torch_before_ort import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py index 71b24709d..a5dc1a31e 100644 --- a/api/onnx_web/server/config.py +++ b/api/onnx_web/server/config.py @@ -18,8 +18,8 @@ noise_source_normal, noise_source_uniform, ) -from ..onnx.torch_before_ort import get_available_providers from ..params import DeviceParams +from ..torch_before_ort import get_available_providers from .context import ServerContext logger = getLogger(__name__) diff --git a/api/onnx_web/onnx/torch_before_ort.py b/api/onnx_web/torch_before_ort.py similarity index 100% rename from api/onnx_web/onnx/torch_before_ort.py rename to api/onnx_web/torch_before_ort.py diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index fc58ba57c..4f2b66ff8 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -48,7 +48,7 @@ def __init__( def create_logger_worker(self) -> None: self.log_queue = Queue() - self.logger = Process(target=logger_init, args=(self.log_queue)) + self.logger = Process(target=logger_init, args=(self.log_queue,)) logger.debug("starting log worker") self.logger.start() diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index efb598f6f..24a1c4f23 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -4,8 +4,8 @@ from setproctitle import setproctitle from torch.multiprocessing import Queue -from ..onnx.torch_before_ort import get_available_providers from ..server import ServerContext, apply_patches +from ..torch_before_ort import get_available_providers from .context import WorkerContext logger = getLogger(__name__) @@ -14,7 +14,7 @@ def logger_init(logs: Queue): setproctitle("onnx-web logger") - logger.info("checking in from logger, %s") + logger.info("checking in from logger") while True: job = logs.get() @@ -27,7 +27,7 @@ def worker_init(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) - logger.info("checking in from worker, %s, %s", get_available_providers()) + logger.info("checking in from worker, %s", get_available_providers()) while True: job = context.pending.get() From eb82e73e599d06e00f67c76c875e3e0d31f1fac3 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 15:26:54 -0600 Subject: [PATCH 16/40] initialize list of finished jobs --- api/onnx_web/worker/pool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 4f2b66ff8..ec5ac480b 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -37,6 +37,7 @@ def __init__( self.pending = {} self.workers = {} self.active_job = {} + self.finished = [] self.finished_jobs = 0 # TODO: turn this into a Dict per-worker self.create_logger_worker() From 525ee24e916e404449929ee72752568ac66a3487 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 20:09:42 -0600 Subject: [PATCH 17/40] track started and finished jobs --- api/onnx_web/chain/blend_mask.py | 5 +- api/onnx_web/diffusion/run.py | 6 +- api/onnx_web/server/config.py | 10 ++- api/onnx_web/server/utils.py | 8 +-- api/onnx_web/worker/context.py | 21 +++--- api/onnx_web/worker/pool.py | 113 ++++++++++++++++++++++--------- api/onnx_web/worker/worker.py | 12 ++-- 7 files changed, 116 insertions(+), 59 deletions(-) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index f7b68e6f3..5c53bd12b 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -3,9 +3,8 @@ from PIL import Image -from onnx_web.image import valid_image -from onnx_web.output import save_image - +from ..image import valid_image +from ..output import save_image from ..params import ImageParams, StageParams from ..server import ServerContext from ..utils import is_debug diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 2e92294cd..765d7de8d 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -6,10 +6,8 @@ from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from PIL import Image -from onnx_web.chain import blend_mask -from onnx_web.chain.base import ChainProgress - -from ..chain import upscale_outpaint +from ..chain import blend_mask, upscale_outpaint +from ..chain.base import ChainProgress from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams from ..server import ServerContext diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py index a5dc1a31e..c03921046 100644 --- a/api/onnx_web/server/config.py +++ b/api/onnx_web/server/config.py @@ -118,35 +118,42 @@ def load_models(context: ServerContext) -> None: ) diffusion_models = list(set(diffusion_models)) diffusion_models.sort() + logger.debug("loaded diffusion models from disk: %s", diffusion_models) correction_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) ] correction_models = list(set(correction_models)) correction_models.sort() + logger.debug("loaded correction models from disk: %s", correction_models) inversion_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) ] inversion_models = list(set(inversion_models)) inversion_models.sort() + logger.debug("loaded inversion models from disk: %s", inversion_models) upscaling_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) ] upscaling_models = list(set(upscaling_models)) upscaling_models.sort() + logger.debug("loaded upscaling models from disk: %s", upscaling_models) def load_params(context: ServerContext) -> None: global config_params + params_file = path.join(context.params_path, "params.json") + logger.debug("loading server parameters from file: %s", params_file) + with open(params_file, "r") as f: config_params = yaml.safe_load(f) if "platform" in config_params and context.default_platform is not None: logger.info( - "Overriding default platform from environment: %s", + "overriding default platform from environment: %s", context.default_platform, ) config_platform = config_params.get("platform", {}) @@ -157,6 +164,7 @@ def load_platforms(context: ServerContext) -> None: global available_platforms providers = list(get_available_providers()) + logger.debug("loading available platforms from providers: %s", providers) for potential in platform_providers: if ( diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py index 582b1c6a5..56cc33e0a 100644 --- a/api/onnx_web/server/utils.py +++ b/api/onnx_web/server/utils.py @@ -4,9 +4,8 @@ from flask import Flask -from onnx_web.utils import base_join -from onnx_web.worker.pool import DevicePoolExecutor - +from ..utils import base_join +from ..worker.pool import DevicePoolExecutor from .context import ServerContext @@ -28,7 +27,8 @@ def register_routes( pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]], ): - pass + for route, kwargs, method in routes: + app.route(route, **kwargs)(wrap_route(method, context, pool=pool)) def wrap_route(func, *args, **kwargs): diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 056a2517f..67bc4eb65 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -22,18 +22,20 @@ def __init__( key: str, device: DeviceParams, cancel: "Value[bool]" = None, - finished: "Value[bool]" = None, progress: "Value[int]" = None, + finished: "Queue[str]" = None, logs: "Queue[str]" = None, pending: "Queue[Any]" = None, + started: "Queue[Tuple[str, str]]" = None, ): self.key = key - self.cancel = cancel self.device = device - self.pending = pending + self.cancel = cancel self.progress = progress - self.logs = logs self.finished = finished + self.logs = logs + self.pending = pending + self.started = started def is_cancelled(self) -> bool: return self.cancel.value @@ -62,15 +64,16 @@ def set_cancel(self, cancel: bool = True) -> None: with self.cancel.get_lock(): self.cancel.value = cancel - def set_finished(self, finished: bool = True) -> None: - with self.finished.get_lock(): - self.finished.value = finished - def set_progress(self, progress: int) -> None: with self.progress.get_lock(): self.progress.value = progress + def put_finished(self, job: str) -> None: + self.finished.put((job, self.device.device)) + + def put_started(self, job: str) -> None: + self.started.put((job, self.device.device)) + def clear_flags(self) -> None: self.set_cancel(False) - self.set_finished(False) self.set_progress(0) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ec5ac480b..6b4bfd96d 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,9 +1,9 @@ from collections import Counter from logging import getLogger -from multiprocessing import Queue +from threading import Thread from typing import Callable, Dict, List, Optional, Tuple -from torch.multiprocessing import Process, Value +from torch.multiprocessing import Process, Queue, Value from ..params import DeviceParams from ..server import ServerContext @@ -14,12 +14,12 @@ class DevicePoolExecutor: - context: Dict[str, WorkerContext] = None + context: Dict[str, WorkerContext] = None # Device -> Context devices: List[DeviceParams] = None pending: Dict[str, "Queue[WorkerContext]"] = None workers: Dict[str, Process] = None - active_job: Dict[str, str] = None - finished: List[Tuple[str, int, bool]] = None + active_jobs: Dict[str, str] = None + finished_jobs: List[Tuple[str, int, bool]] = None def __init__( self, @@ -36,11 +36,15 @@ def __init__( self.context = {} self.pending = {} self.workers = {} - self.active_job = {} - self.finished = [] - self.finished_jobs = 0 # TODO: turn this into a Dict per-worker + self.active_jobs = {} + self.finished_jobs = [] + self.total_jobs = 0 # TODO: turn this into a Dict per-worker + + self.started = Queue() + self.finished = Queue() self.create_logger_worker() + self.create_queue_workers() for device in devices: self.create_device_worker(device) @@ -56,16 +60,23 @@ def create_logger_worker(self) -> None: def create_device_worker(self, device: DeviceParams) -> None: name = device.device - pending = Queue() - self.pending[name] = pending + + # reuse the queue if possible, to keep queued jobs + if name in self.pending: + pending = self.pending[name] + else: + pending = Queue() + self.pending[name] = pending + context = WorkerContext( name, device, cancel=Value("B", False), - finished=Value("B", False), progress=Value("I", 0), - pending=pending, + finished=self.finished, logs=self.log_queue, + pending=pending, + started=self.started, ) self.context[name] = context self.workers[name] = Process(target=worker_init, args=(context, self.server)) @@ -73,9 +84,32 @@ def create_device_worker(self, device: DeviceParams) -> None: logger.debug("starting worker for device %s", device) self.workers[name].start() - def create_prune_worker(self) -> None: - # TODO: create a background thread to prune completed jobs - pass + def create_queue_workers(self) -> None: + def started_worker(pending: Queue): + logger.info("checking in from started thread") + while True: + job, device = pending.get() + logger.info("job has been started: %s", job) + self.active_jobs[device] = job + + def finished_worker(finished: Queue): + logger.info("checking in from finished thread") + while True: + job, device = finished.get() + logger.info("job has been finished: %s", job) + context = self.get_job_context(job) + self.finished_jobs.append( + (job, context.progress.value, context.cancel.value) + ) + + self.started_thread = Thread(target=started_worker, args=(self.started,)) + self.started_thread.start() + self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) + self.finished_thread.start() + + def get_job_context(self, key: str) -> WorkerContext: + device = self.active_jobs[key] + return self.context[device] def cancel(self, key: str) -> bool: """ @@ -83,11 +117,11 @@ def cancel(self, key: str) -> bool: the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. """ - if key not in self.active_job: + if key not in self.active_jobs: logger.warn("attempting to cancel unknown job: %s", key) return False - device = self.active_job[key] + device = self.active_jobs[key] context = self.context[device] logger.info("cancelling job %s on device %s", key, device) @@ -98,19 +132,17 @@ def cancel(self, key: str) -> bool: return True def done(self, key: str) -> Tuple[Optional[bool], int]: - if key not in self.active_job: + for k, p, c in self.finished_jobs: + if k == key: + return (c, p) + + if key not in self.active_jobs: logger.warn("checking status for unknown job: %s", key) return (None, 0) # TODO: prune here, maybe? - - device = self.active_job[key] - context = self.context[device] - - if context.finished.value is True: - self.finished.append((key, context.progress.value, context.cancel.value)) - - return (context.finished.value, context.progress.value) + context = self.get_job_context(key) + return (False, context.progress.value) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -132,6 +164,9 @@ def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: return lowest_devices[0] def join(self): + self.started_thread.join(self.join_timeout) + self.finished_thread.join(self.join_timeout) + for device, worker in self.workers.items(): if worker.is_alive(): logger.info("stopping worker for device %s", device) @@ -166,11 +201,11 @@ def submit( needs_device: Optional[DeviceParams] = None, **kwargs, ) -> None: - self.finished_jobs += 1 - logger.debug("pool job count: %s", self.finished_jobs) - if self.finished_jobs > self.max_jobs_per_worker: + self.total_jobs += 1 + logger.debug("pool job count: %s", self.total_jobs) + if self.total_jobs > self.max_jobs_per_worker: self.recycle() - self.finished_jobs = 0 + self.total_jobs = 0 device_idx = self.get_next_device(needs_device=needs_device) logger.info( @@ -184,7 +219,7 @@ def submit( queue = self.pending[device.device] queue.put((fn, args, kwargs)) - self.active_job[key] = device.device + self.active_jobs[key] = device.device def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ @@ -193,10 +228,22 @@ def status(self) -> List[Tuple[str, int, bool, int]]: self.workers[name].is_alive(), context.pending.qsize(), context.cancel.value, - context.finished.value, + False, context.progress.value, ) for name, context in self.context.items() ] - pending.extend(self.finished) + pending.extend( + [ + ( + name, + False, + 0, + cancel, + True, + progress, + ) + for name, progress, cancel in self.finished_jobs + ] + ) return pending diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 24a1c4f23..db23540f2 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -32,12 +32,14 @@ def worker_init(context: WorkerContext, server: ServerContext): while True: job = context.pending.get() logger.info("got job: %s", job) - try: - fn, args, kwargs = job - name = args[3][0] - logger.info("starting job: %s", name) + fn, args, kwargs = job + name = args[3][0] + + try: context.clear_flags() + logger.info("starting job: %s", name) + context.put_started(name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) except Exception as e: @@ -46,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext): format_exception(type(e), e, e.__traceback__), ) finally: - context.set_finished() + context.put_finished(name) logger.info("finished job: %s", name) From 401ee20526207557fc69a3e2bfa48a83bf8b7c15 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 20:13:16 -0600 Subject: [PATCH 18/40] fix finished flag --- api/onnx_web/worker/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 6b4bfd96d..97cc683a6 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -134,7 +134,7 @@ def cancel(self, key: str) -> bool: def done(self, key: str) -> Tuple[Optional[bool], int]: for k, p, c in self.finished_jobs: if k == key: - return (c, p) + return (True, p) if key not in self.active_jobs: logger.warn("checking status for unknown job: %s", key) From a37d1a455015e6328d4718fa38082aa337cf34f7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 20:37:22 -0600 Subject: [PATCH 19/40] use progress queue --- api/onnx_web/worker/context.py | 16 ++++-------- api/onnx_web/worker/pool.py | 45 +++++++++++++++++----------------- api/onnx_web/worker/worker.py | 4 +-- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 67bc4eb65..2daf23d01 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -22,11 +22,10 @@ def __init__( key: str, device: DeviceParams, cancel: "Value[bool]" = None, - progress: "Value[int]" = None, - finished: "Queue[str]" = None, logs: "Queue[str]" = None, pending: "Queue[Any]" = None, - started: "Queue[Tuple[str, str]]" = None, + progress: "Queue[Tuple[str, int]]" = None, + finished: "Queue[str]" = None, ): self.key = key self.device = device @@ -35,7 +34,6 @@ def __init__( self.finished = finished self.logs = logs self.pending = pending - self.started = started def is_cancelled(self) -> bool: return self.cancel.value @@ -65,14 +63,10 @@ def set_cancel(self, cancel: bool = True) -> None: self.cancel.value = cancel def set_progress(self, progress: int) -> None: - with self.progress.get_lock(): - self.progress.value = progress - - def put_finished(self, job: str) -> None: - self.finished.put((job, self.device.device)) + self.progress.put((self.key, self.device.device, progress)) - def put_started(self, job: str) -> None: - self.started.put((job, self.device.device)) + def set_finished(self) -> None: + self.finished.put((self.key, self.device.device)) def clear_flags(self) -> None: self.set_cancel(False) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 97cc683a6..589b79389 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -18,7 +18,7 @@ class DevicePoolExecutor: devices: List[DeviceParams] = None pending: Dict[str, "Queue[WorkerContext]"] = None workers: Dict[str, Process] = None - active_jobs: Dict[str, str] = None + active_jobs: Dict[str, Tuple[str, int]] = None finished_jobs: List[Tuple[str, int, bool]] = None def __init__( @@ -40,7 +40,7 @@ def __init__( self.finished_jobs = [] self.total_jobs = 0 # TODO: turn this into a Dict per-worker - self.started = Queue() + self.progress = Queue() self.finished = Queue() self.create_logger_worker() @@ -72,11 +72,10 @@ def create_device_worker(self, device: DeviceParams) -> None: name, device, cancel=Value("B", False), - progress=Value("I", 0), + progress=self.progress, finished=self.finished, logs=self.log_queue, pending=pending, - started=self.started, ) self.context[name] = context self.workers[name] = Process(target=worker_init, args=(context, self.server)) @@ -85,30 +84,32 @@ def create_device_worker(self, device: DeviceParams) -> None: self.workers[name].start() def create_queue_workers(self) -> None: - def started_worker(pending: Queue): - logger.info("checking in from started thread") + def progress_worker(progress: Queue): + logger.info("checking in from progress worker thread") while True: - job, device = pending.get() - logger.info("job has been started: %s", job) - self.active_jobs[device] = job + job, device, value = progress.get() + logger.info("progress update for job: %s, %s", job, value) + self.active_jobs[job] = (device, value) def finished_worker(finished: Queue): - logger.info("checking in from finished thread") + logger.info("checking in from finished worker thread") while True: job, device = finished.get() logger.info("job has been finished: %s", job) - context = self.get_job_context(job) + context = self.context[device] + _device, progress = self.active_jobs[job] self.finished_jobs.append( - (job, context.progress.value, context.cancel.value) + (job, progress, context.cancel.value) ) + del self.active_jobs[job] - self.started_thread = Thread(target=started_worker, args=(self.started,)) - self.started_thread.start() + self.progress_thread = Thread(target=progress_worker, args=(self.progress,)) + self.progress_thread.start() self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) self.finished_thread.start() def get_job_context(self, key: str) -> WorkerContext: - device = self.active_jobs[key] + device, _progress = self.active_jobs[key] return self.context[device] def cancel(self, key: str) -> bool: @@ -141,8 +142,8 @@ def done(self, key: str) -> Tuple[Optional[bool], int]: return (None, 0) # TODO: prune here, maybe? - context = self.get_job_context(key) - return (False, context.progress.value) + _device, progress = self.active_jobs[key] + return (False, progress) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -164,7 +165,7 @@ def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: return lowest_devices[0] def join(self): - self.started_thread.join(self.join_timeout) + self.progress_thread.join(self.join_timeout) self.finished_thread.join(self.join_timeout) for device, worker in self.workers.items(): @@ -226,12 +227,12 @@ def status(self) -> List[Tuple[str, int, bool, int]]: ( name, self.workers[name].is_alive(), - context.pending.qsize(), - context.cancel.value, + self.context[device].pending.qsize(), + self.context[device].cancel.value, False, - context.progress.value, + progress, ) - for name, context in self.context.items() + for name, device, progress in self.active_jobs ] pending.extend( [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index db23540f2..cbd3afa7f 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -37,9 +37,9 @@ def worker_init(context: WorkerContext, server: ServerContext): name = args[3][0] try: + context.key = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) - context.put_started(name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) except Exception as e: @@ -48,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext): format_exception(type(e), e, e.__traceback__), ) finally: - context.put_finished(name) + context.set_finished() logger.info("finished job: %s", name) From 13395933dc5ea12b82f12a00dae14654a6307fc6 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 20:41:16 -0600 Subject: [PATCH 20/40] always put progress in active jobs --- api/onnx_web/worker/pool.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 589b79389..ce45ed8af 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -122,7 +122,7 @@ def cancel(self, key: str) -> bool: logger.warn("attempting to cancel unknown job: %s", key) return False - device = self.active_jobs[key] + device, _progress = self.active_jobs[key] context = self.context[device] logger.info("cancelling job %s on device %s", key, device) @@ -220,7 +220,7 @@ def submit( queue = self.pending[device.device] queue.put((fn, args, kwargs)) - self.active_jobs[key] = device.device + self.active_jobs[key] = (device.device, 0) def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ @@ -232,7 +232,7 @@ def status(self) -> List[Tuple[str, int, bool, int]]: False, progress, ) - for name, device, progress in self.active_jobs + for name, device, progress in self.active_jobs.items() ] pending.extend( [ From 66a20e60fef15234301902a079f7e1959c49f3df Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 17:14:53 -0600 Subject: [PATCH 21/40] run logger in a thread, clean up status --- api/onnx_web/server/api.py | 2 +- api/onnx_web/worker/context.py | 14 +++ api/onnx_web/worker/pool.py | 177 +++++++++++++++++++-------------- api/onnx_web/worker/worker.py | 16 +-- 4 files changed, 117 insertions(+), 92 deletions(-) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 7ff611927..599212463 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -402,7 +402,7 @@ def blend(context: ServerContext, pool: DevicePoolExecutor): def txt2txt(context: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(context) - output = make_output_name(context, "upscale", params, size) + output = make_output_name(context, "txt2txt", params, size) logger.info("upscale job queued for: %s", output) pool.submit( diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 2daf23d01..1ef564a78 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -71,3 +71,17 @@ def set_finished(self) -> None: def clear_flags(self) -> None: self.set_cancel(False) self.set_progress(0) + + +class JobStatus: + def __init__( + self, + name: str, + progress: int = 0, + cancelled: bool = False, + finished: bool = False, + ) -> None: + self.name = name + self.progress = progress + self.cancelled = cancelled + self.finished = finished diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ce45ed8af..dced875da 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -8,7 +8,7 @@ from ..params import DeviceParams from ..server import ServerContext from .context import WorkerContext -from .worker import logger_init, worker_init +from .worker import worker_main logger = getLogger(__name__) @@ -18,15 +18,15 @@ class DevicePoolExecutor: devices: List[DeviceParams] = None pending: Dict[str, "Queue[WorkerContext]"] = None workers: Dict[str, Process] = None - active_jobs: Dict[str, Tuple[str, int]] = None - finished_jobs: List[Tuple[str, int, bool]] = None + active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus] + finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus] def __init__( self, server: ServerContext, devices: List[DeviceParams], max_jobs_per_worker: int = 10, - join_timeout: float = 5.0, + join_timeout: float = 1.0, ): self.server = server self.devices = devices @@ -35,28 +35,27 @@ def __init__( self.context = {} self.pending = {} + self.threads = {} self.workers = {} + self.active_jobs = {} + self.cancelled_jobs = [] self.finished_jobs = [] self.total_jobs = 0 # TODO: turn this into a Dict per-worker + self.logs = Queue() self.progress = Queue() self.finished = Queue() self.create_logger_worker() - self.create_queue_workers() + self.create_progress_worker() + self.create_finished_worker() + for device in devices: self.create_device_worker(device) logger.debug("testing log worker") - self.log_queue.put("testing") - - def create_logger_worker(self) -> None: - self.log_queue = Queue() - self.logger = Process(target=logger_init, args=(self.log_queue,)) - - logger.debug("starting log worker") - self.logger.start() + self.logs.put("testing") def create_device_worker(self, device: DeviceParams) -> None: name = device.device @@ -74,23 +73,54 @@ def create_device_worker(self, device: DeviceParams) -> None: cancel=Value("B", False), progress=self.progress, finished=self.finished, - logs=self.log_queue, + logs=self.logs, pending=pending, ) self.context[name] = context - self.workers[name] = Process(target=worker_init, args=(context, self.server)) + self.workers[name] = Process(target=worker_main, args=(context, self.server)) logger.debug("starting worker for device %s", device) self.workers[name].start() - def create_queue_workers(self) -> None: + def create_logger_worker(self) -> None: + def logger_worker(logs: Queue): + logger.info("checking in from logger worker thread") + + while True: + job = logs.get() + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + + logger_thread = Thread(target=logger_worker, args=(self.logs,)) + self.threads["logger"] = logger_thread + + logger.debug("starting logger worker") + logger_thread.start() + + def create_progress_worker(self) -> None: def progress_worker(progress: Queue): logger.info("checking in from progress worker thread") while True: - job, device, value = progress.get() - logger.info("progress update for job: %s, %s", job, value) - self.active_jobs[job] = (device, value) - + try: + job, device, value = progress.get() + logger.info("progress update for job: %s to %s", job, value) + self.active_jobs[job] = (device, value) + if job in self.cancelled_jobs: + logger.debug( + "setting flag for cancelled job: %s on %s", job, device + ) + self.context[device].set_cancel() + except Exception as err: + logger.error("error during progress update", err) + + progress_thread = Thread(target=progress_worker, args=(self.progress,)) + self.threads["progress"] = progress_thread + + logger.debug("starting progress worker") + progress_thread.start() + + def create_finished_worker(self) -> None: def finished_worker(finished: Queue): logger.info("checking in from finished worker thread") while True: @@ -98,41 +128,66 @@ def finished_worker(finished: Queue): logger.info("job has been finished: %s", job) context = self.context[device] _device, progress = self.active_jobs[job] - self.finished_jobs.append( - (job, progress, context.cancel.value) - ) + self.finished_jobs.append((job, progress, context.cancel.value)) del self.active_jobs[job] - self.progress_thread = Thread(target=progress_worker, args=(self.progress,)) - self.progress_thread.start() - self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) - self.finished_thread.start() + finished_thread = Thread(target=finished_worker, args=(self.finished,)) + self.thread["finished"] = finished_thread + + logger.debug("started finished worker") + finished_thread.start() def get_job_context(self, key: str) -> WorkerContext: device, _progress = self.active_jobs[key] return self.context[device] + def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: + # respect overrides if possible + if needs_device is not None: + for i in range(len(self.devices)): + if self.devices[i].device == needs_device.device: + return i + + jobs = Counter(range(len(self.devices))) + jobs.update([self.pending[d.device].qsize() for d in self.devices]) + + queued = jobs.most_common() + logger.debug("jobs queued by device: %s", queued) + + lowest_count = queued[-1][1] + lowest_devices = [d[0] for d in queued if d[1] == lowest_count] + lowest_devices.sort() + + return lowest_devices[0] + def cancel(self, key: str) -> bool: """ Cancel a job. If the job has not been started, this will cancel the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. """ + + self.cancelled_jobs.append(key) + if key not in self.active_jobs: - logger.warn("attempting to cancel unknown job: %s", key) - return False + logger.debug("cancelled job has not been started yet: %s", key) + return True device, _progress = self.active_jobs[key] - context = self.context[device] - logger.info("cancelling job %s on device %s", key, device) + logger.info("cancelling job %s, active on device %s", key, device) - if context.cancel.get_lock(): - context.cancel.value = True + context = self.context[device] + context.set_cancel() - # self.finished.append((key, context.progress.value, context.cancel.value)) maybe? return True def done(self, key: str) -> Tuple[Optional[bool], int]: + """ + Check if a job has been finished and report the last progress update. + + If the job is still active or pending, the first item will be False. + If the job is not finished or active, the first item will be None. + """ for k, p, c in self.finished_jobs: if k == key: return (True, p) @@ -141,29 +196,9 @@ def done(self, key: str) -> Tuple[Optional[bool], int]: logger.warn("checking status for unknown job: %s", key) return (None, 0) - # TODO: prune here, maybe? _device, progress = self.active_jobs[key] return (False, progress) - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: - # respect overrides if possible - if needs_device is not None: - for i in range(len(self.devices)): - if self.devices[i].device == needs_device.device: - return i - - jobs = Counter(range(len(self.devices))) - jobs.update([self.pending[d.device].qsize() for d in self.devices]) - - queued = jobs.most_common() - logger.debug("jobs queued by device: %s", queued) - - lowest_count = queued[-1][1] - lowest_devices = [d[0] for d in queued if d[1] == lowest_count] - lowest_devices.sort() - - return lowest_devices[0] - def join(self): self.progress_thread.join(self.join_timeout) self.finished_thread.join(self.join_timeout) @@ -216,35 +251,23 @@ def submit( self.devices[device_idx], ) - device = self.devices[device_idx] - queue = self.pending[device.device] - queue.put((fn, args, kwargs)) - - self.active_jobs[key] = (device.device, 0) - - def status(self) -> List[Tuple[str, int, bool, int]]: - pending = [ - ( - name, - self.workers[name].is_alive(), - self.context[device].pending.qsize(), - self.context[device].cancel.value, - False, - progress, - ) - for name, device, progress in self.active_jobs.items() + device = self.devices[device_idx].device + self.pending[device].put((fn, args, kwargs)) + + def status(self) -> List[Tuple[str, int, bool, bool]]: + history = [ + (name, progress, False, name in self.cancelled_jobs) + for name, _device, progress in self.active_jobs.items() ] - pending.extend( + history.extend( [ ( name, - False, - 0, - cancel, - True, progress, + True, + cancel, ) for name, progress, cancel in self.finished_jobs ] ) - return pending + return history diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index cbd3afa7f..9518f948b 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -11,19 +11,7 @@ logger = getLogger(__name__) -def logger_init(logs: Queue): - setproctitle("onnx-web logger") - - logger.info("checking in from logger") - - while True: - job = logs.get() - with open("worker.log", "w") as f: - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") - - -def worker_init(context: WorkerContext, server: ServerContext): +def worker_main(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) @@ -37,7 +25,7 @@ def worker_init(context: WorkerContext, server: ServerContext): name = args[3][0] try: - context.key = name # TODO: hax + context.key = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) fn(context, *args, **kwargs) From 2327b2402217208f5be6c3ee109bbb3c5a00ee2f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 17:35:31 -0600 Subject: [PATCH 22/40] join all threads --- api/onnx_web/worker/context.py | 12 ++++++------ api/onnx_web/worker/pool.py | 8 +++----- api/onnx_web/worker/worker.py | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 1ef564a78..a69e28a21 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -13,13 +13,13 @@ class WorkerContext: cancel: "Value[bool]" = None - key: str = None + job: str = None pending: "Queue[Tuple[Callable, Any, Any]]" = None progress: "Value[int]" = None def __init__( self, - key: str, + job: str, device: DeviceParams, cancel: "Value[bool]" = None, logs: "Queue[str]" = None, @@ -27,7 +27,7 @@ def __init__( progress: "Queue[Tuple[str, int]]" = None, finished: "Queue[str]" = None, ): - self.key = key + self.job = job self.device = device self.cancel = cancel self.progress = progress @@ -53,7 +53,7 @@ def on_progress(step: int, timestep: int, latents: Any): if self.is_cancelled(): raise RuntimeError("job has been cancelled") else: - logger.debug("setting progress for job %s to %s", self.key, step) + logger.debug("setting progress for job %s to %s", self.job, step) self.set_progress(step) return on_progress @@ -63,10 +63,10 @@ def set_cancel(self, cancel: bool = True) -> None: self.cancel.value = cancel def set_progress(self, progress: int) -> None: - self.progress.put((self.key, self.device.device, progress)) + self.progress.put((self.job, self.device.device, progress)) def set_finished(self) -> None: - self.finished.put((self.key, self.device.device)) + self.finished.put((self.job, self.device.device)) def clear_flags(self) -> None: self.set_cancel(False) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index dced875da..522533d7d 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -200,16 +200,14 @@ def done(self, key: str) -> Tuple[Optional[bool], int]: return (False, progress) def join(self): - self.progress_thread.join(self.join_timeout) - self.finished_thread.join(self.join_timeout) - for device, worker in self.workers.items(): if worker.is_alive(): logger.info("stopping worker for device %s", device) worker.join(self.join_timeout) - if self.logger.is_alive(): - self.logger.join(self.join_timeout) + for name, thread in self.threads.items(): + logger.info("stopping worker thread: %s", name) + thread.join(self.join_timeout) def recycle(self): for name, proc in self.workers.items(): diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 9518f948b..d6e2d0c98 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -25,7 +25,7 @@ def worker_main(context: WorkerContext, server: ServerContext): name = args[3][0] try: - context.key = name # TODO: hax + context.job = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) fn(context, *args, **kwargs) From 113ad052933cea9043bf6d319cc7a2119cea883a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 17:36:26 -0600 Subject: [PATCH 23/40] typo --- api/onnx_web/worker/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 522533d7d..e9119c2a6 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -132,7 +132,7 @@ def finished_worker(finished: Queue): del self.active_jobs[job] finished_thread = Thread(target=finished_worker, args=(self.finished,)) - self.thread["finished"] = finished_thread + self.threads["finished"] = finished_thread logger.debug("started finished worker") finished_thread.start() From 06f06f5a112b1f8f2680879b56d6bdcb992d05be Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 19:48:51 -0600 Subject: [PATCH 24/40] error handling in all threads --- api/onnx_web/worker/pool.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index e9119c2a6..803d47738 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -87,10 +87,13 @@ def logger_worker(logs: Queue): logger.info("checking in from logger worker thread") while True: - job = logs.get() - with open("worker.log", "w") as f: - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") + try: + job = logs.get() + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + except Exception as err: + logger.error("error in log worker: %s", err) logger_thread = Thread(target=logger_worker, args=(self.logs,)) self.threads["logger"] = logger_thread @@ -112,7 +115,7 @@ def progress_worker(progress: Queue): ) self.context[device].set_cancel() except Exception as err: - logger.error("error during progress update", err) + logger.error("error in progress worker: %s", err) progress_thread = Thread(target=progress_worker, args=(self.progress,)) self.threads["progress"] = progress_thread @@ -124,12 +127,15 @@ def create_finished_worker(self) -> None: def finished_worker(finished: Queue): logger.info("checking in from finished worker thread") while True: - job, device = finished.get() - logger.info("job has been finished: %s", job) - context = self.context[device] - _device, progress = self.active_jobs[job] - self.finished_jobs.append((job, progress, context.cancel.value)) - del self.active_jobs[job] + try: + job, device = finished.get() + logger.info("job has been finished: %s", job) + context = self.context[device] + _device, progress = self.active_jobs[job] + self.finished_jobs.append((job, progress, context.cancel.value)) + del self.active_jobs[job] + except Exception as err: + logger.error("error in finished worker: %s", err) finished_thread = Thread(target=finished_worker, args=(self.finished,)) self.threads["finished"] = finished_thread From 61373d530a4c1d21d46eb0194460f64482b73796 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 20:08:52 -0600 Subject: [PATCH 25/40] fix Windows entrypoint --- api/launch-extras.bat | 2 +- api/launch.bat | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/launch-extras.bat b/api/launch-extras.bat index 2f0b95c0e..126df2c69 100644 --- a/api/launch-extras.bat +++ b/api/launch-extras.bat @@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json) python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN% echo "Launching API server..." -flask --app="onnx_web.serve:run" run --host=0.0.0.0 +flask --app="onnx_web.main:run" run --host=0.0.0.0 diff --git a/api/launch.bat b/api/launch.bat index 4ee27fd63..09ddfccd2 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..." python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN% echo "Launching API server..." -flask --app="onnx_web.serve:run" run --host=0.0.0.0 +flask --app="onnx_web.main:run" run --host=0.0.0.0 From 0793b61c3a9f0a637db829646c0bb08b45ad0f81 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 22:25:53 -0600 Subject: [PATCH 26/40] consistently pass job key to workers --- api/onnx_web/worker/pool.py | 2 +- api/onnx_web/worker/worker.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 803d47738..9b8925f33 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -256,7 +256,7 @@ def submit( ) device = self.devices[device_idx].device - self.pending[device].put((fn, args, kwargs)) + self.pending[device].put((key, fn, args, kwargs)) def status(self) -> List[Tuple[str, int, bool, bool]]: history = [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index d6e2d0c98..6755cd1f7 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -18,11 +18,8 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("checking in from worker, %s", get_available_providers()) while True: - job = context.pending.get() - logger.info("got job: %s", job) - - fn, args, kwargs = job - name = args[3][0] + name, fn, args, kwargs = context.pending.get() + logger.info("worker for %s got job: %s", context.device.device, name) try: context.job = name # TODO: hax From 136759285ddf9e06975727e4dc32db659f858e9d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 22:37:43 -0600 Subject: [PATCH 27/40] set queue timeouts --- api/onnx_web/worker/context.py | 2 +- api/onnx_web/worker/pool.py | 8 ++++---- api/onnx_web/worker/worker.py | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index a69e28a21..0f1af7e2c 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -63,7 +63,7 @@ def set_cancel(self, cancel: bool = True) -> None: self.cancel.value = cancel def set_progress(self, progress: int) -> None: - self.progress.put((self.job, self.device.device, progress)) + self.progress.put((self.job, self.device.device, progress), block=False) def set_finished(self) -> None: self.finished.put((self.job, self.device.device)) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 9b8925f33..4105ac572 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -88,7 +88,7 @@ def logger_worker(logs: Queue): while True: try: - job = logs.get() + job = logs.get(timeout=(self.join_timeout / 2)) with open("worker.log", "w") as f: logger.info("got log: %s", job) f.write(str(job) + "\n\n") @@ -106,7 +106,7 @@ def progress_worker(progress: Queue): logger.info("checking in from progress worker thread") while True: try: - job, device, value = progress.get() + job, device, value = progress.get(timeout=(self.join_timeout / 2)) logger.info("progress update for job: %s to %s", job, value) self.active_jobs[job] = (device, value) if job in self.cancelled_jobs: @@ -128,7 +128,7 @@ def finished_worker(finished: Queue): logger.info("checking in from finished worker thread") while True: try: - job, device = finished.get() + job, device = finished.get(timeout=(self.join_timeout / 2)) logger.info("job has been finished: %s", job) context = self.context[device] _device, progress = self.active_jobs[job] @@ -256,7 +256,7 @@ def submit( ) device = self.devices[device_idx].device - self.pending[device].put((key, fn, args, kwargs)) + self.pending[device].put((key, fn, args, kwargs), block=False) def status(self) -> List[Tuple[str, int, bool, bool]]: history = [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 6755cd1f7..ef925a337 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,4 +1,5 @@ from logging import getLogger +from queue import Empty from traceback import format_exception from setproctitle import setproctitle @@ -18,7 +19,7 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("checking in from worker, %s", get_available_providers()) while True: - name, fn, args, kwargs = context.pending.get() + name, fn, args, kwargs = context.pending.get(timeout=1.0) logger.info("worker for %s got job: %s", context.device.device, name) try: @@ -27,6 +28,8 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("starting job: %s", name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) + except Empty: + pass except Exception as e: logger.error( "error while running job: %s", From 953e5abd3631d5723ae21662c11b1ceff3bc39fd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 22:45:29 -0600 Subject: [PATCH 28/40] handle empty errors --- api/onnx_web/worker/pool.py | 7 +++++++ api/onnx_web/worker/worker.py | 10 ++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 4105ac572..3bede110a 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,5 +1,6 @@ from collections import Counter from logging import getLogger +from queue import Empty from threading import Thread from typing import Callable, Dict, List, Optional, Tuple @@ -92,6 +93,8 @@ def logger_worker(logs: Queue): with open("worker.log", "w") as f: logger.info("got log: %s", job) f.write(str(job) + "\n\n") + except Empty: + pass except Exception as err: logger.error("error in log worker: %s", err) @@ -114,6 +117,8 @@ def progress_worker(progress: Queue): "setting flag for cancelled job: %s on %s", job, device ) self.context[device].set_cancel() + except Empty: + pass except Exception as err: logger.error("error in progress worker: %s", err) @@ -134,6 +139,8 @@ def finished_worker(finished: Queue): _device, progress = self.active_jobs[job] self.finished_jobs.append((job, progress, context.cancel.value)) del self.active_jobs[job] + except Empty: + pass except Exception as err: logger.error("error in finished worker: %s", err) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index ef925a337..ff0e9aae3 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -19,15 +19,16 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("checking in from worker, %s", get_available_providers()) while True: - name, fn, args, kwargs = context.pending.get(timeout=1.0) - logger.info("worker for %s got job: %s", context.device.device, name) - try: + name, fn, args, kwargs = context.pending.get(timeout=1.0) + logger.info("worker for %s got job: %s", context.device.device, name) + context.job = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) + context.set_finished() except Empty: pass except Exception as e: @@ -35,6 +36,3 @@ def worker_main(context: WorkerContext, server: ServerContext): "error while running job: %s", format_exception(type(e), e, e.__traceback__), ) - finally: - context.set_finished() - logger.info("finished job: %s", name) From 988088d64efdef3af8c576e9ee96d1ed5f5e777d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 22:52:43 -0600 Subject: [PATCH 29/40] quit workers on keyboard signal --- api/onnx_web/worker/worker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index ff0e9aae3..691a0fad5 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,5 +1,6 @@ from logging import getLogger from queue import Empty +from sys import exit from traceback import format_exception from setproctitle import setproctitle @@ -31,6 +32,8 @@ def worker_main(context: WorkerContext, server: ServerContext): context.set_finished() except Empty: pass + except KeyboardInterrupt: + exit(0) except Exception as e: logger.error( "error while running job: %s", From da6ae5d62f04da4cc00067c2772a6d8c91b17164 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 23:01:26 -0600 Subject: [PATCH 30/40] more logging around shutdown, close queues --- api/onnx_web/main.py | 7 ++++++- api/onnx_web/worker/pool.py | 18 ++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index bbfa10398..54f2f7df5 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -59,7 +59,12 @@ def main(): def run(): app, pool = main() - atexit.register(lambda: pool.join()) + + def quit(): + logger.info("shutting down workers") + pool.join() + + atexit.register(quit) return app diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 3bede110a..d2db88ea1 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -213,15 +213,29 @@ def done(self, key: str) -> Tuple[Optional[bool], int]: return (False, progress) def join(self): + logger.debug("stopping worker pool") + for device, worker in self.workers.items(): if worker.is_alive(): - logger.info("stopping worker for device %s", device) + logger.debug("stopping worker for device %s", device) worker.join(self.join_timeout) + else: + logger.debug("worker for device %s has died", device) for name, thread in self.threads.items(): - logger.info("stopping worker thread: %s", name) + logger.debug("stopping worker thread: %s", name) thread.join(self.join_timeout) + logger.debug("closing queues") + self.logs.close() + self.finished.close() + self.progress.close() + for key, queue in self.pending.items(): + queue.close() + del self.pending[key] + + logger.debug("worker pool fully joined") + def recycle(self): for name, proc in self.workers.items(): if proc.is_alive(): From f7f438e767bd2efae8e997730bd869a2e138d332 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 23:03:42 -0600 Subject: [PATCH 31/40] directly rejoin pool --- api/onnx_web/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 54f2f7df5..d1b20a147 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -69,6 +69,7 @@ def quit(): if __name__ == "__main__": - app = run() + app, pool = main() app.run("0.0.0.0", 5000, debug=is_debug()) logger.info("shutting down app") + pool.join() From 1ce98ace33a79bf384d05e8717a18389e9652f62 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 23:12:53 -0600 Subject: [PATCH 32/40] add value error handling --- api/onnx_web/worker/pool.py | 7 +++++++ api/onnx_web/worker/worker.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index d2db88ea1..41be9c972 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -95,6 +95,8 @@ def logger_worker(logs: Queue): f.write(str(job) + "\n\n") except Empty: pass + except ValueError: + break except Exception as err: logger.error("error in log worker: %s", err) @@ -119,6 +121,8 @@ def progress_worker(progress: Queue): self.context[device].set_cancel() except Empty: pass + except ValueError: + break except Exception as err: logger.error("error in progress worker: %s", err) @@ -141,6 +145,8 @@ def finished_worker(finished: Queue): del self.active_jobs[job] except Empty: pass + except ValueError: + break except Exception as err: logger.error("error in finished worker: %s", err) @@ -219,6 +225,7 @@ def join(self): if worker.is_alive(): logger.debug("stopping worker for device %s", device) worker.join(self.join_timeout) + worker.terminate() else: logger.debug("worker for device %s has died", device) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 691a0fad5..0cd180b42 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -29,11 +29,16 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("starting job: %s", name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) + context.pending.task_done() context.set_finished() except Empty: pass except KeyboardInterrupt: + logger.info("worker got keyboard interrupt") exit(0) + except ValueError as e: + logger.info("value error in worker: %s", e) + exit(1) except Exception as e: logger.error( "error while running job: %s", From 7e0ccdb1af196d3468f57df577caf5c0db295882 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 23:14:20 -0600 Subject: [PATCH 33/40] remove pending queues after joining --- api/onnx_web/worker/pool.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 41be9c972..dbd626f88 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -237,9 +237,10 @@ def join(self): self.logs.close() self.finished.close() self.progress.close() - for key, queue in self.pending.items(): + for queue in self.pending.values(): queue.close() - del self.pending[key] + + self.pending.clear() logger.debug("worker pool fully joined") From 4ae3d9caa2cecec677ef2450329ac6fac1c08bea Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 23:18:37 -0600 Subject: [PATCH 34/40] remove task done --- api/onnx_web/worker/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 0cd180b42..797b8de3f 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -29,7 +29,6 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("starting job: %s", name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) - context.pending.task_done() context.set_finished() except Empty: pass From cad0d37604ccbec7d0b95fcfa25993dc3bdf0dda Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 23:43:38 -0600 Subject: [PATCH 35/40] some pending queue logging --- api/onnx_web/worker/pool.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index dbd626f88..938de257d 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -63,8 +63,10 @@ def create_device_worker(self, device: DeviceParams) -> None: # reuse the queue if possible, to keep queued jobs if name in self.pending: + logger.debug("using existing pending job queue") pending = self.pending[name] else: + logger.debug("creating new pending job queue") pending = Queue() self.pending[name] = pending From 0011f079d4fb8c4c1b1afce7e89856c90e1dc7bc Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Feb 2023 06:55:15 -0600 Subject: [PATCH 36/40] daemonize queue collectors --- api/onnx_web/worker/pool.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 938de257d..ebfc0dd1b 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -102,7 +102,7 @@ def logger_worker(logs: Queue): except Exception as err: logger.error("error in log worker: %s", err) - logger_thread = Thread(target=logger_worker, args=(self.logs,)) + logger_thread = Thread(target=logger_worker, args=(self.logs,), daemon=True) self.threads["logger"] = logger_thread logger.debug("starting logger worker") @@ -128,7 +128,7 @@ def progress_worker(progress: Queue): except Exception as err: logger.error("error in progress worker: %s", err) - progress_thread = Thread(target=progress_worker, args=(self.progress,)) + progress_thread = Thread(target=progress_worker, args=(self.progress,), daemon=True) self.threads["progress"] = progress_thread logger.debug("starting progress worker") @@ -152,7 +152,7 @@ def finished_worker(finished: Queue): except Exception as err: logger.error("error in finished worker: %s", err) - finished_thread = Thread(target=finished_worker, args=(self.finished,)) + finished_thread = Thread(target=finished_worker, args=(self.finished,), daemon=True) self.threads["finished"] = finished_thread logger.debug("started finished worker") From c95ac1fbddd1c0821441af04f7ca129c5b7952b9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Feb 2023 08:53:17 -0600 Subject: [PATCH 37/40] avoid terminating workers because it breaks their queues --- api/onnx_web/worker/pool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ebfc0dd1b..8caec6375 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -227,7 +227,7 @@ def join(self): if worker.is_alive(): logger.debug("stopping worker for device %s", device) worker.join(self.join_timeout) - worker.terminate() + # worker.terminate() else: logger.debug("worker for device %s has died", device) @@ -251,7 +251,7 @@ def recycle(self): if proc.is_alive(): logger.debug("shutting down worker for device %s", name) proc.join(self.join_timeout) - proc.terminate() + # proc.terminate() else: logger.warning("worker for device %s has died", name) From c99aa67220b113aa082606dcd61d3a957dca61fc Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Feb 2023 21:44:52 -0600 Subject: [PATCH 38/40] name threads, max queues, type/lint fixes --- api/Makefile | 2 +- api/logging.yaml | 2 +- api/onnx_web/chain/base.py | 4 +- api/onnx_web/chain/blend_img2img.py | 3 +- api/onnx_web/chain/blend_inpaint.py | 2 +- api/onnx_web/chain/blend_mask.py | 2 +- api/onnx_web/chain/correct_codeformer.py | 3 +- api/onnx_web/chain/persist_s3.py | 7 +- api/onnx_web/chain/reduce_crop.py | 3 +- api/onnx_web/chain/source_txt2img.py | 3 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/chain/upscale_resrgan.py | 6 +- .../chain/upscale_stable_diffusion.py | 5 +- api/onnx_web/chain/utils.py | 3 + api/onnx_web/convert/__main__.py | 4 +- api/onnx_web/convert/utils.py | 8 +- api/onnx_web/main.py | 2 + api/onnx_web/output.py | 4 +- api/onnx_web/params.py | 4 +- api/onnx_web/server/api.py | 63 ++++++++------ api/onnx_web/server/context.py | 12 +-- api/onnx_web/server/hacks.py | 5 +- api/onnx_web/transformers.py | 6 +- api/onnx_web/upscale.py | 3 +- api/onnx_web/utils.py | 6 +- api/onnx_web/worker/context.py | 18 ++-- api/onnx_web/worker/pool.py | 84 +++++++++++++------ api/onnx_web/worker/worker.py | 1 - api/pyproject.toml | 21 ++++- api/scripts/test-memory-leak.sh | 2 +- 30 files changed, 179 insertions(+), 111 deletions(-) diff --git a/api/Makefile b/api/Makefile index 236aacabd..c5d598502 100644 --- a/api/Makefile +++ b/api/Makefile @@ -39,4 +39,4 @@ lint-fix: flake8 onnx_web typecheck: - mypy -m onnx_web.serve + mypy onnx_web diff --git a/api/logging.yaml b/api/logging.yaml index 24bd3c292..861d480a1 100644 --- a/api/logging.yaml +++ b/api/logging.yaml @@ -1,7 +1,7 @@ version: 1 formatters: simple: - format: '[%(asctime)s] %(levelname)s: %(name)s: %(message)s' + format: '[%(asctime)s] %(levelname)s: %(processName)s %(threadName)s %(name)s: %(message)s' handlers: console: class: logging.StreamHandler diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 2ddc59aea..45644496a 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -62,7 +62,7 @@ class ChainPipeline: def __init__( self, - stages: List[PipelineStage] = None, + stages: Optional[List[PipelineStage]] = None, ): """ Create a new pipeline that will run the given stages. @@ -82,7 +82,7 @@ def __call__( server: ServerContext, params: ImageParams, source: Image.Image, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **pipeline_kwargs ) -> Image.Image: """ diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index f7c516058..87b5ce94d 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -9,6 +9,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -20,7 +21,7 @@ def blend_img2img( params: ImageParams, source: Image.Image, *, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, stage_source: Image.Image, **kwargs, ) -> Image.Image: diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 7d864b5e8..3fbbbe3c7 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -31,7 +31,7 @@ def blend_inpaint( fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 5c53bd12b..8fe22220f 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -21,7 +21,7 @@ def blend_mask( *, sources: Optional[List[Image.Image]] = None, stage_mask: Optional[Image.Image] = None, - _callback: ProgressCallback = None, + _callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: logger.info("blending image using mask") diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index c3eaec65c..d9a070376 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -5,6 +5,7 @@ from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -18,7 +19,7 @@ def correct_codeformer( _params: ImageParams, source: Image.Image, *, - stage_source: Image.Image = None, + stage_source: Optional[Image.Image] = None, upscale: UpscaleParams, **kwargs, ) -> Image.Image: diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 3e01b9cec..e4eb96b09 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -7,6 +7,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -20,9 +21,9 @@ def persist_s3( *, output: str, bucket: str, - endpoint_url: str = None, - profile_name: str = None, - stage_source: Image.Image = None, + endpoint_url: Optional[str] = None, + profile_name: Optional[str] = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index cce82f0ca..0668429bc 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -5,6 +5,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -18,7 +19,7 @@ def reduce_crop( *, origin: Size, size: Size, - stage_source: Image.Image = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 1dec32439..0e22d552a 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -9,6 +9,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -21,7 +22,7 @@ def source_txt2img( _source: Image.Image, *, size: Size, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 695652050..a40de6715 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -31,7 +31,7 @@ def upscale_outpaint( fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 055319d0d..bde23ec37 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -9,12 +9,10 @@ from ..server import ServerContext from ..utils import run_gc from ..worker import WorkerContext +from typing import Optional logger = getLogger(__name__) -last_pipeline_instance = None -last_pipeline_params = (None, None) - TAG_X4_V3 = "real-esrgan-x4-v3" @@ -104,7 +102,7 @@ def upscale_resrgan( source: Image.Image, *, upscale: UpscaleParams, - stage_source: Image.Image = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 0accc8548..298b547aa 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -13,6 +13,7 @@ from ..server import ServerContext from ..utils import run_gc from ..worker import ProgressCallback, WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -70,8 +71,8 @@ def upscale_stable_diffusion( source: Image.Image, *, upscale: UpscaleParams, - stage_source: Image.Image = None, - callback: ProgressCallback = None, + stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 6598f10b3..81317226a 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -110,3 +110,6 @@ def process_tile_order( elif order == TileOrder.spiral: logger.debug("using spiral tile order with tile size: %s", tile) return process_tile_spiral(source, tile, scale, filters, **kwargs) + else: + logger.warn("unknown tile order: %s", order) + raise ValueError() diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 2a670511d..a9c760b2c 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -3,7 +3,7 @@ from logging import getLogger from os import makedirs, path from sys import exit -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse from jsonschema import ValidationError, validate @@ -36,7 +36,7 @@ ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ) -Models = Dict[str, List[Tuple[str, str, Optional[int]]]] +Models = Dict[str, List[Any]] logger = getLogger(__name__) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9d0f96aba..9b3edf662 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -23,7 +23,7 @@ class ConversionContext(ServerContext): def __init__( self, - model_path: Optional[str] = None, + model_path: str, cache_path: Optional[str] = None, device: Optional[str] = None, half: Optional[bool] = False, @@ -31,7 +31,7 @@ def __init__( token: Optional[str] = None, **kwargs, ) -> None: - super().__init__(self, model_path=model_path, cache_path=cache_path) + super().__init__(model_path=model_path, cache_path=cache_path) self.half = half self.opset = opset @@ -153,7 +153,7 @@ def source_format(model: Dict) -> Optional[str]: return model["format"] if "source" in model: - ext = path.splitext(model["source"]) + _name, ext = path.splitext(model["source"]) if ext in model_formats: return ext @@ -183,7 +183,7 @@ def config_from_key(cls, target, k, v): setattr(target, k, v) -def load_yaml(file: str) -> str: +def load_yaml(file: str) -> Config: with open(file, "r") as f: data = safe_load(f.read()) return Config(data) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index d1b20a147..7b8ea3255 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -6,6 +6,7 @@ from flask import Flask from flask_cors import CORS from huggingface_hub.utils.tqdm import disable_progress_bars +from setproctitle import setproctitle from torch.multiprocessing import set_start_method from .server.api import register_api_routes @@ -26,6 +27,7 @@ def main(): + setproctitle("onnx-web server") set_start_method("spawn", force=True) context = ServerContext.from_environ() diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 749008a2a..ae35d06ee 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -15,7 +15,7 @@ logger = getLogger(__name__) -def hash_value(sha, param: Param): +def hash_value(sha, param: Optional[Param]): if param is None: return elif isinstance(param, bool): @@ -63,7 +63,7 @@ def make_output_name( mode: str, params: ImageParams, size: Size, - extras: Optional[Tuple[Param]] = None, + extras: Optional[List[Optional[Param]]] = None, ) -> List[str]: now = int(time()) sha = sha256() diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 32440c072..9912d6649 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -101,12 +101,12 @@ def __init__( self.device = device self.provider = provider self.options = options - self.optimizations = optimizations + self.optimizations = optimizations or [] def __str__(self) -> str: return "%s - %s (%s)" % (self.device, self.provider, self.options) - def ort_provider(self) -> Tuple[str, Any]: + def ort_provider(self) -> Union[str, Tuple[str, Any]]: if self.options is None: return self.provider else: diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 599212463..7f69c1c1f 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -81,7 +81,7 @@ def introspect(context: ServerContext, app: Flask): return { "name": "onnx-web", "routes": [ - {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} + {"path": url_from_rule(rule), "methods": list(rule.methods or []).sort()} for rule in app.url_map.iter_rules() ], } @@ -119,10 +119,10 @@ def list_schedulers(context: ServerContext): def img2img(context: ServerContext, pool: DevicePoolExecutor): - if "source" not in request.files: + source_file = request.files.get("source") + if source_file is None: return error_reply("source image is required") - source_file = request.files.get("source") source = Image.open(BytesIO(source_file.read())).convert("RGB") device, params, size = pipeline_from_request(context) @@ -136,7 +136,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor): get_config_value("strength", "min"), ) - output = make_output_name(context, "img2img", params, size, extras=(strength,)) + output = make_output_name(context, "img2img", params, size, extras=[strength]) job_name = output[0] logger.info("img2img job queued for: %s", job_name) @@ -179,16 +179,15 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor): def inpaint(context: ServerContext, pool: DevicePoolExecutor): - if "source" not in request.files: + source_file = request.files.get("source") + if source_file is None: return error_reply("source image is required") - if "mask" not in request.files: + mask_file = request.files.get("mask") + if mask_file is None: return error_reply("mask image is required") - source_file = request.files.get("source") source = Image.open(BytesIO(source_file.read())).convert("RGB") - - mask_file = request.files.get("mask") mask = Image.open(BytesIO(mask_file.read())).convert("RGB") device, params, size = pipeline_from_request(context) @@ -207,7 +206,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): "inpaint", params, size, - extras=( + extras=[ expand.left, expand.right, expand.top, @@ -216,7 +215,7 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): noise_source.__name__, fill_color, tile_order, - ), + ], ) job_name = output[0] logger.info("inpaint job queued for: %s", job_name) @@ -245,10 +244,10 @@ def inpaint(context: ServerContext, pool: DevicePoolExecutor): def upscale(context: ServerContext, pool: DevicePoolExecutor): - if "source" not in request.files: + source_file = request.files.get("source") + if source_file is None: return error_reply("source image is required") - source_file = request.files.get("source") source = Image.open(BytesIO(source_file.read())).convert("RGB") device, params, size = pipeline_from_request(context) @@ -324,9 +323,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor): stage.name, ) source_file = request.files.get(stage_source_name) - source = Image.open(BytesIO(source_file.read())).convert("RGB") - source = valid_image(source, max_dims=(size.width, size.height)) - kwargs["stage_source"] = source + if source_file is not None: + source = Image.open(BytesIO(source_file.read())).convert("RGB") + source = valid_image(source, max_dims=(size.width, size.height)) + kwargs["stage_source"] = source if stage_mask_name in request.files: logger.debug( @@ -335,9 +335,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor): stage.name, ) mask_file = request.files.get(stage_mask_name) - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - mask = valid_image(mask, max_dims=(size.width, size.height)) - kwargs["stage_mask"] = mask + if mask_file is not None: + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + mask = valid_image(mask, max_dims=(size.width, size.height)) + kwargs["stage_mask"] = mask pipeline.append((callback, stage, kwargs)) @@ -360,10 +361,10 @@ def chain(context: ServerContext, pool: DevicePoolExecutor): def blend(context: ServerContext, pool: DevicePoolExecutor): - if "mask" not in request.files: + mask_file = request.files.get("mask") + if mask_file is None: return error_reply("mask image is required") - mask_file = request.files.get("mask") mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask = valid_image(mask) @@ -372,9 +373,12 @@ def blend(context: ServerContext, pool: DevicePoolExecutor): for i in range(max_sources): source_file = request.files.get("source:%s" % (i)) - source = Image.open(BytesIO(source_file.read())).convert("RGBA") - source = valid_image(source, mask.size, mask.size) - sources.append(source) + if source_file is None: + logger.warning("missing source %s", i) + else: + source = Image.open(BytesIO(source_file.read())).convert("RGBA") + source = valid_image(source, mask.size, mask.size) + sources.append(source) device, params, size = pipeline_from_request(context) upscale = upscale_from_request() @@ -403,10 +407,11 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(context) output = make_output_name(context, "txt2txt", params, size) - logger.info("upscale job queued for: %s", output) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) pool.submit( - output, + job_name, run_txt2txt_pipeline, context, params, @@ -420,6 +425,8 @@ def txt2txt(context: ServerContext, pool: DevicePoolExecutor): def cancel(context: ServerContext, pool: DevicePoolExecutor): output_file = request.args.get("output", None) + if output_file is None: + return error_reply("output name is required") cancel = pool.cancel(output_file) @@ -428,6 +435,8 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor): def ready(context: ServerContext, pool: DevicePoolExecutor): output_file = request.args.get("output", None) + if output_file is None: + return error_reply("output name is required") done, progress = pool.done(output_file) @@ -436,7 +445,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): if path.exists(output): return ready_reply(True) - return ready_reply(done, progress=progress) + return ready_reply(done or False, progress=progress) def status(context: ServerContext, pool: DevicePoolExecutor): diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 4e1f79601..3107a79b7 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,6 +1,6 @@ from logging import getLogger from os import environ, path -from typing import List +from typing import List, Optional from ..utils import get_boolean from .model_cache import ModelCache @@ -18,13 +18,13 @@ def __init__( cors_origin: str = "*", num_workers: int = 1, any_platform: bool = True, - block_platforms: List[str] = None, - default_platform: str = None, + block_platforms: Optional[List[str]] = None, + default_platform: Optional[str] = None, image_format: str = "png", - cache: ModelCache = None, - cache_path: str = None, + cache: Optional[ModelCache] = None, + cache_path: Optional[str] = None, show_progress: bool = True, - optimizations: List[str] = None, + optimizations: Optional[List[str]] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 4b7ac0b99..7847ccf4c 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -119,9 +119,8 @@ def patch_not_impl(): def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: - if url in cache_path_map: - cache_path = cache_path_map.get(url) - else: + cache_path = cache_path_map.get(url, None) + if cache_path is None: parsed = urlparse(url) cache_path = path.basename(parsed.path) diff --git a/api/onnx_web/transformers.py b/api/onnx_web/transformers.py index 18d90f0a1..299d38544 100644 --- a/api/onnx_web/transformers.py +++ b/api/onnx_web/transformers.py @@ -22,13 +22,13 @@ def run_txt2txt_pipeline( device = job.get_device() - model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device()) + pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str()) tokenizer = AutoTokenizer.from_pretrained(model) input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to( - device.torch_device() + device.torch_str() ) - results = model.generate( + results = pipe.generate( input_ids, do_sample=True, max_length=tokens, diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 098ce817c..4d103a2a5 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -12,6 +12,7 @@ from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .server import ServerContext from .worker import ProgressCallback, WorkerContext +from typing import Optional logger = getLogger(__name__) @@ -24,7 +25,7 @@ def run_upscale_correction( image: Image.Image, *, upscale: UpscaleParams, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, ) -> Image.Image: """ This is a convenience method for a chain pipeline that will run upscaling and diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 69012af96..f794eca0d 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -2,7 +2,7 @@ import threading from logging import getLogger from os import environ, path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch @@ -36,7 +36,7 @@ def get_and_clamp_int( return min(max(int(args.get(key, default_value)), min_value), max_value) -def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]: +def get_from_list(args: Any, key: str, values: Sequence[Any]) -> Optional[Any]: selected = args.get(key, None) if selected in values: return selected @@ -82,7 +82,7 @@ def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: raise ValueError("invalid size") -def run_gc(devices: List[DeviceParams] = None): +def run_gc(devices: Optional[List[DeviceParams]] = None): logger.debug( "running garbage collection with %s active threads", threading.active_count() ) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 0f1af7e2c..1b14d80e7 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -12,20 +12,20 @@ class WorkerContext: - cancel: "Value[bool]" = None - job: str = None - pending: "Queue[Tuple[Callable, Any, Any]]" = None - progress: "Value[int]" = None + cancel: "Value[bool]" + job: str + pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]" + progress: "Value[int]" def __init__( self, job: str, device: DeviceParams, - cancel: "Value[bool]" = None, - logs: "Queue[str]" = None, - pending: "Queue[Any]" = None, - progress: "Queue[Tuple[str, int]]" = None, - finished: "Queue[str]" = None, + cancel: "Value[bool]", + logs: "Queue[str]", + pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]", + progress: "Queue[Tuple[str, str, int]]", + finished: "Queue[Tuple[str, str]]", ): self.job = job self.device = device diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 8caec6375..38285df2c 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -2,7 +2,7 @@ from logging import getLogger from queue import Empty from threading import Thread -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from torch.multiprocessing import Process, Queue, Value @@ -15,23 +15,38 @@ class DevicePoolExecutor: - context: Dict[str, WorkerContext] = None # Device -> Context - devices: List[DeviceParams] = None - pending: Dict[str, "Queue[WorkerContext]"] = None - workers: Dict[str, Process] = None - active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus] - finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus] + server: ServerContext + devices: List[DeviceParams] + max_jobs_per_worker: int + max_pending_per_worker: int + join_timeout: float + + context: Dict[str, WorkerContext] # Device -> Context + pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] + threads: Dict[str, Thread] + workers: Dict[str, Process] + + active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus] + cancelled_jobs: List[str] + finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus] + total_jobs: int + + logs: "Queue" + progress: "Queue[Tuple[str, str, int]]" + finished: "Queue[Tuple[str, str]]" def __init__( self, server: ServerContext, devices: List[DeviceParams], max_jobs_per_worker: int = 10, + max_pending_per_worker: int = 100, join_timeout: float = 1.0, ): self.server = server self.devices = devices self.max_jobs_per_worker = max_jobs_per_worker + self.max_pending_per_worker = max_pending_per_worker self.join_timeout = join_timeout self.context = {} @@ -44,9 +59,9 @@ def __init__( self.finished_jobs = [] self.total_jobs = 0 # TODO: turn this into a Dict per-worker - self.logs = Queue() - self.progress = Queue() - self.finished = Queue() + self.logs = Queue(self.max_pending_per_worker) + self.progress = Queue(self.max_pending_per_worker) + self.finished = Queue(self.max_pending_per_worker) self.create_logger_worker() self.create_progress_worker() @@ -67,7 +82,7 @@ def create_device_worker(self, device: DeviceParams) -> None: pending = self.pending[name] else: logger.debug("creating new pending job queue") - pending = Queue() + pending = Queue(self.max_pending_per_worker) self.pending[name] = pending context = WorkerContext( @@ -80,7 +95,11 @@ def create_device_worker(self, device: DeviceParams) -> None: pending=pending, ) self.context[name] = context - self.workers[name] = Process(target=worker_main, args=(context, self.server)) + self.workers[name] = Process( + name=f"onnx-web worker: {name}", + target=worker_main, + args=(context, self.server), + ) logger.debug("starting worker for device %s", device) self.workers[name].start() @@ -102,7 +121,9 @@ def logger_worker(logs: Queue): except Exception as err: logger.error("error in log worker: %s", err) - logger_thread = Thread(target=logger_worker, args=(self.logs,), daemon=True) + logger_thread = Thread( + name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True + ) self.threads["logger"] = logger_thread logger.debug("starting logger worker") @@ -128,7 +149,12 @@ def progress_worker(progress: Queue): except Exception as err: logger.error("error in progress worker: %s", err) - progress_thread = Thread(target=progress_worker, args=(self.progress,), daemon=True) + progress_thread = Thread( + name="onnx-web progress", + target=progress_worker, + args=(self.progress,), + daemon=True, + ) self.threads["progress"] = progress_thread logger.debug("starting progress worker") @@ -152,7 +178,12 @@ def finished_worker(finished: Queue): except Exception as err: logger.error("error in finished worker: %s", err) - finished_thread = Thread(target=finished_worker, args=(self.finished,), daemon=True) + finished_thread = Thread( + name="onnx-web finished", + target=finished_worker, + args=(self.finished,), + daemon=True, + ) self.threads["finished"] = finished_thread logger.debug("started finished worker") @@ -221,8 +252,18 @@ def done(self, key: str) -> Tuple[Optional[bool], int]: return (False, progress) def join(self): - logger.debug("stopping worker pool") + logger.info("stopping worker pool") + + logger.debug("closing queues") + self.logs.close() + self.finished.close() + self.progress.close() + for queue in self.pending.values(): + queue.close() + self.pending.clear() + + logger.debug("stopping device workers") for device, worker in self.workers.items(): if worker.is_alive(): logger.debug("stopping worker for device %s", device) @@ -235,15 +276,6 @@ def join(self): logger.debug("stopping worker thread: %s", name) thread.join(self.join_timeout) - logger.debug("closing queues") - self.logs.close() - self.finished.close() - self.progress.close() - for queue in self.pending.values(): - queue.close() - - self.pending.clear() - logger.debug("worker pool fully joined") def recycle(self): @@ -292,7 +324,7 @@ def submit( def status(self) -> List[Tuple[str, int, bool, bool]]: history = [ (name, progress, False, name in self.cancelled_jobs) - for name, _device, progress in self.active_jobs.items() + for name, (_device, progress) in self.active_jobs.items() ] history.extend( [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 797b8de3f..94c570205 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -4,7 +4,6 @@ from traceback import format_exception from setproctitle import setproctitle -from torch.multiprocessing import Queue from ..server import ServerContext, apply_patches from ..torch_before_ort import get_available_providers diff --git a/api/pyproject.toml b/api/pyproject.toml index 04c7bca40..e337c1f21 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -8,16 +8,35 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion [tool.mypy] # ignore_missing_imports = true +exclude = [ + "onnx_web.diffusion.lpw_stable_diffusion_onnx", + "onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale" +] [[tool.mypy.overrides]] module = [ "basicsr.archs.rrdbnet_arch", + "basicsr.utils.download_util", + "basicsr.utils", + "basicsr", "boto3", "codeformer", + "codeformer.facelib.utils.misc", + "codeformer.facelib.utils", + "codeformer.facelib", "diffusers", + "diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion", + "diffusers.pipelines.paint_by_example", + "diffusers.pipelines.stable_diffusion", "diffusers.pipeline_utils", + "diffusers.utils.logging", + "facexlib.utils", + "facexlib", "gfpgan", "onnxruntime", - "realesrgan" + "realesrgan", + "realesrgan.archs.srvgg_arch", + "safetensors", + "transformers" ] ignore_missing_imports = true \ No newline at end of file diff --git a/api/scripts/test-memory-leak.sh b/api/scripts/test-memory-leak.sh index d7545e606..41953b4c5 100755 --- a/api/scripts/test-memory-leak.sh +++ b/api/scripts/test-memory-leak.sh @@ -4,7 +4,7 @@ test_images=0 while true; do curl "http://${test_host}:5000/api/txt2img?"\ -'cfg=16.00&steps=3&scheduler=deis-multi&seed=-1&'\ +'cfg=16.00&steps=3&scheduler=ddim&seed=-1&'\ 'prompt=an+astronaut+eating+a+hamburger&negativePrompt=&'\ 'model=stable-diffusion-onnx-v1-5&platform=any&'\ 'upscaling=upscaling-real-esrgan-x2-plus&correction=correction-codeformer&'\ From 12fb7f52bb17208f82cf6b3274acb4b90bfdc717 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Feb 2023 21:56:12 -0600 Subject: [PATCH 39/40] fix(api): sanitize filenames in user input --- api/onnx_web/convert/diffusion/original.py | 3 ++- api/onnx_web/convert/utils.py | 7 ------- api/onnx_web/server/api.py | 3 +++ api/onnx_web/utils.py | 6 ++++++ 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index c331ee741..dbcaa330e 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -53,7 +53,8 @@ CLIPVisionConfig, ) -from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from ...utils import sanitize_name +from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml from .diffusers import convert_diffusion_diffusers logger = getLogger(__name__) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9b3edf662..295fdd878 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -189,13 +189,6 @@ def load_yaml(file: str) -> Config: return Config(data) -safe_chars = "._-" - - -def sanitize_name(name): - return "".join(x for x in name if (x.isalnum() or x in safe_chars)) - - def remove_prefix(name, prefix): if name.startswith(prefix): return name[len(prefix) :] diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 7f69c1c1f..ed70c76fa 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -28,6 +28,7 @@ get_from_map, get_not_empty, get_size, + sanitize_name, ) from ..worker.pool import DevicePoolExecutor from .config import ( @@ -428,6 +429,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor): if output_file is None: return error_reply("output name is required") + output_file = sanitize_name(output_file) cancel = pool.cancel(output_file) return ready_reply(cancel) @@ -438,6 +440,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): if output_file is None: return error_reply("output name is required") + output_file = sanitize_name(output_file) done, progress = pool.done(output_file) if done is None: diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index f794eca0d..74f998f8a 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -10,6 +10,8 @@ logger = getLogger(__name__) +SAFE_CHARS = "._-" + def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") @@ -100,3 +102,7 @@ def run_gc(devices: Optional[List[DeviceParams]] = None): (mem_total - mem_free), mem_total, ) + + +def sanitize_name(name): + return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS)) From 1f9efb433aa37a8929666036426ba359ff262386 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Feb 2023 22:04:55 -0600 Subject: [PATCH 40/40] apply lint --- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/correct_codeformer.py | 2 +- api/onnx_web/chain/persist_s3.py | 2 +- api/onnx_web/chain/reduce_crop.py | 2 +- api/onnx_web/chain/source_txt2img.py | 2 +- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/output.py | 2 +- api/onnx_web/upscale.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 87b5ce94d..e25dc6d93 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import Optional import numpy as np import torch @@ -9,7 +10,6 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index d9a070376..09ecce9e1 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,11 +1,11 @@ from logging import getLogger +from typing import Optional from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index e4eb96b09..4c620c4c5 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -1,5 +1,6 @@ from io import BytesIO from logging import getLogger +from typing import Optional from boto3 import Session from PIL import Image @@ -7,7 +8,6 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 0668429bc..3f2d82db7 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -1,11 +1,11 @@ from logging import getLogger +from typing import Optional from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 0e22d552a..35ea37309 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import Optional import numpy as np import torch @@ -9,7 +10,6 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index bde23ec37..1f8c8f07e 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -1,5 +1,6 @@ from logging import getLogger from os import path +from typing import Optional import numpy as np from PIL import Image @@ -9,7 +10,6 @@ from ..server import ServerContext from ..utils import run_gc from ..worker import WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 298b547aa..9049758e3 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -1,5 +1,6 @@ from logging import getLogger from os import path +from typing import Optional import torch from diffusers import StableDiffusionUpscalePipeline @@ -13,7 +14,6 @@ from ..server import ServerContext from ..utils import run_gc from ..worker import ProgressCallback, WorkerContext -from typing import Optional logger = getLogger(__name__) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index ae35d06ee..399154b11 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -4,7 +4,7 @@ from os import path from struct import pack from time import time -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional from PIL import Image diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 4d103a2a5..c9b7308ae 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import Optional from PIL import Image @@ -12,7 +13,6 @@ from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .server import ServerContext from .worker import ProgressCallback, WorkerContext -from typing import Optional logger = getLogger(__name__)