Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

per device worker #205

Merged
merged 40 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e46a1e5
begin switching to per-device torch mp workers
ssube Feb 26, 2023
f898de8
background workers, logger
ssube Feb 26, 2023
943281f
wire up worker jobs
ssube Feb 26, 2023
06c74a7
feat(api): remove Flask app from global scope
ssube Feb 26, 2023
6998e87
rejoin worker pool
ssube Feb 26, 2023
d765a6f
make logger start up well
ssube Feb 26, 2023
e1d0ad5
lock per worker, torch before ORT
ssube Feb 26, 2023
f115326
apply patches within workers
ssube Feb 26, 2023
e0737e9
update progress and finished flag from worker
ssube Feb 26, 2023
6502e1e
recycle worker pool after 10 jobs
ssube Feb 26, 2023
b880b7a
set process titles, terminate workers
ssube Feb 26, 2023
584dddb
lint all the new stuff
ssube Feb 26, 2023
d1961af
re-implement cancellation
ssube Feb 26, 2023
85118d1
clear worker flags between jobs, attempt to record finished jobs again
ssube Feb 26, 2023
b931da1
fix imports, lint
ssube Feb 26, 2023
eb82e73
initialize list of finished jobs
ssube Feb 26, 2023
525ee24
track started and finished jobs
ssube Feb 27, 2023
401ee20
fix finished flag
ssube Feb 27, 2023
a37d1a4
use progress queue
ssube Feb 27, 2023
1339593
always put progress in active jobs
ssube Feb 27, 2023
66a20e6
run logger in a thread, clean up status
ssube Feb 27, 2023
2327b24
join all threads
ssube Feb 27, 2023
113ad05
typo
ssube Feb 27, 2023
06f06f5
error handling in all threads
ssube Feb 28, 2023
61373d5
fix Windows entrypoint
ssube Feb 28, 2023
0793b61
consistently pass job key to workers
ssube Feb 28, 2023
1367592
set queue timeouts
ssube Feb 28, 2023
953e5ab
handle empty errors
ssube Feb 28, 2023
988088d
quit workers on keyboard signal
ssube Feb 28, 2023
da6ae5d
more logging around shutdown, close queues
ssube Feb 28, 2023
f7f438e
directly rejoin pool
ssube Feb 28, 2023
1ce98ac
add value error handling
ssube Feb 28, 2023
7e0ccdb
remove pending queues after joining
ssube Feb 28, 2023
4ae3d9c
remove task done
ssube Feb 28, 2023
cad0d37
some pending queue logging
ssube Feb 28, 2023
0011f07
daemonize queue collectors
ssube Feb 28, 2023
c95ac1f
avoid terminating workers because it breaks their queues
ssube Feb 28, 2023
c99aa67
name threads, max queues, type/lint fixes
ssube Mar 1, 2023
12fb7f5
fix(api): sanitize filenames in user input
ssube Mar 1, 2023
1f9efb4
apply lint
ssube Mar 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/.gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.coverage
coverage.xml

*.log
*.swp
*.pyc

Expand Down
2 changes: 1 addition & 1 deletion api/launch-extras.bat
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=extras.json)
python -m onnx_web.convert --sources --diffusion --upscaling --correction --extras=%ONNX_WEB_EXTRA_MODELS% --token=%HF_TOKEN%

echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
flask --app="onnx_web.serve:run" run --host=0.0.0.0
2 changes: 1 addition & 1 deletion api/launch-extras.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-}

echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
flask --app='onnx_web.main:run' run --host=0.0.0.0
2 changes: 1 addition & 1 deletion api/launch.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ echo "Downloading and converting models to ONNX format..."
python -m onnx_web.convert --sources --diffusion --upscaling --correction --token=%HF_TOKEN%

echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
flask --app="onnx_web.serve:run" run --host=0.0.0.0
2 changes: 1 addition & 1 deletion api/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
--token=${HF_TOKEN:-}

