Skip to content

Commit

Permalink
feat(api): switch to device pool for background workers
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 4, 2023
1 parent ecec0a2 commit 6426cff
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 44 deletions.
113 changes: 113 additions & 0 deletions api/onnx_web/device_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
from logging import getLogger
from multiprocessing import Value
from typing import Any, Callable, List, Union, Optional

logger = getLogger(__name__)


class JobContext:
def __init__(
self,
key: str,
devices: List[str],
cancel: bool = False,
device_index: int = -1,
progress: int = 0,
):
self.key = key
self.devices = list(devices)
self.cancel = Value('B', cancel)
self.device_index = Value('i', device_index)
self.progress = Value('I', progress)

def is_cancelled(self) -> bool:
return self.cancel.value

def get_device(self) -> str:
'''
Get the device assigned to this job.
'''
with self.device_index.get_lock():
device_index = self.device_index.value
if device_index < 0:
raise Exception('job has not been assigned to a device')
else:
return self.devices[device_index]

def get_progress_callback(self) -> Callable[..., None]:
def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled():
raise Exception('job has been cancelled')
else:
self.set_progress(step)

return on_progress

def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock():
self.cancel.value = cancel

def set_progress(self, progress: int) -> None:
with self.progress.get_lock():
self.progress.value = progress


class Job:
def __init__(
self,
key: str,
future: Future,
context: JobContext,
):
self.context = context
self.future = future
self.key = key

def set_cancel(self, cancel: bool = True):
self.context.set_cancel(cancel)

def set_progress(self, progress: int):
self.context.set_progress(progress)


class DevicePoolExecutor:
devices: List[str] = None
jobs: List[Job] = None
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None

def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]]):
self.devices = devices
self.jobs = []
self.pool = pool or ThreadPoolExecutor(len(devices))

def cancel(self, key: str) -> bool:
'''
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
'''
for job in self.jobs:
if job.key == key:
if job.future.cancel():
return True
else:
with job.cancel.get_lock():
job.cancel.value = True

def done(self, key: str) -> bool:
for job in self.jobs:
if job.key == key:
return job.future.done()

logger.warn('checking status for unknown key: %s', key)
return None

def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()]

def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
context = JobContext(key, self.devices, device_index=0)
future = self.pool.submit(fn, context, *args, **kwargs)
job = Job(key, future, context)
self.jobs.append(job)
60 changes: 41 additions & 19 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from ..chain import (
upscale_outpaint,
)
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Border,
Expand Down Expand Up @@ -38,18 +41,21 @@


def run_txt2img_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler)
params.model, params.provider, params.scheduler, device=device)

latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)

progress = job.get_progress_callback()
result = pipe(
params.prompt,
height=size.height,
Expand All @@ -59,13 +65,14 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
server, StageParams(), params, image, upscale=upscale)

dest = save_image(ctx, output, image)
save_params(ctx, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)

del image
del result
Expand All @@ -75,18 +82,21 @@ def run_txt2img_pipeline(


def run_img2img_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
strength: float,
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
params.model, params.provider, params.scheduler, device=device)

rng = np.random.RandomState(params.seed)

progress = job.get_progress_callback()
result = pipe(
params.prompt,
generator=rng,
Expand All @@ -95,14 +105,15 @@ def run_img2img_pipeline(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
server, StageParams(), params, image, upscale=upscale)

dest = save_image(ctx, output, image)
dest = save_image(server, output, image)
size = Size(*source_image.size)
save_params(ctx, output, params, size, upscale=upscale)
save_params(server, output, params, size, upscale=upscale)

del image
del result
Expand All @@ -112,7 +123,8 @@ def run_img2img_pipeline(


def run_inpaint_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
Expand All @@ -125,9 +137,13 @@ def run_inpaint_pipeline(
strength: float,
fill_color: str,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
stage = StageParams()

# TODO: pass device, progress
image = upscale_outpaint(
ctx,
server,
stage,
params,
source_image,
Expand All @@ -146,10 +162,10 @@ def run_inpaint_pipeline(
'output image size does not match source, skipping post-blend')

image = run_upscale_correction(
ctx, stage, params, image, upscale=upscale)
server, stage, params, image, upscale=upscale)

dest = save_image(ctx, output, image)
save_params(ctx, output, params, size, upscale=upscale, border=border)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)

del image
run_gc()
Expand All @@ -158,18 +174,24 @@ def run_inpaint_pipeline(


def run_upscale_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
stage = StageParams()

# TODO: pass device, progress
image = run_upscale_correction(
ctx, StageParams(), params, source_image, upscale=upscale)
server, stage, params, source_image, upscale=upscale)

dest = save_image(ctx, output, image)
save_params(ctx, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)

del image
run_gc()
Expand Down

0 comments on commit 6426cff

Please sign in to comment.