Skip to content

Commit

Permalink
feat(api): add strength param to inpaint, remove same from upscale
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 17, 2023
1 parent b496e71 commit 5ba752e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
14 changes: 8 additions & 6 deletions api/onnx_web/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None

# from https://www.travelneil.com/stable-diffusion-updates.html


def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
'''
From https://www.travelneil.com/stable-diffusion-updates.html
'''
# 1 is batch size
latents_shape = (1, 4, size.height // 8, size.width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
Expand Down Expand Up @@ -147,7 +148,8 @@ def run_inpaint_pipeline(
mask_image: Image,
expand: Border,
noise_source: Any,
mask_filter: Any
mask_filter: Any,
strength: float,
):
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler)
Expand Down Expand Up @@ -180,6 +182,7 @@ def run_inpaint_pipeline(
num_inference_steps=params.steps,
width=size.width,
).images[0]
image = ImageChops.blend(source_image, image, strength)

if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, upscale, image)
Expand All @@ -189,17 +192,16 @@ def run_inpaint_pipeline(

print('saved inpaint output: %s' % (dest))


def run_upscale_pipeline(
ctx: ServerContext,
_params: BaseParams,
_size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image,
strength: float,
source_image: Image
):
image = upscale_resrgan(ctx, upscale, source_image)
image = ImageChops.blend(source_image, image, strength)

dest = safer_join(ctx.output_path, output)
image.save(dest)
Expand Down
26 changes: 13 additions & 13 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def img2img():
'img2img',
params,
size,
extras=(strength))
extras=(strength,))
print("img2img output: %s" % (output))

source_image.thumbnail((size.width, size.height))
Expand Down Expand Up @@ -393,6 +393,11 @@ def inpaint():
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
noise_source = get_from_map(
request.args, 'noise', noise_sources, 'histogram')
strength = get_and_clamp_float(
request.args,
'strength',
config_params.get('strength').get('default'),
config_params.get('strength').get('max'))

output = make_output_name(
'inpaint',
Expand All @@ -405,6 +410,7 @@ def inpaint():
expand.bottom,
mask_filter.__name__,
noise_source.__name__,
strength,
)
)
print("inpaint output: %s" % output)
Expand All @@ -423,7 +429,8 @@ def inpaint():
mask_image,
expand,
noise_source,
mask_filter)
mask_filter,
strength)

return jsonify({
'output': output,
Expand All @@ -440,22 +447,15 @@ def upscale():
params, size = pipeline_from_request()
upscale = upscale_from_request()

strength = get_and_clamp_float(
request.args,
'strength',
config_params.get('strength').get('default'),
config_params.get('strength').get('max'))

output = make_output_name(
'img2img',
'upscale',
params,
size,
extras=(strength))
print("img2img output: %s" % (output))
size)
print("upscale output: %s" % (output))

source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_upscale_pipeline,
context, params, output, upscale, source_image, strength)
context, params, output, upscale, source_image)

return jsonify({
'output': output,
Expand Down

0 comments on commit 5ba752e

Please sign in to comment.