Skip to content

Commit

Permalink
fix(api): make request parsing consistent between JSON and forms
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 13, 2023
1 parent 8a5e211 commit a33c88e
Show file tree
Hide file tree
Showing 17 changed files with 295 additions and 104 deletions.
8 changes: 4 additions & 4 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ def stage(self, callback: BaseStage, params: StageParams, **kwargs):

def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, _kwargs in self.stages:
steps += callback.steps(params, size)
for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size)

return steps

def outputs(self, params: ImageParams, sources: int):
outputs = sources
for callback, _params, _kwargs in self.stages:
outputs += callback.outputs(params, outputs)
for callback, _params, kwargs in self.stages:
outputs += callback.outputs(kwargs.get("params", params), outputs)

return outputs

Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/chain/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def run(
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
raise NotImplementedError()
raise NotImplementedError() # noqa

def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return 1
return 1 # noqa

def outputs(
self,
Expand Down
70 changes: 37 additions & 33 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
from ..params import Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
base_join,
Expand Down Expand Up @@ -50,10 +50,11 @@
get_wildcard_data,
)
from .params import (
border_from_request,
highres_from_request,
build_border,
build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request,
upscale_from_request,
)
from .utils import wrap_route

Expand Down Expand Up @@ -168,8 +169,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
size = Size(source.width, source.height)

device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys())
)
Expand Down Expand Up @@ -217,8 +218,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):

def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img")
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()

replace_wildcards(params, get_wildcard_data())

Expand Down Expand Up @@ -271,9 +272,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
)

device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
upscale = upscale_from_request()
highres = highres_from_request()
expand = build_border()
upscale = build_upscale()
highres = build_highres()

fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
Expand Down Expand Up @@ -341,8 +342,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")

device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()

replace_wildcards(params, get_wildcard_data())

Expand All @@ -367,6 +368,10 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))


# keys that are specially parsed by params and should not show up in with_args
CHAIN_POP_KEYS = ["model", "control"]


def chain(server: ServerContext, pool: DevicePoolExecutor):
if request.is_json:
logger.debug("chain pipeline request with JSON body")
Expand All @@ -386,9 +391,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)

# get defaults from the regular parameters
device, base_params, base_size = pipeline_from_request(
server, data=data.get("defaults", None)
device, base_params, base_size = pipeline_from_json(
server, data=data.get("defaults")
)

# start building the pipeline
Expand All @@ -399,32 +403,32 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)

# TODO: combine base params with stage params
_device, params, size = pipeline_from_request(server, data=kwargs)
_device, params, size = pipeline_from_json(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())

if "model" in kwargs:
kwargs.pop("model")

if "control" in kwargs:
logger.warning("TODO: resolve controlnet model")
kwargs.pop("control")
# remove parsed keys, like model names (which become paths)
for pop_key in CHAIN_POP_KEYS:
if pop_key in kwargs:
kwargs.pop(pop_key)

# replace kwargs with parsed versions
kwargs["params"] = params
kwargs["size"] = size

border = build_border(kwargs)
kwargs["border"] = border

upscale = build_upscale(kwargs)
kwargs["upscale"] = upscale

# prepare the stage metadata
stage = StageParams(
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")),
tile_size=get_size(kwargs.get("tiles")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)

if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border

if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale

# load any images related to this stage
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)

Expand Down Expand Up @@ -494,7 +498,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
sources.append(source)

device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
upscale = build_upscale()

output = make_output_name(server, "upscale", params, size)
job_name = output[0]
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/server/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def set(self, tag: str, key: Any, value: Any) -> None:
return

for i in range(len(cache)):
t, k, v = cache[i]
t, k, _v = cache[i]
if tag == t and key != k:
logger.debug("updating model cache: %s %s", tag, key)
cache[i] = (tag, key, value)
Expand Down

0 comments on commit a33c88e

Please sign in to comment.