echo "Launching API server..."
flask --app=onnx_web.serve run --host=0.0.0.0
flask --app='onnx_web.main:run' run --host=0.0.0.0
6 changes: 3 additions & 3 deletions api/logging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ formatters:
handlers:
console:
class: logging.StreamHandler
level: INFO
level: DEBUG
formatter: simple
stream: ext://sys.stdout
loggers:
'':
level: INFO
level: DEBUG
handlers: [console]
propagate: True
root:
level: INFO
level: DEBUG
handlers: [console]
13 changes: 10 additions & 3 deletions api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from . import logging
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
from .chain import (
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline
from .diffusion.run import (
run_blend_pipeline,
Expand All @@ -25,6 +30,7 @@
from .onnx import OnnxNet, OnnxTensor
from .params import (
Border,
DeviceParams,
ImageParams,
Param,
Point,
Expand All @@ -33,8 +39,6 @@
UpscaleParams,
)
from .server import (
DeviceParams,
DevicePoolExecutor,
ModelCache,
ServerContext,
apply_patch_basicsr,
Expand All @@ -51,3 +55,6 @@
get_from_map,
get_not_empty,
)
from .worker import (
DevicePoolExecutor,
)
17 changes: 17 additions & 0 deletions api/onnx_web/chain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,20 @@
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion

CHAIN_STAGES = {
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-mask": blend_mask,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-txt2img": source_txt2img,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
}
7 changes: 4 additions & 3 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order

logger = getLogger(__name__)
Expand All @@ -17,7 +18,7 @@
class StageCallback(Protocol):
def __call__(
self,
job: JobContext,
job: WorkerContext,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
Expand Down Expand Up @@ -77,7 +78,7 @@ def append(self, stage: PipelineStage):

def __call__(
self,
job: JobContext,
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Image.Image,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext

logger = getLogger(__name__)


def blend_img2img(
job: JobContext,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order

logger = getLogger(__name__)


def blend_inpaint(
job: JobContext,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
Expand Down
10 changes: 5 additions & 5 deletions api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@

from PIL import Image

from onnx_web.image import valid_image
from onnx_web.output import save_image

from ..image import valid_image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext

logger = getLogger(__name__)


def blend_mask(
_job: JobContext,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from PIL import Image

from ..params import ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)

device = "cpu"


def correct_codeformer(
job: JobContext,
job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from PIL import Image

from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext

logger = getLogger(__name__)

Expand Down Expand Up @@ -46,7 +47,7 @@ def load_gfpgan(


def correct_gfpgan(
job: JobContext,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)


def persist_disk(
_job: JobContext,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/persist_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from PIL import Image

from ..params import ImageParams, StageParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)


def persist_s3(
_job: JobContext,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/reduce_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from PIL import Image

from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)


def reduce_crop(
_job: JobContext,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/reduce_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from PIL import Image

from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)


def reduce_thumbnail(
_job: JobContext,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/source_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from PIL import Image

from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..worker import WorkerContext

logger = getLogger(__name__)


def source_noise(
_job: JobContext,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext

logger = getLogger(__name__)


def source_txt2img(
job: JobContext,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_grid, process_tile_order

logger = getLogger(__name__)


def upscale_outpaint(
job: JobContext,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ServerContext
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext

logger = getLogger(__name__)

Expand Down Expand Up @@ -96,7 +97,7 @@ def load_resrgan(


def upscale_resrgan(
job: JobContext,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
OnnxStableDiffusionUpscalePipeline,
)
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import JobContext, ProgressCallback, ServerContext
from ..server import ServerContext
from ..utils import run_gc
from ..worker import ProgressCallback, WorkerContext

logger = getLogger(__name__)

Expand Down Expand Up @@ -62,7 +63,7 @@ def load_stable_diffusion(


def upscale_stable_diffusion(
job: JobContext,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
Expand Down
Loading