Skip to content

Commit

Permalink
fix(api): get default params from file, enforce minimum params
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 14, 2023
1 parent c09eb75 commit e8b580a
Showing 1 changed file with 30 additions and 32 deletions.
62 changes: 30 additions & 32 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# types
DiffusionPipeline,
)
from flask import Flask, jsonify, request, send_file, send_from_directory, url_for
from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor
from hashlib import sha256
Expand All @@ -32,22 +32,6 @@
import json
import numpy as np

# defaults
default_model = 'stable-diffusion-onnx-v1-5'
default_platform = 'amd'
default_scheduler = 'euler-a'
default_prompt = "a photo of an astronaut eating a hamburger"
default_cfg = 8
default_steps = 20
default_height = 512
default_width = 512
default_strength = 0.5

max_cfg = 30
max_steps = 150
max_height = 512
max_width = 512

# paths
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
path.join('..', 'gui', 'out'))
Expand Down Expand Up @@ -190,27 +174,41 @@ def pipeline_from_request():
user = request.remote_addr

# pipeline stuff
model = get_model_path(request.args.get('model', default_model))
model = get_model_path(request.args.get(
'model', config_params.get('model').get('default')))
provider = get_from_map(request.args, 'platform',
platform_providers, default_platform)
platform_providers, config_params.get('provider').get('default'))
scheduler = get_from_map(request.args, 'scheduler',
pipeline_schedulers, default_scheduler)
pipeline_schedulers, config_params.get('scheduler').get('default'))

# image params
prompt = request.args.get('prompt', default_prompt)
prompt = request.args.get(
'prompt', config_params.get('prompt').get('default'))
negative_prompt = request.args.get('negativePrompt', None)

if negative_prompt is not None and negative_prompt.strip() == '':
negative_prompt = None

cfg = get_and_clamp_float(
request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0)
request.args, 'cfg',
config_params.get('cfg').get('default'),
config_params.get('cfg').get('max'),
config_params.get('cfg').get('min'))
steps = get_and_clamp_int(
request.args, 'steps', default_steps, config_params.get('steps').get('max'))
request.args, 'steps',
config_params.get('steps').get('default'),
config_params.get('steps').get('max'),
config_params.get('steps').get('min'))
height = get_and_clamp_int(
request.args, 'height', default_height, config_params.get('height').get('max'))
request.args, 'height',
config_params.get('height').get('default'),
config_params.get('height').get('max'),
config_params.get('height').get('min'))
width = get_and_clamp_int(
request.args, 'width', default_width, config_params.get('width').get('max'))
request.args, 'width',
config_params.get('width').get('default'),
config_params.get('width').get('max'),
config_params.get('width').get('min'))

seed = int(request.args.get('seed', -1))
if seed == -1:
Expand Down Expand Up @@ -369,7 +367,6 @@ def list_schedulers():
def img2img():
input_file = request.files.get('source')
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
input_image.thumbnail((default_width, default_height))

strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)

Expand All @@ -380,6 +377,7 @@ def img2img():
(prompt, cfg, negative_prompt, steps, strength, height, width))
print("img2img output: %s" % (output_full))

input_image.thumbnail((width, height))
executor.submit_stored(output_file, run_img2img_pipeline, model, provider,
scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image)

Expand All @@ -394,8 +392,8 @@ def img2img():
'cfg': cfg,
'negativePrompt': negative_prompt,
'steps': steps,
'height': default_height,
'width': default_width,
'height': height,
'width': width,
}
})

Expand Down Expand Up @@ -433,11 +431,9 @@ def txt2img():
def inpaint():
source_file = request.files.get('source')
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
source_image.thumbnail((default_width, default_height))

mask_file = request.files.get('mask')
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
mask_image.thumbnail((default_width, default_height))

(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
width, seed) = pipeline_from_request()
Expand All @@ -446,6 +442,8 @@ def inpaint():
'inpaint', seed, (prompt, cfg, steps, height, width, seed))
print("inpaint output: %s" % output_full)

source_image.thumbnail((width, height))
mask_image.thumbnail((width, height))
executor.submit_stored(output_file, run_inpaint_pipeline, model, provider, scheduler, prompt, negative_prompt,
cfg, steps, seed, output_full, height, width, source_image, mask_image)

Expand All @@ -460,8 +458,8 @@ def inpaint():
'cfg': cfg,
'negativePrompt': negative_prompt,
'steps': steps,
'height': default_height,
'width': default_width,
'height': height,
'width': width,
}
})

Expand Down

0 comments on commit e8b580a

Please sign in to comment.