Skip to content

Commit

Permalink
fix(api): pass job context and device to upscaling
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 4, 2023
1 parent 8a81e8b commit 3637f64
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
10 changes: 4 additions & 6 deletions api/onnx_web/diffusion/run.py
Expand Up @@ -69,7 +69,7 @@ def run_txt2img_pipeline(
)
image = result.images[0]
image = run_upscale_correction(
server, StageParams(), params, image, upscale=upscale)
job, server, StageParams(), params, image, upscale=upscale)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
Expand Down Expand Up @@ -109,7 +109,7 @@ def run_img2img_pipeline(
)
image = result.images[0]
image = run_upscale_correction(
server, StageParams(), params, image, upscale=upscale)
job, server, StageParams(), params, image, upscale=upscale)

dest = save_image(server, output, image)
size = Size(*source_image.size)
Expand Down Expand Up @@ -141,7 +141,6 @@ def run_inpaint_pipeline(
# progress = job.get_progress_callback()
stage = StageParams()

# TODO: pass device, progress
image = upscale_outpaint(
server,
stage,
Expand All @@ -162,7 +161,7 @@ def run_inpaint_pipeline(
'output image size does not match source, skipping post-blend')

image = run_upscale_correction(
server, stage, params, image, upscale=upscale)
job, server, stage, params, image, upscale=upscale)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)
Expand All @@ -186,9 +185,8 @@ def run_upscale_pipeline(
# progress = job.get_progress_callback()
stage = StageParams()

# TODO: pass device, progress
image = run_upscale_correction(
server, stage, params, source_image, upscale=upscale)
job, server, stage, params, source_image, upscale=upscale)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
Expand Down
22 changes: 16 additions & 6 deletions api/onnx_web/serve.py
Expand Up @@ -41,6 +41,7 @@
ChainPipeline,
)
from .device_pool import (
DeviceParams,
DevicePoolExecutor,
)
from .diffusion.run import (
Expand Down Expand Up @@ -168,14 +169,23 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)


def pipeline_from_request() -> Tuple[ImageParams, Size]:
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr

# platform stuff
device_name = request.args.get('platform', available_platforms[0].device)
device = None

for platform in available_platforms:
if platform.device == device_name:
device = available_platforms[0]

if device is None:
raise Exception('unknown device')

# pipeline stuff
model = get_not_empty(request.args, 'model', get_config_value('model'))
model_path = get_model_path(model)
provider = get_from_map(request.args, 'platform',
platform_providers, get_config_value('platform'))
scheduler = get_from_map(request.args, 'scheduler',
pipeline_schedulers, get_config_value('scheduler'))

Expand Down Expand Up @@ -213,12 +223,12 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]:
seed = np.random.randint(np.iinfo(np.int32).max)

logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt)
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)

params = ImageParams(model_path, provider, scheduler, prompt,
params = ImageParams(model_path, scheduler, prompt,
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (params, size)
return (device, params, size)


def border_from_request() -> Border:
Expand Down
8 changes: 6 additions & 2 deletions api/onnx_web/upscale.py
Expand Up @@ -7,6 +7,9 @@
upscale_resrgan,
ChainPipeline,
)
from .device_pool import (
JobContext,
)
from .params import (
ImageParams,
SizeChart,
Expand All @@ -21,7 +24,8 @@


def run_upscale_correction(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
Expand Down Expand Up @@ -51,4 +55,4 @@ def run_upscale_correction(
outscale=1)
chain.append((correct_gfpgan, stage, None))

return chain(ctx, params, image, prompt=params.prompt, upscale=upscale)
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
1 change: 1 addition & 0 deletions common/pipelines/outpaint.json
Expand Up @@ -12,6 +12,7 @@
"type": "upscale-outpaint",
"params": {
"border": 256,
"model": "stable-diffusion-onnx-v1-inpainting",
"prompt": "a magical wizard in a robe fighting a dragon"
}
},
Expand Down

0 comments on commit 3637f64

Please sign in to comment.