Skip to content

Commit

Permalink
feat(api): add save-to-disk stage
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 28, 2023
1 parent ce6cf08 commit 779457b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
3 changes: 3 additions & 0 deletions api/onnx_web/chain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .generate_txt2img import (
generate_txt2img,
)
from .persist_disk import (
persist_disk,
)
from .upscale_outpaint import (
upscale_outpaint,
)
Expand Down
23 changes: 23 additions & 0 deletions api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from PIL import Image


from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)


def persist_disk(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
output: str,
) -> Image.Image:
source_image.save(output)
print('saved image to %s' % (output,))
return source_image
6 changes: 3 additions & 3 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upscale_outpaint(
fill_color: str = 'white',
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
) -> Image:
) -> Image.Image:
print('upscaling image by expanding borders', expand)

if mask_image is None:
Expand All @@ -51,21 +51,21 @@ def upscale_outpaint(
draw.rectangle((expand.left, expand.top, expand.left +
source_image.width, expand.top + source_image.height), fill='black')

source_image, mask_image, noise_image, full_dims = expand_image(
source_image, mask_image, noise_image, _full_dims = expand_image(
source_image,
mask_image,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter)
size = Size(*full_dims)

if is_debug():
source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))

def outpaint(image: Image.Image):
size = Size(*image.size)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler)

Expand Down
18 changes: 12 additions & 6 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .chain import (
correct_gfpgan,
generate_txt2img,
persist_disk,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
Expand Down Expand Up @@ -540,25 +541,30 @@ def upscale():

@app.route('/api/chain', methods=['POST'])
def chain():
print('TODO: run chain pipeline')

params, size = pipeline_from_request()
output = make_output_name('chain', params, size)

# parse body as json, list of stages
example = ChainPipeline(stages=[
(generate_txt2img, StageParams(), {
'size': size,
}),
(upscale_outpaint, StageParams(outscale=4), {
'expand': Border(256, 256, 256, 256),
}),
(persist_disk, StageParams(), {
'output': output,
})
])

output = make_output_name('chain', params, size)
# build and run chain pipeline
executor.submit_stored(output, example, context, params, Image.new('RGB', (1, 1)))

# parse body as json, list of stages
# build and run chain pipeline
return jsonify({})
return jsonify({
'output': output,
'params': params.tojson(),
'size': upscale.resize(size).tojson(),
})


@app.route('/api/ready')
Expand Down

0 comments on commit 779457b

Please sign in to comment.