Skip to content

Commit

Permalink
feat: make enabling highres a parameter of its own
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 3, 2023
1 parent 3edf5e6 commit 99c91a3
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 54 deletions.
53 changes: 29 additions & 24 deletions api/onnx_web/chain/highres.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,41 @@ def stage_highres(
if chain is None:
chain = ChainPipeline()

if not highres.enabled:
logger.debug("highres not enabled, skipping")
return chain

if highres.iterations < 1:
logger.debug("no highres iterations, skipping")
return chain

if highres.method == "upscale":
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
stage,
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)
else:
logger.debug("using simple upscaling for highres")
for _i in range(highres.iterations):
if highres.method == "upscale":
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
stage,
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)
else:
logger.debug("using simple upscaling for highres")
chain.stage(
UpscaleSimpleStage(),
stage,
method=highres.method,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
)

chain.stage(
UpscaleSimpleStage(),
BlendImg2ImgStage(),
stage,
method=highres.method,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
overlap=params.overlap,
strength=highres.strength,
)

chain.stage(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
strength=highres.strength,
)

return chain
66 changes: 36 additions & 30 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,13 @@ def run_txt2img_pipeline(
)

# apply highres
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction, after highres
stage_upscale_correction(
Expand Down Expand Up @@ -152,14 +151,13 @@ def run_img2img_pipeline(
)

# highres, if selected
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction, after highres
stage_upscale_correction(
Expand Down Expand Up @@ -234,21 +232,30 @@ def run_inpaint_pipeline(
noise_source=noise_source,
)

# apply highres
for _i in range(highres.iterations):
stage_highres(
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
params,
highres,
upscale,
upscale=first_upscale,
chain=chain,
)

# apply highres
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction
stage_upscale_correction(
stage,
params,
upscale=upscale,
upscale=after_upscale,
chain=chain,
)

Expand Down Expand Up @@ -303,14 +310,13 @@ def run_upscale_pipeline(
)

# apply highres
for _i in range(highres.iterations):
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)
stage_highres(
stage,
params,
highres,
upscale,
chain=chain,
)

# apply upscaling and correction, after highres
stage_upscale_correction(
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,14 @@ def with_args(self, **kwargs):
class HighresParams:
def __init__(
self,
enabled: bool,
scale: int,
steps: int,
strength: float,
method: Literal["bilinear", "lanczos", "upscale"] = "lanczos",
iterations: int = 1,
):
self.enabled = enabled
self.scale = scale
self.steps = steps
self.strength = strength
Expand All @@ -441,6 +443,7 @@ def resize(self, size: Size) -> Size:

def tojson(self):
return {
"enabled": self.enabled,
"iterations": self.iterations,
"method": self.method,
"scale": self.scale,
Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/server/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def upscale_from_request() -> UpscaleParams:


def highres_from_request() -> HighresParams:
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
iterations = get_and_clamp_int(
request.args,
"highresIterations",
Expand Down Expand Up @@ -313,6 +314,7 @@ def highres_from_request() -> HighresParams:
get_config_value("highresStrength", "min"),
)
return HighresParams(
enabled,
scale,
steps,
strength,
Expand Down
3 changes: 3 additions & 0 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
"max": 8192,
"step": 8
},
"highres": {
"default": false
},
"highresIterations": {
"default": 1,
"min": 1,
Expand Down
1 change: 1 addition & 0 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {

export function appendHighresToURL(url: URL, highres: HighresParams) {
if (highres.enabled) {
url.searchParams.append('highres', String(highres.enabled));
url.searchParams.append('highresIterations', highres.highresIterations.toFixed(FIXED_INTEGER));
url.searchParams.append('highresMethod', highres.highresMethod);
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
Expand Down
3 changes: 3 additions & 0 deletions gui/src/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
"max": 8192,
"step": 8
},
"highres": {
"default": false
},
"highresMethod": {
"default": "lanczos",
"keys": [
Expand Down

0 comments on commit 99c91a3

Please sign in to comment.