diff --git a/api/.gitignore b/api/.gitignore index a315070ee..2ba1650c7 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -1,6 +1,7 @@ .coverage coverage.xml +*.log *.swp *.pyc diff --git a/api/Makefile b/api/Makefile index 236aacabd..c5d598502 100644 --- a/api/Makefile +++ b/api/Makefile @@ -39,4 +39,4 @@ lint-fix: flake8 onnx_web typecheck: - mypy -m onnx_web.serve + mypy onnx_web diff --git a/api/launch-extras.bat b/api/launch-extras.bat index fa3c89083..126df2c69 100644 --- a/api/launch-extras.bat +++ b/api/launch-extras.bat @@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json) python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN% echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app="onnx_web.main:run" run --host=0.0.0.0 diff --git a/api/launch-extras.sh b/api/launch-extras.sh index 96db9bb2c..50572aa43 100755 --- a/api/launch-extras.sh +++ b/api/launch-extras.sh @@ -25,4 +25,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/launch.bat b/api/launch.bat index f589bd119..09ddfccd2 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..." python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN% echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app="onnx_web.main:run" run --host=0.0.0.0 diff --git a/api/launch.sh b/api/launch.sh index 50863ba88..55b6ff729 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -24,4 +24,4 @@ python3 -m onnx_web.convert \ --token=${HF_TOKEN:-} echo "Launching API server..." -flask --app=onnx_web.serve run --host=0.0.0.0 +flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/logging.yaml b/api/logging.yaml index 0bf543100..861d480a1 100644 --- a/api/logging.yaml +++ b/api/logging.yaml @@ -1,18 +1,18 @@ version: 1 formatters: simple: - format: '[%(asctime)s] %(levelname)s: %(name)s: %(message)s' + format: '[%(asctime)s] %(levelname)s: %(processName)s %(threadName)s %(name)s: %(message)s' handlers: console: class: logging.StreamHandler - level: INFO + level: DEBUG formatter: simple stream: ext://sys.stdout loggers: '': - level: INFO + level: DEBUG handlers: [console] propagate: True root: - level: INFO + level: DEBUG handlers: [console] diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 3afedb07f..5294121c7 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -1,5 +1,10 @@ from . import logging -from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion +from .chain import ( + correct_codeformer, + correct_gfpgan, + upscale_resrgan, + upscale_stable_diffusion, +) from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline from .diffusion.run import ( run_blend_pipeline, @@ -25,6 +30,7 @@ from .onnx import OnnxNet, OnnxTensor from .params import ( Border, + DeviceParams, ImageParams, Param, Point, @@ -33,8 +39,6 @@ UpscaleParams, ) from .server import ( - DeviceParams, - DevicePoolExecutor, ModelCache, ServerContext, apply_patch_basicsr, @@ -51,3 +55,6 @@ get_from_map, get_not_empty, ) +from .worker import ( + DevicePoolExecutor, +) diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 5aa56c567..a983c8498 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -13,3 +13,20 @@ from .upscale_outpaint import upscale_outpaint from .upscale_resrgan import upscale_resrgan from .upscale_stable_diffusion import upscale_stable_diffusion + +CHAIN_STAGES = { + "blend-img2img": blend_img2img, + "blend-inpaint": blend_inpaint, + "blend-mask": blend_mask, + "correct-codeformer": correct_codeformer, + "correct-gfpgan": correct_gfpgan, + "persist-disk": persist_disk, + "persist-s3": persist_s3, + "reduce-crop": reduce_crop, + "reduce-thumbnail": reduce_thumbnail, + "source-noise": source_noise, + "source-txt2img": source_txt2img, + "upscale-outpaint": upscale_outpaint, + "upscale-resrgan": upscale_resrgan, + "upscale-stable-diffusion": upscale_stable_diffusion, +} diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 7e77bab7c..45644496a 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -7,8 +7,9 @@ from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_order logger = getLogger(__name__) @@ -17,7 +18,7 @@ class StageCallback(Protocol): def __call__( self, - job: JobContext, + job: WorkerContext, ctx: ServerContext, stage: StageParams, params: ImageParams, @@ -61,7 +62,7 @@ class ChainPipeline: def __init__( self, - stages: List[PipelineStage] = None, + stages: Optional[List[PipelineStage]] = None, ): """ Create a new pipeline that will run the given stages. @@ -77,11 +78,11 @@ def append(self, stage: PipelineStage): def __call__( self, - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, source: Image.Image, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **pipeline_kwargs ) -> Image.Image: """ diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0ef9ef960..e25dc6d93 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import Optional import numpy as np import torch @@ -7,19 +8,20 @@ from ..diffusion.load import load_pipeline from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) def blend_img2img( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, source: Image.Image, *, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, stage_source: Image.Image, **kwargs, ) -> Image.Image: diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 16422cce3..3fbbbe3c7 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -10,15 +10,16 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_order logger = getLogger(__name__) def blend_inpaint( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, @@ -30,7 +31,7 @@ def blend_inpaint( fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index be5d11597..8fe22220f 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -3,25 +3,25 @@ from PIL import Image -from onnx_web.image import valid_image -from onnx_web.output import save_image - +from ..image import valid_image +from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) def blend_mask( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, *, sources: Optional[List[Image.Image]] = None, stage_mask: Optional[Image.Image] = None, - _callback: ProgressCallback = None, + _callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: logger.info("blending image using mask") diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 6b4e235d7..09ecce9e1 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,9 +1,11 @@ from logging import getLogger +from typing import Optional from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) @@ -11,13 +13,13 @@ def correct_codeformer( - job: JobContext, + job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, source: Image.Image, *, - stage_source: Image.Image = None, + stage_source: Optional[Image.Image] = None, upscale: UpscaleParams, **kwargs, ) -> Image.Image: diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 6796c1bab..2cff2e181 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,8 +5,9 @@ from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..server import ServerContext from ..utils import run_gc +from ..worker import WorkerContext logger = getLogger(__name__) @@ -46,7 +47,7 @@ def load_gfpgan( def correct_gfpgan( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 9a5f0cd0c..eac0f36cb 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -4,13 +4,14 @@ from ..output import save_image from ..params import ImageParams, StageParams -from ..server import JobContext, ServerContext +from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) def persist_disk( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index bf3682bf3..4c620c4c5 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -1,17 +1,19 @@ from io import BytesIO from logging import getLogger +from typing import Optional from boto3 import Session from PIL import Image from ..params import ImageParams, StageParams -from ..server import JobContext, ServerContext +from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) def persist_s3( - _job: JobContext, + _job: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, @@ -19,9 +21,9 @@ def persist_s3( *, output: str, bucket: str, - endpoint_url: str = None, - profile_name: str = None, - stage_source: Image.Image = None, + endpoint_url: Optional[str] = None, + profile_name: Optional[str] = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 226f6cf26..3f2d82db7 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -1,15 +1,17 @@ from logging import getLogger +from typing import Optional from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) def reduce_crop( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, @@ -17,7 +19,7 @@ def reduce_crop( *, origin: Size, size: Size, - stage_source: Image.Image = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 0037084c0..6df2ed6e1 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -3,13 +3,14 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) def reduce_thumbnail( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index f6267b26b..0092292c7 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -4,13 +4,14 @@ from PIL import Image from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ServerContext +from ..server import ServerContext +from ..worker import WorkerContext logger = getLogger(__name__) def source_noise( - _job: JobContext, + _job: WorkerContext, _server: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index a5cbb07fe..35ea37309 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import Optional import numpy as np import torch @@ -7,20 +8,21 @@ from ..diffusion.load import get_latents_from_seed, load_pipeline from ..params import ImageParams, Size, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) def source_txt2img( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, _source: Image.Image, *, size: Size, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 7919e3251..a40de6715 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,15 +10,16 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext from ..utils import is_debug +from ..worker import ProgressCallback, WorkerContext from .utils import process_tile_grid, process_tile_order logger = getLogger(__name__) def upscale_outpaint( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, @@ -30,7 +31,7 @@ def upscale_outpaint( fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 178360c36..1f8c8f07e 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -1,19 +1,18 @@ from logging import getLogger from os import path +from typing import Optional import numpy as np from PIL import Image from ..onnx import OnnxNet from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..server import ServerContext from ..utils import run_gc +from ..worker import WorkerContext logger = getLogger(__name__) -last_pipeline_instance = None -last_pipeline_params = (None, None) - TAG_X4_V3 = "real-esrgan-x4-v3" @@ -96,14 +95,14 @@ def load_resrgan( def upscale_resrgan( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, source: Image.Image, *, upscale: UpscaleParams, - stage_source: Image.Image = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: source = stage_source or source diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 5747f13fe..9049758e3 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -1,5 +1,6 @@ from logging import getLogger from os import path +from typing import Optional import torch from diffusers import StableDiffusionUpscalePipeline @@ -10,8 +11,9 @@ OnnxStableDiffusionUpscalePipeline, ) from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams -from ..server import JobContext, ProgressCallback, ServerContext +from ..server import ServerContext from ..utils import run_gc +from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) @@ -62,15 +64,15 @@ def load_stable_diffusion( def upscale_stable_diffusion( - job: JobContext, + job: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, source: Image.Image, *, upscale: UpscaleParams, - stage_source: Image.Image = None, - callback: ProgressCallback = None, + stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 6598f10b3..81317226a 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -110,3 +110,6 @@ def process_tile_order( elif order == TileOrder.spiral: logger.debug("using spiral tile order with tile size: %s", tile) return process_tile_spiral(source, tile, scale, filters, **kwargs) + else: + logger.warn("unknown tile order: %s", order) + raise ValueError() diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 2a670511d..a9c760b2c 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -3,7 +3,7 @@ from logging import getLogger from os import makedirs, path from sys import exit -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse from jsonschema import ValidationError, validate @@ -36,7 +36,7 @@ ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ) -Models = Dict[str, List[Tuple[str, str, Optional[int]]]] +Models = Dict[str, List[Any]] logger = getLogger(__name__) diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index c331ee741..dbcaa330e 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -53,7 +53,8 @@ CLIPVisionConfig, ) -from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from ...utils import sanitize_name +from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml from .diffusers import convert_diffusion_diffusers logger = getLogger(__name__) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9d0f96aba..295fdd878 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -23,7 +23,7 @@ class ConversionContext(ServerContext): def __init__( self, - model_path: Optional[str] = None, + model_path: str, cache_path: Optional[str] = None, device: Optional[str] = None, half: Optional[bool] = False, @@ -31,7 +31,7 @@ def __init__( token: Optional[str] = None, **kwargs, ) -> None: - super().__init__(self, model_path=model_path, cache_path=cache_path) + super().__init__(model_path=model_path, cache_path=cache_path) self.half = half self.opset = opset @@ -153,7 +153,7 @@ def source_format(model: Dict) -> Optional[str]: return model["format"] if "source" in model: - ext = path.splitext(model["source"]) + _name, ext = path.splitext(model["source"]) if ext in model_formats: return ext @@ -183,19 +183,12 @@ def config_from_key(cls, target, k, v): setattr(target, k, v) -def load_yaml(file: str) -> str: +def load_yaml(file: str) -> Config: with open(file, "r") as f: data = safe_load(f.read()) return Config(data) -safe_chars = "._-" - - -def sanitize_name(name): - return "".join(x for x in name if (x.isalnum() or x in safe_chars)) - - def remove_prefix(name, prefix): if name.startswith(prefix): return name[len(prefix) :] diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 47c35fcdd..e3b545105 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -25,7 +25,7 @@ try: from diffusers import DEISMultistepScheduler except ImportError: - from .stub_scheduler import StubScheduler as DEISMultistepScheduler + from ..diffusion.stub_scheduler import StubScheduler as DEISMultistepScheduler from ..params import DeviceParams, Size from ..server import ServerContext @@ -54,6 +54,10 @@ } +def get_pipeline_schedulers(): + return pipeline_schedulers + + def get_scheduler_name(scheduler: Any) -> Optional[str]: for k, v in pipeline_schedulers.items(): if scheduler == v or scheduler == v.__name__: @@ -137,13 +141,14 @@ def load_pipeline( server: ServerContext, pipeline: DiffusionPipeline, model: str, - scheduler_type: Any, + scheduler_name: str, device: DeviceParams, lpw: bool, inversion: Optional[str], ): pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion) - scheduler_key = (scheduler_type, model) + scheduler_key = (scheduler_name, model) + scheduler_type = get_pipeline_schedulers()[scheduler_name] cache_pipe = server.cache.get("diffusion", pipe_key) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index ec988d660..765d7de8d 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -6,22 +6,21 @@ from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from PIL import Image -from onnx_web.chain import blend_mask -from onnx_web.chain.base import ChainProgress - -from ..chain import upscale_outpaint +from ..chain import blend_mask, upscale_outpaint +from ..chain.base import ChainProgress from ..output import save_image, save_params from ..params import Border, ImageParams, Size, StageParams, UpscaleParams -from ..server import JobContext, ServerContext +from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc +from ..worker import WorkerContext from .load import get_latents_from_seed, load_pipeline logger = getLogger(__name__) def run_txt2img_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -95,7 +94,7 @@ def run_txt2img_pipeline( def run_img2img_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, outputs: List[str], @@ -167,7 +166,7 @@ def run_img2img_pipeline( def run_inpaint_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -217,7 +216,7 @@ def run_inpaint_pipeline( def run_upscale_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, @@ -243,7 +242,7 @@ def run_upscale_pipeline( def run_blend_pipeline( - job: JobContext, + job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py new file mode 100644 index 000000000..7b8ea3255 --- /dev/null +++ b/api/onnx_web/main.py @@ -0,0 +1,77 @@ +import atexit +import gc +from logging import getLogger + +from diffusers.utils.logging import disable_progress_bar +from flask import Flask +from flask_cors import CORS +from huggingface_hub.utils.tqdm import disable_progress_bars +from setproctitle import setproctitle +from torch.multiprocessing import set_start_method + +from .server.api import register_api_routes +from .server.config import ( + get_available_platforms, + load_models, + load_params, + load_platforms, +) +from .server.context import ServerContext +from .server.hacks import apply_patches +from .server.static import register_static_routes +from .server.utils import check_paths +from .utils import is_debug +from .worker import DevicePoolExecutor + +logger = getLogger(__name__) + + +def main(): + setproctitle("onnx-web server") + set_start_method("spawn", force=True) + + context = ServerContext.from_environ() + apply_patches(context) + check_paths(context) + load_models(context) + load_params(context) + load_platforms(context) + + if is_debug(): + gc.set_debug(gc.DEBUG_STATS) + + if not context.show_progress: + disable_progress_bar() + disable_progress_bars() + + app = Flask(__name__) + CORS(app, origins=context.cors_origin) + + # any is a fake device, should not be in the pool + pool = DevicePoolExecutor( + context, [p for p in get_available_platforms() if p.device != "any"] + ) + + # register routes + register_static_routes(app, context, pool) + register_api_routes(app, context, pool) + + return app, pool + + +def run(): + app, pool = main() + + def quit(): + logger.info("shutting down workers") + pool.join() + + atexit.register(quit) + return app + + +if __name__ == "__main__": + app, pool = main() + app.run("0.0.0.0", 5000, debug=is_debug()) + logger.info("shutting down app") + pool.join() diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 97d5c8b06..a974aff48 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -3,9 +3,9 @@ import numpy as np import torch -from onnxruntime import InferenceSession, SessionOptions from ..server import ServerContext +from ..torch_before_ort import InferenceSession, SessionOptions class OnnxTensor: diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index f60b3941f..399154b11 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -4,11 +4,10 @@ from os import path from struct import pack from time import time -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional from PIL import Image -from .diffusion.load import get_scheduler_name from .params import Border, ImageParams, Param, Size, UpscaleParams from .server import ServerContext from .utils import base_join @@ -16,7 +15,7 @@ logger = getLogger(__name__) -def hash_value(sha, param: Param): +def hash_value(sha, param: Optional[Param]): if param is None: return elif isinstance(param, bool): @@ -44,7 +43,7 @@ def json_params( } json["params"]["model"] = path.basename(params.model) - json["params"]["scheduler"] = get_scheduler_name(params.scheduler) + json["params"]["scheduler"] = params.scheduler if border is not None: json["border"] = border.tojson() @@ -64,14 +63,14 @@ def make_output_name( mode: str, params: ImageParams, size: Size, - extras: Optional[Tuple[Param]] = None, + extras: Optional[List[Optional[Param]]] = None, ) -> List[str]: now = int(time()) sha = sha256() hash_value(sha, mode) hash_value(sha, params.model) - hash_value(sha, params.scheduler.__name__) + hash_value(sha, params.scheduler) hash_value(sha, params.prompt) hash_value(sha, params.negative_prompt) hash_value(sha, params.cfg) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index db23414aa..9912d6649 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,7 +2,7 @@ from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from onnxruntime import GraphOptimizationLevel, SessionOptions +from .torch_before_ort import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) @@ -101,12 +101,12 @@ def __init__( self.device = device self.provider = provider self.options = options - self.optimizations = optimizations + self.optimizations = optimizations or [] def __str__(self) -> str: return "%s - %s (%s)" % (self.device, self.provider, self.options) - def ort_provider(self) -> Tuple[str, Any]: + def ort_provider(self) -> Union[str, Tuple[str, Any]]: if self.options is None: return self.provider else: @@ -148,7 +148,7 @@ class ImageParams: def __init__( self, model: str, - scheduler: Any, + scheduler: str, prompt: str, cfg: float, steps: int, @@ -174,7 +174,7 @@ def __init__( def tojson(self) -> Dict[str, Optional[Param]]: return { "model": self.model, - "scheduler": self.scheduler.__name__, + "scheduler": self.scheduler, "prompt": self.prompt, "negative_prompt": self.negative_prompt, "cfg": self.cfg, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py deleted file mode 100644 index fb95d8ef3..000000000 --- a/api/onnx_web/serve.py +++ /dev/null @@ -1,880 +0,0 @@ -import gc -from functools import cmp_to_key -from glob import glob -from io import BytesIO -from logging import getLogger -from os import makedirs, path -from typing import Dict, List, Tuple, Union - -import numpy as np -import torch -import yaml -from diffusers.utils.logging import disable_progress_bar -from flask import Flask, jsonify, make_response, request, send_from_directory, url_for -from flask_cors import CORS -from huggingface_hub.utils.tqdm import disable_progress_bars -from jsonschema import validate -from onnxruntime import get_available_providers -from PIL import Image - -from .chain import ( - ChainPipeline, - blend_img2img, - blend_inpaint, - correct_codeformer, - correct_gfpgan, - persist_disk, - persist_s3, - reduce_crop, - reduce_thumbnail, - source_noise, - source_txt2img, - upscale_outpaint, - upscale_resrgan, - upscale_stable_diffusion, -) -from .diffusion.load import pipeline_schedulers -from .diffusion.run import ( - run_blend_pipeline, - run_img2img_pipeline, - run_inpaint_pipeline, - run_txt2img_pipeline, - run_upscale_pipeline, -) -from .image import ( # mask filters; noise sources - mask_filter_gaussian_multiply, - mask_filter_gaussian_screen, - mask_filter_none, - noise_source_fill_edge, - noise_source_fill_mask, - noise_source_gaussian, - noise_source_histogram, - noise_source_normal, - noise_source_uniform, - valid_image, -) -from .output import json_params, make_output_name -from .params import ( - Border, - DeviceParams, - ImageParams, - Size, - StageParams, - TileOrder, - UpscaleParams, -) -from .server import DevicePoolExecutor, ServerContext, apply_patches -from .transformers import run_txt2txt_pipeline -from .utils import ( - base_join, - get_and_clamp_float, - get_and_clamp_int, - get_from_list, - get_from_map, - get_not_empty, - get_size, - is_debug, -) - -logger = getLogger(__name__) - -# config caching -config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} - -# pipeline params -platform_providers = { - "cpu": "CPUExecutionProvider", - "cuda": "CUDAExecutionProvider", - "directml": "DmlExecutionProvider", - "rocm": "ROCMExecutionProvider", -} - -noise_sources = { - "fill-edge": noise_source_fill_edge, - "fill-mask": noise_source_fill_mask, - "gaussian": noise_source_gaussian, - "histogram": noise_source_histogram, - "normal": noise_source_normal, - "uniform": noise_source_uniform, -} -mask_filters = { - "none": mask_filter_none, - "gaussian-multiply": mask_filter_gaussian_multiply, - "gaussian-screen": mask_filter_gaussian_screen, -} -chain_stages = { - "blend-img2img": blend_img2img, - "blend-inpaint": blend_inpaint, - "correct-codeformer": correct_codeformer, - "correct-gfpgan": correct_gfpgan, - "persist-disk": persist_disk, - "persist-s3": persist_s3, - "reduce-crop": reduce_crop, - "reduce-thumbnail": reduce_thumbnail, - "source-noise": source_noise, - "source-txt2img": source_txt2img, - "upscale-outpaint": upscale_outpaint, - "upscale-resrgan": upscale_resrgan, - "upscale-stable-diffusion": upscale_stable_diffusion, -} - -# Available ORT providers -available_platforms: List[DeviceParams] = [] - -# loaded from model_path -correction_models: List[str] = [] -diffusion_models: List[str] = [] -inversion_models: List[str] = [] -upscaling_models: List[str] = [] - - -def get_config_value(key: str, subkey: str = "default", default=None): - return config_params.get(key, {}).get(subkey, default) - - -def url_from_rule(rule) -> str: - options = {} - for arg in rule.arguments: - options[arg] = ":%s" % (arg) - - return url_for(rule.endpoint, **options) - - -def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: - user = request.remote_addr - - # platform stuff - device = None - device_name = request.args.get("platform") - - if device_name is not None and device_name != "any": - for platform in available_platforms: - if platform.device == device_name: - device = platform - - # pipeline stuff - lpw = get_not_empty(request.args, "lpw", "false") == "true" - model = get_not_empty(request.args, "model", get_config_value("model")) - model_path = get_model_path(model) - scheduler = get_from_map( - request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler") - ) - - inversion = request.args.get("inversion", None) - inversion_path = None - if inversion is not None and inversion.strip() != "": - inversion_path = get_model_path(inversion) - - # image params - prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) - negative_prompt = request.args.get("negativePrompt", None) - - if negative_prompt is not None and negative_prompt.strip() == "": - negative_prompt = None - - batch = get_and_clamp_int( - request.args, - "batch", - get_config_value("batch"), - get_config_value("batch", "max"), - get_config_value("batch", "min"), - ) - cfg = get_and_clamp_float( - request.args, - "cfg", - get_config_value("cfg"), - get_config_value("cfg", "max"), - get_config_value("cfg", "min"), - ) - eta = get_and_clamp_float( - request.args, - "eta", - get_config_value("eta"), - get_config_value("eta", "max"), - get_config_value("eta", "min"), - ) - steps = get_and_clamp_int( - request.args, - "steps", - get_config_value("steps"), - get_config_value("steps", "max"), - get_config_value("steps", "min"), - ) - height = get_and_clamp_int( - request.args, - "height", - get_config_value("height"), - get_config_value("height", "max"), - get_config_value("height", "min"), - ) - width = get_and_clamp_int( - request.args, - "width", - get_config_value("width"), - get_config_value("width", "max"), - get_config_value("width", "min"), - ) - - seed = int(request.args.get("seed", -1)) - if seed == -1: - # this one can safely use np.random because it produces a single value - seed = np.random.randint(np.iinfo(np.int32).max) - - logger.info( - "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", - user, - steps, - scheduler.__name__, - model_path, - device or "any device", - width, - height, - cfg, - seed, - prompt, - ) - - params = ImageParams( - model_path, - scheduler, - prompt, - cfg, - steps, - seed, - eta=eta, - lpw=lpw, - negative_prompt=negative_prompt, - batch=batch, - inversion=inversion_path, - ) - size = Size(width, height) - return (device, params, size) - - -def border_from_request() -> Border: - left = get_and_clamp_int( - request.args, "left", 0, get_config_value("width", "max"), 0 - ) - right = get_and_clamp_int( - request.args, "right", 0, get_config_value("width", "max"), 0 - ) - top = get_and_clamp_int( - request.args, "top", 0, get_config_value("height", "max"), 0 - ) - bottom = get_and_clamp_int( - request.args, "bottom", 0, get_config_value("height", "max"), 0 - ) - - return Border(left, right, top, bottom) - - -def upscale_from_request() -> UpscaleParams: - denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) - scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) - outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) - upscaling = get_from_list(request.args, "upscaling", upscaling_models) - correction = get_from_list(request.args, "correction", correction_models) - faces = get_not_empty(request.args, "faces", "false") == "true" - face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) - face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) - upscale_order = request.args.get("upscaleOrder", "correction-first") - - return UpscaleParams( - upscaling, - correction_model=correction, - denoise=denoise, - faces=faces, - face_outscale=face_outscale, - face_strength=face_strength, - format="onnx", - outscale=outscale, - scale=scale, - upscale_order=upscale_order, - ) - - -def check_paths(context: ServerContext) -> None: - if not path.exists(context.model_path): - raise RuntimeError("model path must exist") - - if not path.exists(context.output_path): - makedirs(context.output_path) - - -def get_model_name(model: str) -> str: - base = path.basename(model) - (file, _ext) = path.splitext(base) - return file - - -def load_models(context: ServerContext) -> None: - global correction_models - global diffusion_models - global inversion_models - global upscaling_models - - diffusion_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) - ] - diffusion_models.extend( - [ - get_model_name(f) - for f in glob(path.join(context.model_path, "stable-diffusion-*")) - ] - ) - diffusion_models = list(set(diffusion_models)) - diffusion_models.sort() - - correction_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) - ] - correction_models = list(set(correction_models)) - correction_models.sort() - - inversion_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) - ] - inversion_models = list(set(inversion_models)) - inversion_models.sort() - - upscaling_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) - ] - upscaling_models = list(set(upscaling_models)) - upscaling_models.sort() - - -def load_params(context: ServerContext) -> None: - global config_params - params_file = path.join(context.params_path, "params.json") - with open(params_file, "r") as f: - config_params = yaml.safe_load(f) - - if "platform" in config_params and context.default_platform is not None: - logger.info( - "Overriding default platform from environment: %s", - context.default_platform, - ) - config_platform = config_params.get("platform", {}) - config_platform["default"] = context.default_platform - - -def load_platforms(context: ServerContext) -> None: - global available_platforms - - providers = list(get_available_providers()) - - for potential in platform_providers: - if ( - platform_providers[potential] in providers - and potential not in context.block_platforms - ): - if potential == "cuda": - for i in range(torch.cuda.device_count()): - available_platforms.append( - DeviceParams( - potential, - platform_providers[potential], - { - "device_id": i, - }, - context.optimizations, - ) - ) - else: - available_platforms.append( - DeviceParams( - potential, - platform_providers[potential], - None, - context.optimizations, - ) - ) - - if context.any_platform: - # the platform should be ignored when the job is scheduled, but set to CPU just in case - available_platforms.append( - DeviceParams( - "any", - platform_providers["cpu"], - None, - context.optimizations, - ) - ) - - # make sure CPU is last on the list - def any_first_cpu_last(a: DeviceParams, b: DeviceParams): - if a.device == b.device: - return 0 - - # any should be first, if it's available - if a.device == "any": - return -1 - - # cpu should be last, if it's available - if a.device == "cpu": - return 1 - - return -1 - - available_platforms = sorted( - available_platforms, key=cmp_to_key(any_first_cpu_last) - ) - - logger.info( - "available acceleration platforms: %s", - ", ".join([str(p) for p in available_platforms]), - ) - - -context = ServerContext.from_environ() -apply_patches(context) -check_paths(context) -load_models(context) -load_params(context) -load_platforms(context) - -if not context.show_progress: - disable_progress_bar() - disable_progress_bars() - -app = Flask(__name__) -CORS(app, origins=context.cors_origin) - -# any is a fake device, should not be in the pool -executor = DevicePoolExecutor([p for p in available_platforms if p.device != "any"]) - -if is_debug(): - gc.set_debug(gc.DEBUG_STATS) - - -def ready_reply(ready: bool, progress: int = 0): - return jsonify( - { - "progress": progress, - "ready": ready, - } - ) - - -def error_reply(err: str): - response = make_response( - jsonify( - { - "error": err, - } - ) - ) - response.status_code = 400 - return response - - -def get_model_path(model: str): - return base_join(context.model_path, model) - - -def serve_bundle_file(filename="index.html"): - return send_from_directory(path.join("..", context.bundle_path), filename) - - -# routes - - -@app.route("/") -def index(): - return serve_bundle_file() - - -@app.route("/") -def index_path(filename): - return serve_bundle_file(filename) - - -@app.route("/api") -def introspect(): - return { - "name": "onnx-web", - "routes": [ - {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} - for rule in app.url_map.iter_rules() - ], - } - - -@app.route("/api/settings/masks") -def list_mask_filters(): - return jsonify(list(mask_filters.keys())) - - -@app.route("/api/settings/models") -def list_models(): - return jsonify( - { - "correction": correction_models, - "diffusion": diffusion_models, - "inversion": inversion_models, - "upscaling": upscaling_models, - } - ) - - -@app.route("/api/settings/noises") -def list_noise_sources(): - return jsonify(list(noise_sources.keys())) - - -@app.route("/api/settings/params") -def list_params(): - return jsonify(config_params) - - -@app.route("/api/settings/platforms") -def list_platforms(): - return jsonify([p.device for p in available_platforms]) - - -@app.route("/api/settings/schedulers") -def list_schedulers(): - return jsonify(list(pipeline_schedulers.keys())) - - -@app.route("/api/img2img", methods=["POST"]) -def img2img(): - if "source" not in request.files: - return error_reply("source image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - strength = get_and_clamp_float( - request.args, - "strength", - get_config_value("strength"), - get_config_value("strength", "max"), - get_config_value("strength", "min"), - ) - - output = make_output_name(context, "img2img", params, size, extras=(strength,)) - job_name = output[0] - logger.info("img2img job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_img2img_pipeline, - context, - params, - output, - upscale, - source, - strength, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/txt2img", methods=["POST"]) -def txt2img(): - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "txt2img", params, size) - job_name = output[0] - logger.info("txt2img job queued for: %s", job_name) - - executor.submit( - job_name, - run_txt2img_pipeline, - context, - params, - size, - output, - upscale, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/inpaint", methods=["POST"]) -def inpaint(): - if "source" not in request.files: - return error_reply("source image is required") - - if "mask" not in request.files: - return error_reply("mask image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - mask_file = request.files.get("mask") - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - expand = border_from_request() - upscale = upscale_from_request() - - fill_color = get_not_empty(request.args, "fillColor", "white") - mask_filter = get_from_map(request.args, "filter", mask_filters, "none") - noise_source = get_from_map(request.args, "noise", noise_sources, "histogram") - tile_order = get_from_list( - request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] - ) - - output = make_output_name( - context, - "inpaint", - params, - size, - extras=( - expand.left, - expand.right, - expand.top, - expand.bottom, - mask_filter.__name__, - noise_source.__name__, - fill_color, - tile_order, - ), - ) - job_name = output[0] - logger.info("inpaint job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - mask = valid_image(mask, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_inpaint_pipeline, - context, - params, - size, - output, - upscale, - source, - mask, - expand, - noise_source, - mask_filter, - fill_color, - tile_order, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) - - -@app.route("/api/upscale", methods=["POST"]) -def upscale(): - if "source" not in request.files: - return error_reply("source image is required") - - source_file = request.files.get("source") - source = Image.open(BytesIO(source_file.read())).convert("RGB") - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "upscale", params, size) - job_name = output[0] - logger.info("upscale job queued for: %s", job_name) - - source = valid_image(source, min_dims=size, max_dims=size) - executor.submit( - job_name, - run_upscale_pipeline, - context, - params, - size, - output, - upscale, - source, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/chain", methods=["POST"]) -def chain(): - logger.debug( - "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() - ) - body = request.form.get("chain") or request.files.get("chain") - if body is None: - return error_reply("chain pipeline must have a body") - - data = yaml.safe_load(body) - with open("./schemas/chain.yaml", "r") as f: - schema = yaml.safe_load(f.read()) - - logger.debug("validating chain request: %s against %s", data, schema) - validate(data, schema) - - # get defaults from the regular parameters - device, params, size = pipeline_from_request() - output = make_output_name(context, "chain", params, size) - job_name = output[0] - - pipeline = ChainPipeline() - for stage_data in data.get("stages", []): - callback = chain_stages[stage_data.get("type")] - kwargs = stage_data.get("params", {}) - logger.info("request stage: %s, %s", callback.__name__, kwargs) - - stage = StageParams( - stage_data.get("name", callback.__name__), - tile_size=get_size(kwargs.get("tile_size")), - outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), - ) - - if "border" in kwargs: - border = Border.even(int(kwargs.get("border"))) - kwargs["border"] = border - - if "upscale" in kwargs: - upscale = UpscaleParams(kwargs.get("upscale")) - kwargs["upscale"] = upscale - - stage_source_name = "source:%s" % (stage.name) - stage_mask_name = "mask:%s" % (stage.name) - - if stage_source_name in request.files: - logger.debug( - "loading source image %s for pipeline stage %s", - stage_source_name, - stage.name, - ) - source_file = request.files.get(stage_source_name) - source = Image.open(BytesIO(source_file.read())).convert("RGB") - source = valid_image(source, max_dims=(size.width, size.height)) - kwargs["stage_source"] = source - - if stage_mask_name in request.files: - logger.debug( - "loading mask image %s for pipeline stage %s", - stage_mask_name, - stage.name, - ) - mask_file = request.files.get(stage_mask_name) - mask = Image.open(BytesIO(mask_file.read())).convert("RGB") - mask = valid_image(mask, max_dims=(size.width, size.height)) - kwargs["stage_mask"] = mask - - pipeline.append((callback, stage, kwargs)) - - logger.info("running chain pipeline with %s stages", len(pipeline.stages)) - - # build and run chain pipeline - empty_source = Image.new("RGB", (size.width, size.height)) - executor.submit( - job_name, - pipeline, - context, - params, - empty_source, - output=output[0], - size=size, - needs_device=device, - ) - - return jsonify(json_params(output, params, size)) - - -@app.route("/api/blend", methods=["POST"]) -def blend(): - if "mask" not in request.files: - return error_reply("mask image is required") - - mask_file = request.files.get("mask") - mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") - mask = valid_image(mask) - - max_sources = 2 - sources = [] - - for i in range(max_sources): - source_file = request.files.get("source:%s" % (i)) - source = Image.open(BytesIO(source_file.read())).convert("RGBA") - source = valid_image(source, mask.size, mask.size) - sources.append(source) - - device, params, size = pipeline_from_request() - upscale = upscale_from_request() - - output = make_output_name(context, "upscale", params, size) - job_name = output[0] - logger.info("upscale job queued for: %s", job_name) - - executor.submit( - job_name, - run_blend_pipeline, - context, - params, - size, - output, - upscale, - sources, - mask, - needs_device=device, - ) - - return jsonify(json_params(output, params, size, upscale=upscale)) - - -@app.route("/api/txt2txt", methods=["POST"]) -def txt2txt(): - device, params, size = pipeline_from_request() - - output = make_output_name(context, "upscale", params, size) - logger.info("upscale job queued for: %s", output) - - executor.submit( - output, - run_txt2txt_pipeline, - context, - params, - size, - output, - needs_device=device, - ) - - return jsonify(json_params(output, params, size)) - - -@app.route("/api/cancel", methods=["PUT"]) -def cancel(): - output_file = request.args.get("output", None) - - cancel = executor.cancel(output_file) - - return ready_reply(cancel) - - -@app.route("/api/ready") -def ready(): - output_file = request.args.get("output", None) - - done, progress = executor.done(output_file) - - if done is None: - output = base_join(context.output_path, output_file) - if path.exists(output): - return ready_reply(True) - - return ready_reply(done, progress=progress) - - -@app.route("/api/status") -def status(): - return jsonify(executor.status()) - - -@app.route("/output/") -def output(filename: str): - return send_from_directory( - path.join("..", context.output_path), filename, as_attachment=False - ) diff --git a/api/onnx_web/server/__init__.py b/api/onnx_web/server/__init__.py index 0403746c5..f02fa35a6 100644 --- a/api/onnx_web/server/__init__.py +++ b/api/onnx_web/server/__init__.py @@ -1,10 +1,3 @@ -from .device_pool import ( - DeviceParams, - DevicePoolExecutor, - Job, - JobContext, - ProgressCallback, -) from .hacks import ( apply_patch_basicsr, apply_patch_codeformer, diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py new file mode 100644 index 000000000..ed70c76fa --- /dev/null +++ b/api/onnx_web/server/api.py @@ -0,0 +1,493 @@ +from io import BytesIO +from logging import getLogger +from os import path + +import yaml +from flask import Flask, jsonify, make_response, request, url_for +from jsonschema import validate +from PIL import Image + +from ..chain import CHAIN_STAGES, ChainPipeline +from ..diffusion.load import get_pipeline_schedulers +from ..diffusion.run import ( + run_blend_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, + run_upscale_pipeline, +) +from ..image import valid_image # mask filters; noise sources +from ..output import json_params, make_output_name +from ..params import Border, StageParams, TileOrder, UpscaleParams +from ..transformers import run_txt2txt_pipeline +from ..utils import ( + base_join, + get_and_clamp_float, + get_and_clamp_int, + get_from_list, + get_from_map, + get_not_empty, + get_size, + sanitize_name, +) +from ..worker.pool import DevicePoolExecutor +from .config import ( + get_available_platforms, + get_config_params, + get_config_value, + get_correction_models, + get_diffusion_models, + get_inversion_models, + get_mask_filters, + get_noise_sources, + get_upscaling_models, +) +from .context import ServerContext +from .params import border_from_request, pipeline_from_request, upscale_from_request +from .utils import wrap_route + +logger = getLogger(__name__) + + +def ready_reply(ready: bool, progress: int = 0): + return jsonify( + { + "progress": progress, + "ready": ready, + } + ) + + +def error_reply(err: str): + response = make_response( + jsonify( + { + "error": err, + } + ) + ) + response.status_code = 400 + return response + + +def url_from_rule(rule) -> str: + options = {} + for arg in rule.arguments: + options[arg] = ":%s" % (arg) + + return url_for(rule.endpoint, **options) + + +def introspect(context: ServerContext, app: Flask): + return { + "name": "onnx-web", + "routes": [ + {"path": url_from_rule(rule), "methods": list(rule.methods or []).sort()} + for rule in app.url_map.iter_rules() + ], + } + + +def list_mask_filters(context: ServerContext): + return jsonify(list(get_mask_filters().keys())) + + +def list_models(context: ServerContext): + return jsonify( + { + "correction": get_correction_models(), + "diffusion": get_diffusion_models(), + "inversion": get_inversion_models(), + "upscaling": get_upscaling_models(), + } + ) + + +def list_noise_sources(context: ServerContext): + return jsonify(list(get_noise_sources().keys())) + + +def list_params(context: ServerContext): + return jsonify(get_config_params()) + + +def list_platforms(context: ServerContext): + return jsonify([p.device for p in get_available_platforms()]) + + +def list_schedulers(context: ServerContext): + return jsonify(list(get_pipeline_schedulers().keys())) + + +def img2img(context: ServerContext, pool: DevicePoolExecutor): + source_file = request.files.get("source") + if source_file is None: + return error_reply("source image is required") + + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + strength = get_and_clamp_float( + request.args, + "strength", + get_config_value("strength"), + get_config_value("strength", "max"), + get_config_value("strength", "min"), + ) + + output = make_output_name(context, "img2img", params, size, extras=[strength]) + job_name = output[0] + logger.info("img2img job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_img2img_pipeline, + context, + params, + output, + upscale, + source, + strength, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def txt2img(context: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "txt2img", params, size) + job_name = output[0] + logger.info("txt2img job queued for: %s", job_name) + + pool.submit( + job_name, + run_txt2img_pipeline, + context, + params, + size, + output, + upscale, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def inpaint(context: ServerContext, pool: DevicePoolExecutor): + source_file = request.files.get("source") + if source_file is None: + return error_reply("source image is required") + + mask_file = request.files.get("mask") + if mask_file is None: + return error_reply("mask image is required") + + source = Image.open(BytesIO(source_file.read())).convert("RGB") + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + expand = border_from_request() + upscale = upscale_from_request() + + fill_color = get_not_empty(request.args, "fillColor", "white") + mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") + noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram") + tile_order = get_from_list( + request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] + ) + + output = make_output_name( + context, + "inpaint", + params, + size, + extras=[ + expand.left, + expand.right, + expand.top, + expand.bottom, + mask_filter.__name__, + noise_source.__name__, + fill_color, + tile_order, + ], + ) + job_name = output[0] + logger.info("inpaint job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + mask = valid_image(mask, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_inpaint_pipeline, + context, + params, + size, + output, + upscale, + source, + mask, + expand, + noise_source, + mask_filter, + fill_color, + tile_order, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) + + +def upscale(context: ServerContext, pool: DevicePoolExecutor): + source_file = request.files.get("source") + if source_file is None: + return error_reply("source image is required") + + source = Image.open(BytesIO(source_file.read())).convert("RGB") + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + source = valid_image(source, min_dims=size, max_dims=size) + pool.submit( + job_name, + run_upscale_pipeline, + context, + params, + size, + output, + upscale, + source, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def chain(context: ServerContext, pool: DevicePoolExecutor): + logger.debug( + "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() + ) + body = request.form.get("chain") or request.files.get("chain") + if body is None: + return error_reply("chain pipeline must have a body") + + data = yaml.safe_load(body) + with open("./schemas/chain.yaml", "r") as f: + schema = yaml.safe_load(f.read()) + + logger.debug("validating chain request: %s against %s", data, schema) + validate(data, schema) + + # get defaults from the regular parameters + device, params, size = pipeline_from_request(context) + output = make_output_name(context, "chain", params, size) + job_name = output[0] + + pipeline = ChainPipeline() + for stage_data in data.get("stages", []): + callback = CHAIN_STAGES[stage_data.get("type")] + kwargs = stage_data.get("params", {}) + logger.info("request stage: %s, %s", callback.__name__, kwargs) + + stage = StageParams( + stage_data.get("name", callback.__name__), + tile_size=get_size(kwargs.get("tile_size")), + outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), + ) + + if "border" in kwargs: + border = Border.even(int(kwargs.get("border"))) + kwargs["border"] = border + + if "upscale" in kwargs: + upscale = UpscaleParams(kwargs.get("upscale")) + kwargs["upscale"] = upscale + + stage_source_name = "source:%s" % (stage.name) + stage_mask_name = "mask:%s" % (stage.name) + + if stage_source_name in request.files: + logger.debug( + "loading source image %s for pipeline stage %s", + stage_source_name, + stage.name, + ) + source_file = request.files.get(stage_source_name) + if source_file is not None: + source = Image.open(BytesIO(source_file.read())).convert("RGB") + source = valid_image(source, max_dims=(size.width, size.height)) + kwargs["stage_source"] = source + + if stage_mask_name in request.files: + logger.debug( + "loading mask image %s for pipeline stage %s", + stage_mask_name, + stage.name, + ) + mask_file = request.files.get(stage_mask_name) + if mask_file is not None: + mask = Image.open(BytesIO(mask_file.read())).convert("RGB") + mask = valid_image(mask, max_dims=(size.width, size.height)) + kwargs["stage_mask"] = mask + + pipeline.append((callback, stage, kwargs)) + + logger.info("running chain pipeline with %s stages", len(pipeline.stages)) + + # build and run chain pipeline + empty_source = Image.new("RGB", (size.width, size.height)) + pool.submit( + job_name, + pipeline, + context, + params, + empty_source, + output=output[0], + size=size, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + +def blend(context: ServerContext, pool: DevicePoolExecutor): + mask_file = request.files.get("mask") + if mask_file is None: + return error_reply("mask image is required") + + mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") + mask = valid_image(mask) + + max_sources = 2 + sources = [] + + for i in range(max_sources): + source_file = request.files.get("source:%s" % (i)) + if source_file is None: + logger.warning("missing source %s", i) + else: + source = Image.open(BytesIO(source_file.read())).convert("RGBA") + source = valid_image(source, mask.size, mask.size) + sources.append(source) + + device, params, size = pipeline_from_request(context) + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + pool.submit( + job_name, + run_blend_pipeline, + context, + params, + size, + output, + upscale, + sources, + mask, + needs_device=device, + ) + + return jsonify(json_params(output, params, size, upscale=upscale)) + + +def txt2txt(context: ServerContext, pool: DevicePoolExecutor): + device, params, size = pipeline_from_request(context) + + output = make_output_name(context, "txt2txt", params, size) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) + + pool.submit( + job_name, + run_txt2txt_pipeline, + context, + params, + size, + output, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + +def cancel(context: ServerContext, pool: DevicePoolExecutor): + output_file = request.args.get("output", None) + if output_file is None: + return error_reply("output name is required") + + output_file = sanitize_name(output_file) + cancel = pool.cancel(output_file) + + return ready_reply(cancel) + + +def ready(context: ServerContext, pool: DevicePoolExecutor): + output_file = request.args.get("output", None) + if output_file is None: + return error_reply("output name is required") + + output_file = sanitize_name(output_file) + done, progress = pool.done(output_file) + + if done is None: + output = base_join(context.output_path, output_file) + if path.exists(output): + return ready_reply(True) + + return ready_reply(done or False, progress=progress) + + +def status(context: ServerContext, pool: DevicePoolExecutor): + return jsonify(pool.status()) + + +def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor): + return [ + app.route("/api")(wrap_route(introspect, context, app=app)), + app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)), + app.route("/api/settings/models")(wrap_route(list_models, context)), + app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)), + app.route("/api/settings/params")(wrap_route(list_params, context)), + app.route("/api/settings/platforms")(wrap_route(list_platforms, context)), + app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)), + app.route("/api/img2img", methods=["POST"])( + wrap_route(img2img, context, pool=pool) + ), + app.route("/api/txt2img", methods=["POST"])( + wrap_route(txt2img, context, pool=pool) + ), + app.route("/api/txt2txt", methods=["POST"])( + wrap_route(txt2txt, context, pool=pool) + ), + app.route("/api/inpaint", methods=["POST"])( + wrap_route(inpaint, context, pool=pool) + ), + app.route("/api/upscale", methods=["POST"])( + wrap_route(upscale, context, pool=pool) + ), + app.route("/api/chain", methods=["POST"])( + wrap_route(chain, context, pool=pool) + ), + app.route("/api/blend", methods=["POST"])( + wrap_route(blend, context, pool=pool) + ), + app.route("/api/cancel", methods=["PUT"])( + wrap_route(cancel, context, pool=pool) + ), + app.route("/api/ready")(wrap_route(ready, context, pool=pool)), + app.route("/api/status")(wrap_route(status, context, pool=pool)), + ] diff --git a/api/onnx_web/server/config.py b/api/onnx_web/server/config.py new file mode 100644 index 000000000..c03921046 --- /dev/null +++ b/api/onnx_web/server/config.py @@ -0,0 +1,229 @@ +from functools import cmp_to_key +from glob import glob +from logging import getLogger +from os import path +from typing import Dict, List, Union + +import torch +import yaml + +from ..image import ( # mask filters; noise sources + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, + noise_source_fill_edge, + noise_source_fill_mask, + noise_source_gaussian, + noise_source_histogram, + noise_source_normal, + noise_source_uniform, +) +from ..params import DeviceParams +from ..torch_before_ort import get_available_providers +from .context import ServerContext + +logger = getLogger(__name__) + +# config caching +config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} + +# pipeline params +platform_providers = { + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "directml": "DmlExecutionProvider", + "rocm": "ROCMExecutionProvider", +} +noise_sources = { + "fill-edge": noise_source_fill_edge, + "fill-mask": noise_source_fill_mask, + "gaussian": noise_source_gaussian, + "histogram": noise_source_histogram, + "normal": noise_source_normal, + "uniform": noise_source_uniform, +} +mask_filters = { + "none": mask_filter_none, + "gaussian-multiply": mask_filter_gaussian_multiply, + "gaussian-screen": mask_filter_gaussian_screen, +} + + +# Available ORT providers +available_platforms: List[DeviceParams] = [] + +# loaded from model_path +correction_models: List[str] = [] +diffusion_models: List[str] = [] +inversion_models: List[str] = [] +upscaling_models: List[str] = [] + + +def get_config_params(): + return config_params + + +def get_available_platforms(): + return available_platforms + + +def get_correction_models(): + return correction_models + + +def get_diffusion_models(): + return diffusion_models + + +def get_inversion_models(): + return inversion_models + + +def get_upscaling_models(): + return upscaling_models + + +def get_mask_filters(): + return mask_filters + + +def get_noise_sources(): + return noise_sources + + +def get_config_value(key: str, subkey: str = "default", default=None): + return config_params.get(key, {}).get(subkey, default) + + +def get_model_name(model: str) -> str: + base = path.basename(model) + (file, _ext) = path.splitext(base) + return file + + +def load_models(context: ServerContext) -> None: + global correction_models + global diffusion_models + global inversion_models + global upscaling_models + + diffusion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) + ] + diffusion_models.extend( + [ + get_model_name(f) + for f in glob(path.join(context.model_path, "stable-diffusion-*")) + ] + ) + diffusion_models = list(set(diffusion_models)) + diffusion_models.sort() + logger.debug("loaded diffusion models from disk: %s", diffusion_models) + + correction_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) + ] + correction_models = list(set(correction_models)) + correction_models.sort() + logger.debug("loaded correction models from disk: %s", correction_models) + + inversion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) + ] + inversion_models = list(set(inversion_models)) + inversion_models.sort() + logger.debug("loaded inversion models from disk: %s", inversion_models) + + upscaling_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) + ] + upscaling_models = list(set(upscaling_models)) + upscaling_models.sort() + logger.debug("loaded upscaling models from disk: %s", upscaling_models) + + +def load_params(context: ServerContext) -> None: + global config_params + + params_file = path.join(context.params_path, "params.json") + logger.debug("loading server parameters from file: %s", params_file) + + with open(params_file, "r") as f: + config_params = yaml.safe_load(f) + + if "platform" in config_params and context.default_platform is not None: + logger.info( + "overriding default platform from environment: %s", + context.default_platform, + ) + config_platform = config_params.get("platform", {}) + config_platform["default"] = context.default_platform + + +def load_platforms(context: ServerContext) -> None: + global available_platforms + + providers = list(get_available_providers()) + logger.debug("loading available platforms from providers: %s", providers) + + for potential in platform_providers: + if ( + platform_providers[potential] in providers + and potential not in context.block_platforms + ): + if potential == "cuda": + for i in range(torch.cuda.device_count()): + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + { + "device_id": i, + }, + context.optimizations, + ) + ) + else: + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + None, + context.optimizations, + ) + ) + + if context.any_platform: + # the platform should be ignored when the job is scheduled, but set to CPU just in case + available_platforms.append( + DeviceParams( + "any", + platform_providers["cpu"], + None, + context.optimizations, + ) + ) + + # make sure CPU is last on the list + def any_first_cpu_last(a: DeviceParams, b: DeviceParams): + if a.device == b.device: + return 0 + + # any should be first, if it's available + if a.device == "any": + return -1 + + # cpu should be last, if it's available + if a.device == "cpu": + return 1 + + return -1 + + available_platforms = sorted( + available_platforms, key=cmp_to_key(any_first_cpu_last) + ) + + logger.info( + "available acceleration platforms: %s", + ", ".join([str(p) for p in available_platforms]), + ) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 4e1f79601..3107a79b7 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,6 +1,6 @@ from logging import getLogger from os import environ, path -from typing import List +from typing import List, Optional from ..utils import get_boolean from .model_cache import ModelCache @@ -18,13 +18,13 @@ def __init__( cors_origin: str = "*", num_workers: int = 1, any_platform: bool = True, - block_platforms: List[str] = None, - default_platform: str = None, + block_platforms: Optional[List[str]] = None, + default_platform: Optional[str] = None, image_format: str = "png", - cache: ModelCache = None, - cache_path: str = None, + cache: Optional[ModelCache] = None, + cache_path: Optional[str] = None, show_progress: bool = True, - optimizations: List[str] = None, + optimizations: Optional[List[str]] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path diff --git a/api/onnx_web/server/device_pool.py b/api/onnx_web/server/device_pool.py deleted file mode 100644 index d6799b755..000000000 --- a/api/onnx_web/server/device_pool.py +++ /dev/null @@ -1,275 +0,0 @@ -from collections import Counter -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor -from logging import getLogger -from multiprocessing import Value -from traceback import format_exception -from typing import Any, Callable, List, Optional, Tuple, Union - -from ..params import DeviceParams -from ..utils import run_gc - -logger = getLogger(__name__) - -ProgressCallback = Callable[[int, int, Any], None] - - -class JobContext: - cancel: Value = None - device_index: Value = None - devices: List[DeviceParams] = None - key: str = None - progress: Value = None - - def __init__( - self, - key: str, - devices: List[DeviceParams], - cancel: bool = False, - device_index: int = -1, - progress: int = 0, - ): - self.key = key - self.devices = list(devices) - self.cancel = Value("B", cancel) - self.device_index = Value("i", device_index) - self.progress = Value("I", progress) - - def is_cancelled(self) -> bool: - return self.cancel.value - - def get_device(self) -> DeviceParams: - """ - Get the device assigned to this job. - """ - with self.device_index.get_lock(): - device_index = self.device_index.value - if device_index < 0: - raise ValueError("job has not been assigned to a device") - else: - device = self.devices[device_index] - logger.debug("job %s assigned to device %s", self.key, device) - return device - - def get_progress(self) -> int: - return self.progress.value - - def get_progress_callback(self) -> ProgressCallback: - def on_progress(step: int, timestep: int, latents: Any): - on_progress.step = step - if self.is_cancelled(): - raise RuntimeError("job has been cancelled") - else: - logger.debug("setting progress for job %s to %s", self.key, step) - self.set_progress(step) - - return on_progress - - def set_cancel(self, cancel: bool = True) -> None: - with self.cancel.get_lock(): - self.cancel.value = cancel - - def set_progress(self, progress: int) -> None: - with self.progress.get_lock(): - self.progress.value = progress - - -class Job: - """ - Link a future to its context. - """ - - context: JobContext = None - future: Future = None - key: str = None - - def __init__( - self, - key: str, - future: Future, - context: JobContext, - ): - self.context = context - self.future = future - self.key = key - - def get_progress(self) -> int: - return self.context.get_progress() - - def set_cancel(self, cancel: bool = True): - return self.context.set_cancel(cancel) - - def set_progress(self, progress: int): - return self.context.set_progress(progress) - - -class DevicePoolExecutor: - devices: List[DeviceParams] = None - jobs: List[Job] = None - next_device: int = 0 - pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None - recent: List[Tuple[str, int]] = None - - def __init__( - self, - devices: List[DeviceParams], - pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None, - recent_limit: int = 10, - ): - self.devices = devices - self.jobs = [] - self.next_device = 0 - self.recent = [] - self.recent_limit = recent_limit - - device_count = len(devices) - if pool is None: - logger.info( - "creating thread pool executor for %s devices: %s", - device_count, - [d.device for d in devices], - ) - self.pool = ThreadPoolExecutor(device_count) - else: - logger.info( - "using existing pool for %s devices: %s", - device_count, - [d.device for d in devices], - ) - self.pool = pool - - def cancel(self, key: str) -> bool: - """ - Cancel a job. If the job has not been started, this will cancel - the future and never execute it. If the job has been started, it - should be cancelled on the next progress callback. - """ - for job in self.jobs: - if job.key == key: - if job.future.cancel(): - return True - else: - job.set_cancel() - return True - - return False - - def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, progress in self.recent: - if key == k: - return (True, progress) - - for job in self.jobs: - if job.key == key: - done = job.future.done() - progress = job.get_progress() - - if done: - self.prune() - - return (done, progress) - - logger.warn("checking status for unknown key: %s", key) - return (None, 0) - - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: - # respect overrides if possible - if needs_device is not None: - for i in range(len(self.devices)): - if self.devices[i].device == needs_device.device: - return i - - # use the first/default device if there are no jobs - if len(self.jobs) == 0: - return 0 - - job_devices = [ - job.context.device_index.value for job in self.jobs if not job.future.done() - ] - job_counts = Counter(range(len(self.devices))) - job_counts.update(job_devices) - - queued = job_counts.most_common() - logger.debug("jobs queued by device: %s", queued) - - lowest_count = queued[-1][1] - lowest_devices = [d[0] for d in queued if d[1] == lowest_count] - lowest_devices.sort() - - return lowest_devices[0] - - def prune(self): - pending_jobs = [job for job in self.jobs if job.future.done()] - logger.debug("pruning %s of %s pending jobs", len(pending_jobs), len(self.jobs)) - - for job in pending_jobs: - self.recent.append((job.key, job.get_progress())) - try: - self.jobs.remove(job) - except ValueError as e: - logger.warning("error removing pruned job from pending: %s", e) - - recent_count = len(self.recent) - if recent_count > self.recent_limit: - logger.debug( - "pruning %s of %s recent jobs", - recent_count - self.recent_limit, - recent_count, - ) - self.recent[:] = self.recent[-self.recent_limit :] - - def submit( - self, - key: str, - fn: Callable[..., None], - /, - *args, - needs_device: Optional[DeviceParams] = None, - **kwargs, - ) -> None: - self.prune() - device = self.get_next_device(needs_device=needs_device) - logger.info( - "assigning job %s to device %s: %s", key, device, self.devices[device] - ) - - context = JobContext(key, self.devices, device_index=device) - future = self.pool.submit(fn, context, *args, **kwargs) - job = Job(key, future, context) - self.jobs.append(job) - - def job_done(f: Future): - try: - f.result() - logger.info("job %s finished successfully", key) - except Exception as err: - logger.warn( - "job %s failed with an error: %s", - key, - format_exception(type(err), err, err.__traceback__), - ) - run_gc([self.devices[device]]) - - future.add_done_callback(job_done) - - def status(self) -> List[Tuple[str, int, bool, int]]: - pending = [ - ( - job.key, - job.context.device_index.value, - job.future.done(), - job.get_progress(), - ) - for job in self.jobs - ] - recent = [ - ( - key, - None, - True, - progress, - ) - for key, progress in self.recent - ] - - pending.extend(recent) - return pending diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 4b7ac0b99..7847ccf4c 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -119,9 +119,8 @@ def patch_not_impl(): def patch_cache_path(ctx: ServerContext, url: str, **kwargs) -> str: - if url in cache_path_map: - cache_path = cache_path_map.get(url) - else: + cache_path = cache_path_map.get(url, None) + if cache_path is None: parsed = urlparse(url) cache_path = path.basename(parsed.path) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py new file mode 100644 index 000000000..c16d19efa --- /dev/null +++ b/api/onnx_web/server/params.py @@ -0,0 +1,175 @@ +from logging import getLogger +from typing import Tuple + +import numpy as np +from flask import request + +from ..diffusion.load import pipeline_schedulers +from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams +from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty +from .config import ( + get_available_platforms, + get_config_value, + get_correction_models, + get_upscaling_models, +) +from .context import ServerContext +from .utils import get_model_path + +logger = getLogger(__name__) + + +def pipeline_from_request( + context: ServerContext, +) -> Tuple[DeviceParams, ImageParams, Size]: + user = request.remote_addr + + # platform stuff + device = None + device_name = request.args.get("platform") + + if device_name is not None and device_name != "any": + for platform in get_available_platforms(): + if platform.device == device_name: + device = platform + + # pipeline stuff + lpw = get_not_empty(request.args, "lpw", "false") == "true" + model = get_not_empty(request.args, "model", get_config_value("model")) + model_path = get_model_path(context, model) + scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys()) + + if scheduler is None: + scheduler = get_config_value("scheduler") + + inversion = request.args.get("inversion", None) + inversion_path = None + if inversion is not None and inversion.strip() != "": + inversion_path = get_model_path(context, inversion) + + # image params + prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) + negative_prompt = request.args.get("negativePrompt", None) + + if negative_prompt is not None and negative_prompt.strip() == "": + negative_prompt = None + + batch = get_and_clamp_int( + request.args, + "batch", + get_config_value("batch"), + get_config_value("batch", "max"), + get_config_value("batch", "min"), + ) + cfg = get_and_clamp_float( + request.args, + "cfg", + get_config_value("cfg"), + get_config_value("cfg", "max"), + get_config_value("cfg", "min"), + ) + eta = get_and_clamp_float( + request.args, + "eta", + get_config_value("eta"), + get_config_value("eta", "max"), + get_config_value("eta", "min"), + ) + steps = get_and_clamp_int( + request.args, + "steps", + get_config_value("steps"), + get_config_value("steps", "max"), + get_config_value("steps", "min"), + ) + height = get_and_clamp_int( + request.args, + "height", + get_config_value("height"), + get_config_value("height", "max"), + get_config_value("height", "min"), + ) + width = get_and_clamp_int( + request.args, + "width", + get_config_value("width"), + get_config_value("width", "max"), + get_config_value("width", "min"), + ) + + seed = int(request.args.get("seed", -1)) + if seed == -1: + # this one can safely use np.random because it produces a single value + seed = np.random.randint(np.iinfo(np.int32).max) + + logger.info( + "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", + user, + steps, + scheduler, + model_path, + device or "any device", + width, + height, + cfg, + seed, + prompt, + ) + + params = ImageParams( + model_path, + scheduler, + prompt, + cfg, + steps, + seed, + eta=eta, + lpw=lpw, + negative_prompt=negative_prompt, + batch=batch, + inversion=inversion_path, + ) + size = Size(width, height) + return (device, params, size) + + +def border_from_request() -> Border: + left = get_and_clamp_int( + request.args, "left", 0, get_config_value("width", "max"), 0 + ) + right = get_and_clamp_int( + request.args, "right", 0, get_config_value("width", "max"), 0 + ) + top = get_and_clamp_int( + request.args, "top", 0, get_config_value("height", "max"), 0 + ) + bottom = get_and_clamp_int( + request.args, "bottom", 0, get_config_value("height", "max"), 0 + ) + + return Border(left, right, top, bottom) + + +def upscale_from_request() -> UpscaleParams: + denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) + scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) + outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) + upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) + correction = get_from_list(request.args, "correction", get_correction_models()) + faces = get_not_empty(request.args, "faces", "false") == "true" + face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) + face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) + upscale_order = request.args.get("upscaleOrder", "correction-first") + + return UpscaleParams( + upscaling, + correction_model=correction, + denoise=denoise, + faces=faces, + face_outscale=face_outscale, + face_strength=face_strength, + format="onnx", + outscale=outscale, + scale=scale, + upscale_order=upscale_order, + ) diff --git a/api/onnx_web/server/static.py b/api/onnx_web/server/static.py new file mode 100644 index 000000000..9a67bf0ed --- /dev/null +++ b/api/onnx_web/server/static.py @@ -0,0 +1,36 @@ +from os import path + +from flask import Flask, send_from_directory + +from ..worker.pool import DevicePoolExecutor +from .context import ServerContext +from .utils import wrap_route + + +def serve_bundle_file(context: ServerContext, filename="index.html"): + return send_from_directory(path.join("..", context.bundle_path), filename) + + +# non-API routes +def index(context: ServerContext): + return serve_bundle_file(context) + + +def index_path(context: ServerContext, filename: str): + return serve_bundle_file(context, filename) + + +def output(context: ServerContext, filename: str): + return send_from_directory( + path.join("..", context.output_path), filename, as_attachment=False + ) + + +def register_static_routes( + app: Flask, context: ServerContext, pool: DevicePoolExecutor +): + return [ + app.route("/")(wrap_route(index, context)), + app.route("/")(wrap_route(index_path, context)), + app.route("/output/")(wrap_route(output, context)), + ] diff --git a/api/onnx_web/server/utils.py b/api/onnx_web/server/utils.py new file mode 100644 index 000000000..56cc33e0a --- /dev/null +++ b/api/onnx_web/server/utils.py @@ -0,0 +1,37 @@ +from functools import partial, update_wrapper +from os import makedirs, path +from typing import Callable, Dict, List, Tuple + +from flask import Flask + +from ..utils import base_join +from ..worker.pool import DevicePoolExecutor +from .context import ServerContext + + +def check_paths(context: ServerContext) -> None: + if not path.exists(context.model_path): + raise RuntimeError("model path must exist") + + if not path.exists(context.output_path): + makedirs(context.output_path) + + +def get_model_path(context: ServerContext, model: str): + return base_join(context.model_path, model) + + +def register_routes( + app: Flask, + context: ServerContext, + pool: DevicePoolExecutor, + routes: List[Tuple[str, Dict, Callable]], +): + for route, kwargs, method in routes: + app.route(route, **kwargs)(wrap_route(method, context, pool=pool)) + + +def wrap_route(func, *args, **kwargs): + partial_func = partial(func, *args, **kwargs) + update_wrapper(partial_func, func) + return partial_func diff --git a/api/onnx_web/torch_before_ort.py b/api/onnx_web/torch_before_ort.py new file mode 100644 index 000000000..506c14783 --- /dev/null +++ b/api/onnx_web/torch_before_ort.py @@ -0,0 +1,5 @@ +# this file exists to make sure torch is always imported before onnxruntime +# to work around https://github.com/microsoft/onnxruntime/issues/11092 + +import torch # NOQA +from onnxruntime import * # NOQA diff --git a/api/onnx_web/transformers.py b/api/onnx_web/transformers.py index f7a70693b..299d38544 100644 --- a/api/onnx_web/transformers.py +++ b/api/onnx_web/transformers.py @@ -1,13 +1,14 @@ from logging import getLogger from .params import ImageParams, Size -from .server import JobContext, ServerContext +from .server import ServerContext +from .worker import WorkerContext logger = getLogger(__name__) def run_txt2txt_pipeline( - job: JobContext, + job: WorkerContext, _server: ServerContext, params: ImageParams, _size: Size, @@ -21,13 +22,13 @@ def run_txt2txt_pipeline( device = job.get_device() - model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device()) + pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str()) tokenizer = AutoTokenizer.from_pretrained(model) input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to( - device.torch_device() + device.torch_str() ) - results = model.generate( + results = pipe.generate( input_ids, do_sample=True, max_length=tokens, diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index c04d6efb4..c9b7308ae 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -1,4 +1,5 @@ from logging import getLogger +from typing import Optional from PIL import Image @@ -10,20 +11,21 @@ upscale_stable_diffusion, ) from .params import ImageParams, SizeChart, StageParams, UpscaleParams -from .server import JobContext, ProgressCallback, ServerContext +from .server import ServerContext +from .worker import ProgressCallback, WorkerContext logger = getLogger(__name__) def run_upscale_correction( - job: JobContext, + job: WorkerContext, server: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image, *, upscale: UpscaleParams, - callback: ProgressCallback = None, + callback: Optional[ProgressCallback] = None, ) -> Image.Image: """ This is a convenience method for a chain pipeline that will run upscaling and diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 69012af96..74f998f8a 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -2,7 +2,7 @@ import threading from logging import getLogger from os import environ, path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch @@ -10,6 +10,8 @@ logger = getLogger(__name__) +SAFE_CHARS = "._-" + def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") @@ -36,7 +38,7 @@ def get_and_clamp_int( return min(max(int(args.get(key, default_value)), min_value), max_value) -def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]: +def get_from_list(args: Any, key: str, values: Sequence[Any]) -> Optional[Any]: selected = args.get(key, None) if selected in values: return selected @@ -82,7 +84,7 @@ def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: raise ValueError("invalid size") -def run_gc(devices: List[DeviceParams] = None): +def run_gc(devices: Optional[List[DeviceParams]] = None): logger.debug( "running garbage collection with %s active threads", threading.active_count() ) @@ -100,3 +102,7 @@ def run_gc(devices: List[DeviceParams] = None): (mem_total - mem_free), mem_total, ) + + +def sanitize_name(name): + return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS)) diff --git a/api/onnx_web/worker/__init__.py b/api/onnx_web/worker/__init__.py new file mode 100644 index 000000000..c1f2d7949 --- /dev/null +++ b/api/onnx_web/worker/__init__.py @@ -0,0 +1,2 @@ +from .context import WorkerContext, ProgressCallback +from .pool import DevicePoolExecutor diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py new file mode 100644 index 000000000..1b14d80e7 --- /dev/null +++ b/api/onnx_web/worker/context.py @@ -0,0 +1,87 @@ +from logging import getLogger +from typing import Any, Callable, Tuple + +from torch.multiprocessing import Queue, Value + +from ..params import DeviceParams + +logger = getLogger(__name__) + + +ProgressCallback = Callable[[int, int, Any], None] + + +class WorkerContext: + cancel: "Value[bool]" + job: str + pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]" + progress: "Value[int]" + + def __init__( + self, + job: str, + device: DeviceParams, + cancel: "Value[bool]", + logs: "Queue[str]", + pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]", + progress: "Queue[Tuple[str, str, int]]", + finished: "Queue[Tuple[str, str]]", + ): + self.job = job + self.device = device + self.cancel = cancel + self.progress = progress + self.finished = finished + self.logs = logs + self.pending = pending + + def is_cancelled(self) -> bool: + return self.cancel.value + + def get_device(self) -> DeviceParams: + """ + Get the device assigned to this job. + """ + return self.device + + def get_progress(self) -> int: + return self.progress.value + + def get_progress_callback(self) -> ProgressCallback: + def on_progress(step: int, timestep: int, latents: Any): + on_progress.step = step + if self.is_cancelled(): + raise RuntimeError("job has been cancelled") + else: + logger.debug("setting progress for job %s to %s", self.job, step) + self.set_progress(step) + + return on_progress + + def set_cancel(self, cancel: bool = True) -> None: + with self.cancel.get_lock(): + self.cancel.value = cancel + + def set_progress(self, progress: int) -> None: + self.progress.put((self.job, self.device.device, progress), block=False) + + def set_finished(self) -> None: + self.finished.put((self.job, self.device.device)) + + def clear_flags(self) -> None: + self.set_cancel(False) + self.set_progress(0) + + +class JobStatus: + def __init__( + self, + name: str, + progress: int = 0, + cancelled: bool = False, + finished: bool = False, + ) -> None: + self.name = name + self.progress = progress + self.cancelled = cancelled + self.finished = finished diff --git a/api/onnx_web/worker/logging.py b/api/onnx_web/worker/logging.py new file mode 100644 index 000000000..ab90a266f --- /dev/null +++ b/api/onnx_web/worker/logging.py @@ -0,0 +1 @@ +# TODO: queue-based logger diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py new file mode 100644 index 000000000..38285df2c --- /dev/null +++ b/api/onnx_web/worker/pool.py @@ -0,0 +1,340 @@ +from collections import Counter +from logging import getLogger +from queue import Empty +from threading import Thread +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch.multiprocessing import Process, Queue, Value + +from ..params import DeviceParams +from ..server import ServerContext +from .context import WorkerContext +from .worker import worker_main + +logger = getLogger(__name__) + + +class DevicePoolExecutor: + server: ServerContext + devices: List[DeviceParams] + max_jobs_per_worker: int + max_pending_per_worker: int + join_timeout: float + + context: Dict[str, WorkerContext] # Device -> Context + pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] + threads: Dict[str, Thread] + workers: Dict[str, Process] + + active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus] + cancelled_jobs: List[str] + finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus] + total_jobs: int + + logs: "Queue" + progress: "Queue[Tuple[str, str, int]]" + finished: "Queue[Tuple[str, str]]" + + def __init__( + self, + server: ServerContext, + devices: List[DeviceParams], + max_jobs_per_worker: int = 10, + max_pending_per_worker: int = 100, + join_timeout: float = 1.0, + ): + self.server = server + self.devices = devices + self.max_jobs_per_worker = max_jobs_per_worker + self.max_pending_per_worker = max_pending_per_worker + self.join_timeout = join_timeout + + self.context = {} + self.pending = {} + self.threads = {} + self.workers = {} + + self.active_jobs = {} + self.cancelled_jobs = [] + self.finished_jobs = [] + self.total_jobs = 0 # TODO: turn this into a Dict per-worker + + self.logs = Queue(self.max_pending_per_worker) + self.progress = Queue(self.max_pending_per_worker) + self.finished = Queue(self.max_pending_per_worker) + + self.create_logger_worker() + self.create_progress_worker() + self.create_finished_worker() + + for device in devices: + self.create_device_worker(device) + + logger.debug("testing log worker") + self.logs.put("testing") + + def create_device_worker(self, device: DeviceParams) -> None: + name = device.device + + # reuse the queue if possible, to keep queued jobs + if name in self.pending: + logger.debug("using existing pending job queue") + pending = self.pending[name] + else: + logger.debug("creating new pending job queue") + pending = Queue(self.max_pending_per_worker) + self.pending[name] = pending + + context = WorkerContext( + name, + device, + cancel=Value("B", False), + progress=self.progress, + finished=self.finished, + logs=self.logs, + pending=pending, + ) + self.context[name] = context + self.workers[name] = Process( + name=f"onnx-web worker: {name}", + target=worker_main, + args=(context, self.server), + ) + + logger.debug("starting worker for device %s", device) + self.workers[name].start() + + def create_logger_worker(self) -> None: + def logger_worker(logs: Queue): + logger.info("checking in from logger worker thread") + + while True: + try: + job = logs.get(timeout=(self.join_timeout / 2)) + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + except Empty: + pass + except ValueError: + break + except Exception as err: + logger.error("error in log worker: %s", err) + + logger_thread = Thread( + name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True + ) + self.threads["logger"] = logger_thread + + logger.debug("starting logger worker") + logger_thread.start() + + def create_progress_worker(self) -> None: + def progress_worker(progress: Queue): + logger.info("checking in from progress worker thread") + while True: + try: + job, device, value = progress.get(timeout=(self.join_timeout / 2)) + logger.info("progress update for job: %s to %s", job, value) + self.active_jobs[job] = (device, value) + if job in self.cancelled_jobs: + logger.debug( + "setting flag for cancelled job: %s on %s", job, device + ) + self.context[device].set_cancel() + except Empty: + pass + except ValueError: + break + except Exception as err: + logger.error("error in progress worker: %s", err) + + progress_thread = Thread( + name="onnx-web progress", + target=progress_worker, + args=(self.progress,), + daemon=True, + ) + self.threads["progress"] = progress_thread + + logger.debug("starting progress worker") + progress_thread.start() + + def create_finished_worker(self) -> None: + def finished_worker(finished: Queue): + logger.info("checking in from finished worker thread") + while True: + try: + job, device = finished.get(timeout=(self.join_timeout / 2)) + logger.info("job has been finished: %s", job) + context = self.context[device] + _device, progress = self.active_jobs[job] + self.finished_jobs.append((job, progress, context.cancel.value)) + del self.active_jobs[job] + except Empty: + pass + except ValueError: + break + except Exception as err: + logger.error("error in finished worker: %s", err) + + finished_thread = Thread( + name="onnx-web finished", + target=finished_worker, + args=(self.finished,), + daemon=True, + ) + self.threads["finished"] = finished_thread + + logger.debug("started finished worker") + finished_thread.start() + + def get_job_context(self, key: str) -> WorkerContext: + device, _progress = self.active_jobs[key] + return self.context[device] + + def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: + # respect overrides if possible + if needs_device is not None: + for i in range(len(self.devices)): + if self.devices[i].device == needs_device.device: + return i + + jobs = Counter(range(len(self.devices))) + jobs.update([self.pending[d.device].qsize() for d in self.devices]) + + queued = jobs.most_common() + logger.debug("jobs queued by device: %s", queued) + + lowest_count = queued[-1][1] + lowest_devices = [d[0] for d in queued if d[1] == lowest_count] + lowest_devices.sort() + + return lowest_devices[0] + + def cancel(self, key: str) -> bool: + """ + Cancel a job. If the job has not been started, this will cancel + the future and never execute it. If the job has been started, it + should be cancelled on the next progress callback. + """ + + self.cancelled_jobs.append(key) + + if key not in self.active_jobs: + logger.debug("cancelled job has not been started yet: %s", key) + return True + + device, _progress = self.active_jobs[key] + logger.info("cancelling job %s, active on device %s", key, device) + + context = self.context[device] + context.set_cancel() + + return True + + def done(self, key: str) -> Tuple[Optional[bool], int]: + """ + Check if a job has been finished and report the last progress update. + + If the job is still active or pending, the first item will be False. + If the job is not finished or active, the first item will be None. + """ + for k, p, c in self.finished_jobs: + if k == key: + return (True, p) + + if key not in self.active_jobs: + logger.warn("checking status for unknown job: %s", key) + return (None, 0) + + _device, progress = self.active_jobs[key] + return (False, progress) + + def join(self): + logger.info("stopping worker pool") + + logger.debug("closing queues") + self.logs.close() + self.finished.close() + self.progress.close() + for queue in self.pending.values(): + queue.close() + + self.pending.clear() + + logger.debug("stopping device workers") + for device, worker in self.workers.items(): + if worker.is_alive(): + logger.debug("stopping worker for device %s", device) + worker.join(self.join_timeout) + # worker.terminate() + else: + logger.debug("worker for device %s has died", device) + + for name, thread in self.threads.items(): + logger.debug("stopping worker thread: %s", name) + thread.join(self.join_timeout) + + logger.debug("worker pool fully joined") + + def recycle(self): + for name, proc in self.workers.items(): + if proc.is_alive(): + logger.debug("shutting down worker for device %s", name) + proc.join(self.join_timeout) + # proc.terminate() + else: + logger.warning("worker for device %s has died", name) + + self.workers[name] = None + del proc + + logger.info("starting new workers") + + for device in self.devices: + self.create_device_worker(device) + + def submit( + self, + key: str, + fn: Callable[..., None], + /, + *args, + needs_device: Optional[DeviceParams] = None, + **kwargs, + ) -> None: + self.total_jobs += 1 + logger.debug("pool job count: %s", self.total_jobs) + if self.total_jobs > self.max_jobs_per_worker: + self.recycle() + self.total_jobs = 0 + + device_idx = self.get_next_device(needs_device=needs_device) + logger.info( + "assigning job %s to device %s: %s", + key, + device_idx, + self.devices[device_idx], + ) + + device = self.devices[device_idx].device + self.pending[device].put((key, fn, args, kwargs), block=False) + + def status(self) -> List[Tuple[str, int, bool, bool]]: + history = [ + (name, progress, False, name in self.cancelled_jobs) + for name, (_device, progress) in self.active_jobs.items() + ] + history.extend( + [ + ( + name, + progress, + True, + cancel, + ) + for name, progress, cancel in self.finished_jobs + ] + ) + return history diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py new file mode 100644 index 000000000..94c570205 --- /dev/null +++ b/api/onnx_web/worker/worker.py @@ -0,0 +1,44 @@ +from logging import getLogger +from queue import Empty +from sys import exit +from traceback import format_exception + +from setproctitle import setproctitle + +from ..server import ServerContext, apply_patches +from ..torch_before_ort import get_available_providers +from .context import WorkerContext + +logger = getLogger(__name__) + + +def worker_main(context: WorkerContext, server: ServerContext): + apply_patches(server) + setproctitle("onnx-web worker: %s" % (context.device.device)) + + logger.info("checking in from worker, %s", get_available_providers()) + + while True: + try: + name, fn, args, kwargs = context.pending.get(timeout=1.0) + logger.info("worker for %s got job: %s", context.device.device, name) + + context.job = name # TODO: hax + context.clear_flags() + logger.info("starting job: %s", name) + fn(context, *args, **kwargs) + logger.info("job succeeded: %s", name) + context.set_finished() + except Empty: + pass + except KeyboardInterrupt: + logger.info("worker got keyboard interrupt") + exit(0) + except ValueError as e: + logger.info("value error in worker: %s", e) + exit(1) + except Exception as e: + logger.error( + "error while running job: %s", + format_exception(type(e), e, e.__traceback__), + ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 04c7bca40..e337c1f21 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -8,16 +8,35 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion [tool.mypy] # ignore_missing_imports = true +exclude = [ + "onnx_web.diffusion.lpw_stable_diffusion_onnx", + "onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale" +] [[tool.mypy.overrides]] module = [ "basicsr.archs.rrdbnet_arch", + "basicsr.utils.download_util", + "basicsr.utils", + "basicsr", "boto3", "codeformer", + "codeformer.facelib.utils.misc", + "codeformer.facelib.utils", + "codeformer.facelib", "diffusers", + "diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion", + "diffusers.pipelines.paint_by_example", + "diffusers.pipelines.stable_diffusion", "diffusers.pipeline_utils", + "diffusers.utils.logging", + "facexlib.utils", + "facexlib", "gfpgan", "onnxruntime", - "realesrgan" + "realesrgan", + "realesrgan.archs.srvgg_arch", + "safetensors", + "transformers" ] ignore_missing_imports = true \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index e08b7c312..2ac35af81 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -23,3 +23,4 @@ flask flask-cors jsonschema pyyaml +setproctitle \ No newline at end of file diff --git a/api/scripts/test-memory-leak.sh b/api/scripts/test-memory-leak.sh index 5e432e433..41953b4c5 100755 --- a/api/scripts/test-memory-leak.sh +++ b/api/scripts/test-memory-leak.sh @@ -4,7 +4,7 @@ test_images=0 while true; do curl "http://${test_host}:5000/api/txt2img?"\ -'cfg=16.00&steps=3&scheduler=deis-multi&seed=-1&'\ +'cfg=16.00&steps=3&scheduler=ddim&seed=-1&'\ 'prompt=an+astronaut+eating+a+hamburger&negativePrompt=&'\ 'model=stable-diffusion-onnx-v1-5&platform=any&'\ 'upscaling=upscaling-real-esrgan-x2-plus&correction=correction-codeformer&'\ @@ -14,5 +14,5 @@ do --insecure || break; ((test_images++)); echo "waiting after $test_images"; - sleep 10; + sleep 30; done