Skip to content

Commit

Permalink
feat(api): start implementing chain pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 27, 2023
1 parent 260f7a2 commit 71ff3bb
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 96 deletions.
6 changes: 3 additions & 3 deletions api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
noise_source_uniform,
)
from .upscale import (
make_resrgan,
load_resrgan,
run_upscale_correction,
upscale_gfpgan,
upscale_resrgan,
Expand All @@ -30,8 +30,8 @@
get_from_list,
get_from_map,
get_not_empty,
safer_join,
BaseParams,
base_join,
ImageParams,
Border,
Point,
ServerContext,
Expand Down
95 changes: 95 additions & 0 deletions api/onnx_web/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from PIL import Image
from os import path
from typing import Any, List, Optional, Protocol, Tuple

from .image import (
process_tiles,
)
from .utils import (
ImageParams,
ServerContext,
)


class StageParams:
'''
Parameters for a pipeline stage, assuming they can be chained.
'''

def __init__(
self,
tile_size: int = 512,
outscale: int = 1,
) -> None:
self.tile_size = tile_size
self.outscale = outscale


class StageCallback(Protocol):
def __call__(
self,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
pass


PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]


class ChainPipeline:
'''
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
'''

def __init__(
self,
stages: List[PipelineStage],
):
'''
Create a new pipeline that will run the given stages.
'''
self.stages = stages

def append(self, stage: PipelineStage):
'''
Append an additional stage to this pipeline.
'''
self.stages.append(stage)

def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image) -> Image.Image:
'''
TODO: handle List[Image] outputs
'''
print('running pipeline on source image with dimensions %sx%s' %
source.size)
image = source

for stage_fn, stage_params, stage_kwargs in self.stages:
print('running pipeline stage on result image with dimensions %sx%s' %
image.size)
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
print('source image larger than tile size, tiling stage',
stage_params.tile_size)

def stage_tile(tile: Image.Image) -> Image.Image:
tile = stage_fn(ctx, stage_params, tile,
params, **stage_kwargs)
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile

image = process_tiles(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
print('source image within tile size, run stage')
image = stage_fn(ctx, stage_params, image,
params, **stage_kwargs)

print('finished running pipeline stage, result size: %sx%s' % image.size)

print('finished running pipeline, result size: %sx%s' % image.size)
return image
4 changes: 2 additions & 2 deletions api/onnx_web/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from shutil import copyfile, rmtree
from sys import exit
from torch.onnx import export
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple

import torch
import warnings
Expand All @@ -18,7 +18,7 @@
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')

Models = Dict[str, List[Tuple[str, str, Union[int, None]]]]
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]

# recommended models
base_models: Models = {
Expand Down
40 changes: 21 additions & 19 deletions api/onnx_web/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
OnnxStableDiffusionInpaintPipeline,
)
from PIL import Image, ImageChops
from typing import Any, Union
from typing import Any, Optional

import gc
import numpy as np
Expand All @@ -21,10 +21,11 @@
)
from .utils import (
is_debug,
safer_join,
BaseParams,
base_join,
ImageParams,
Border,
ServerContext,
StageParams,
Size,
)

Expand All @@ -45,7 +46,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents


def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None] = None):
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
Expand Down Expand Up @@ -95,7 +96,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu

def run_txt2img_pipeline(
ctx: ServerContext,
params: BaseParams,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams
Expand All @@ -117,9 +118,9 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps,
)
image = result.images[0]
image = run_upscale_correction(ctx, upscale, image)
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)

dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)

del image
Expand All @@ -130,7 +131,7 @@ def run_txt2img_pipeline(

def run_img2img_pipeline(
ctx: ServerContext,
params: BaseParams,
params: ImageParams,
output: str,
upscale: UpscaleParams,
source_image: Image,
Expand All @@ -151,9 +152,9 @@ def run_img2img_pipeline(
strength=strength,
)
image = result.images[0]
image = run_upscale_correction(ctx, upscale, image)
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)

dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)

del image
Expand All @@ -164,7 +165,8 @@ def run_img2img_pipeline(

def run_inpaint_pipeline(
ctx: ServerContext,
params: BaseParams,
stage: StageParams,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams,
Expand Down Expand Up @@ -192,9 +194,9 @@ def run_inpaint_pipeline(
mask_filter=mask_filter)

if is_debug():
source_image.save(safer_join(ctx.output_path, 'last-source.png'))
mask_image.save(safer_join(ctx.output_path, 'last-mask.png'))
noise_image.save(safer_join(ctx.output_path, 'last-noise.png'))
source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))

result = pipe(
params.prompt,
Expand All @@ -215,9 +217,9 @@ def run_inpaint_pipeline(
else:
print('output image size does not match source, skipping post-blend')

image = run_upscale_correction(ctx, upscale, image)
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)

dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)

del image
Expand All @@ -228,15 +230,15 @@ def run_inpaint_pipeline(

def run_upscale_pipeline(
ctx: ServerContext,
_params: BaseParams,
params: ImageParams,
_size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image
):
image = run_upscale_correction(ctx, upscale, source_image)
image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)

dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)

del image
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def process_tiles(
idx = (y * tiles_x) + x
left = x * tile
top = y * tile
print('processing tile %s of %s, %s.%s', idx, total, x, y)
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
tile_image = source.crop((left, top, left + tile, top + tile))

for filter in filters:
Expand Down
18 changes: 12 additions & 6 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
get_from_map,
get_not_empty,
make_output_name,
safer_join,
BaseParams,
base_join,
ImageParams,
Border,
ServerContext,
Size,
Expand Down Expand Up @@ -124,7 +124,7 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)


def pipeline_from_request() -> Tuple[BaseParams, Size]:
def pipeline_from_request() -> Tuple[ImageParams, Size]:
user = request.remote_addr

# pipeline stuff
Expand Down Expand Up @@ -171,7 +171,7 @@ def pipeline_from_request() -> Tuple[BaseParams, Size]:
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))

params = BaseParams(model_path, provider, scheduler, prompt,
params = ImageParams(model_path, provider, scheduler, prompt,
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (params, size)
Expand Down Expand Up @@ -288,7 +288,7 @@ def load_platforms():
# TODO: these two use context

def get_model_path(model: str):
return safer_join(context.model_path, model)
return base_join(context.model_path, model)


def ready_reply(ready: bool):
Expand Down Expand Up @@ -523,14 +523,20 @@ def upscale():
})


@app.route('/api/chain', methods=['POST'])
def chain():
print('TODO: run chain pipeline')
return jsonify({})


@app.route('/api/ready')
def ready():
output_file = request.args.get('output', None)

done = executor.futures.done(output_file)

if done is None:
file = safer_join(context.output_path, output_file)
file = base_join(context.output_path, output_file)
if path.exists(file):
return ready_reply(True)

Expand Down

0 comments on commit 71ff3bb

Please sign in to comment.