Skip to content

Commit

Permalink
fix(api): validate request params better, esp model path
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 22, 2023
1 parent ce11165 commit 876b54a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 34 deletions.
71 changes: 37 additions & 34 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
make_output_name,
safer_join,
BaseParams,
Expand Down Expand Up @@ -112,6 +113,10 @@
upscaling_models = []


def get_config_value(key: str, subkey: str = 'default'):
return config_params.get(key).get(subkey)


def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
Expand All @@ -124,64 +129,60 @@ def pipeline_from_request() -> Tuple[BaseParams, Size]:
user = request.remote_addr

# pipeline stuff
model = get_model_path(request.args.get(
'model', config_params.get('model').get('default')))
model = get_not_empty(request.args, 'model', get_config_value('model'))
model_path = get_model_path(model)
provider = get_from_map(request.args, 'platform',
platform_providers, config_params.get('platform').get('default'))
platform_providers, get_config_value('platform'))
scheduler = get_from_map(request.args, 'scheduler',
pipeline_schedulers, config_params.get('scheduler').get('default'))
pipeline_schedulers, get_config_value('scheduler'))

# image params
prompt = request.args.get(
'prompt', config_params.get('prompt').get('default'))
prompt = get_not_empty(request.args,
'prompt', get_config_value('prompt'))
negative_prompt = request.args.get('negativePrompt', None)

if negative_prompt is not None and negative_prompt.strip() == '':
negative_prompt = None

cfg = get_and_clamp_float(
request.args, 'cfg',
config_params.get('cfg').get('default'),
config_params.get('cfg').get('max'),
config_params.get('cfg').get('min'))
get_config_value('cfg'),
get_config_value('cfg', 'max'),
get_config_value('cfg', 'min'))
steps = get_and_clamp_int(
request.args, 'steps',
config_params.get('steps').get('default'),
config_params.get('steps').get('max'),
config_params.get('steps').get('min'))
get_config_value('steps'),
get_config_value('steps', 'max'),
get_config_value('steps', 'min'))
height = get_and_clamp_int(
request.args, 'height',
config_params.get('height').get('default'),
config_params.get('height').get('max'),
config_params.get('height').get('min'))
get_config_value('height'),
get_config_value('height', 'max'),
get_config_value('height', 'min'))
width = get_and_clamp_int(
request.args, 'width',
config_params.get('width').get('default'),
config_params.get('width').get('max'),
config_params.get('width').get('min'))
get_config_value('width'),
get_config_value('width', 'max'),
get_config_value('width', 'min'))

seed = int(request.args.get('seed', -1))
if seed == -1:
seed = np.random.randint(np.iinfo(np.int32).max)

print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model, provider, width, height, cfg, seed, prompt))
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))

params = BaseParams(model, provider, scheduler, prompt,
params = BaseParams(model_path, provider, scheduler, prompt,
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (params, size)


def border_from_request() -> Border:
left = get_and_clamp_int(request.args, 'left', 0,
config_params.get('width').get('max'), 0)
right = get_and_clamp_int(request.args, 'right',
0, config_params.get('width').get('max'), 0)
top = get_and_clamp_int(request.args, 'top', 0,
config_params.get('height').get('max'), 0)
bottom = get_and_clamp_int(
request.args, 'bottom', 0, config_params.get('height').get('max'), 0)
left = get_and_clamp_int(request.args, 'left', 0, get_config_value('width', 'max'), 0)
right = get_and_clamp_int(request.args, 'right', 0, get_config_value('width', 'max'), 0)
top = get_and_clamp_int(request.args, 'top', 0, get_config_value('height', 'max'), 0)
bottom = get_and_clamp_int(request.args, 'bottom', 0, get_config_value('height', 'max'), 0)

return Border(left, right, top, bottom)

Expand All @@ -192,7 +193,7 @@ def upscale_from_request() -> UpscaleParams:
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models)
faces = request.args.get('faces', 'false') == 'true'
faces = get_not_empty(request.args, 'faces', 'false') == 'true'
face_strength = get_and_clamp_float(
request.args, 'faceStrength', 0.5, 1.0, 0.0)

Expand Down Expand Up @@ -359,8 +360,9 @@ def img2img():
strength = get_and_clamp_float(
request.args,
'strength',
config_params.get('strength').get('default'),
config_params.get('strength').get('max'))
get_config_value('strength'),
get_config_value('strength', 'max'),
get_config_value('strength', 'min'))

output = make_output_name(
'img2img',
Expand Down Expand Up @@ -413,15 +415,16 @@ def inpaint():
expand = border_from_request()
upscale = upscale_from_request()

fill_color = request.args.get('fillColor', 'white')
fill_color = get_not_empty(request.args, 'fillColor', 'white')
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
noise_source = get_from_map(
request.args, 'noise', noise_sources, 'histogram')
strength = get_and_clamp_float(
request.args,
'strength',
config_params.get('strength').get('default'),
config_params.get('strength').get('max'))
get_config_value('strength'),
get_config_value('strength', 'max'),
get_config_value('strength', 'min'))

output = make_output_name(
'inpaint',
Expand Down
9 changes: 9 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any):
return values[default]


def get_not_empty(args: Any, key: str, default: Any):
val = args.get(key, default)

if val is None or len(val) == 0:
val = default

return val


def hash_value(sha, param: Param):
if param is None:
return
Expand Down

0 comments on commit 876b54a

Please sign in to comment.