Skip to content

Commit

Permalink
feat(gui): add highres parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 1, 2023
1 parent f462d80 commit ba09748
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 34 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import run_upscale_correction
from .image import (
expand_image,
mask_filter_gaussian_multiply,
Expand Down Expand Up @@ -48,7 +49,6 @@
apply_patch_facexlib,
apply_patches,
)
from .upscale import run_upscale_correction
from .utils import (
base_join,
get_and_clamp_float,
Expand Down
35 changes: 15 additions & 20 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from ..chain import blend_mask, upscale_outpaint
from ..chain.base import ChainProgress
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams, TileOrder, UpscaleParams
from ..params import Border, HighresParams, ImageParams, Size, StageParams, TileOrder, UpscaleParams
from ..server import ServerContext
from ..upscale import run_upscale_correction
from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline
from .upscale import run_upscale_correction
from .utils import get_inversions_from_prompt, get_loras_from_prompt

logger = getLogger(__name__)
Expand All @@ -29,13 +29,8 @@ def run_txt2img_pipeline(
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
) -> None:
# TODO: add to params
highres_scale = 4
highres_steps = 25
highres_strength = 0.2
highres_steps_post = int((params.steps - highres_steps) / highres_strength)

latents = get_latents_from_seed(params.seed, size, batch=params.batch)

(prompt, loras) = get_loras_from_prompt(params.prompt)
Expand Down Expand Up @@ -66,7 +61,7 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=highres_steps,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
Expand All @@ -81,13 +76,13 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=highres_steps,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)

for image, output in zip(result.images, outputs):
if highres_scale > 1:
if highres.scale > 1:
highres_progress = ChainProgress.from_progress(progress)

image = run_upscale_correction(
Expand Down Expand Up @@ -115,7 +110,7 @@ def run_txt2img_pipeline(
def highres(tile: Image.Image, dims):
tile = tile.resize((size.height, size.width))
if params.lpw:
logger.debug("using LPW pipeline for img2img")
logger.debug("using LPW pipeline for highres")
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
tile,
Expand All @@ -124,8 +119,8 @@ def highres(tile: Image.Image, dims):
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres_steps_post,
strength=highres_strength,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=highres_progress,
)
Expand All @@ -139,19 +134,19 @@ def highres(tile: Image.Image, dims):
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres_steps_post,
strength=highres_strength,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=highres_progress,
)
return result.images[0]

logger.info("running highres fix for %s tiles", highres_scale)
logger.info("running highres fix for %s tiles", highres.scale)
image = process_tile_order(
TileOrder.grid,
image,
size.height // highres_scale,
highres_scale,
size.height // highres.scale,
highres.scale,
[highres],
)

Expand All @@ -166,7 +161,7 @@ def highres(tile: Image.Image, dims):
)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
save_params(server, output, params, size, upscale=upscale, highres=highres)

run_gc([job.get_device()])

Expand Down
8 changes: 4 additions & 4 deletions api/onnx_web/upscale.py → api/onnx_web/diffusers/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

from PIL import Image

from .chain import (
from ..chain import (
ChainPipeline,
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
)
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .server import ServerContext
from .worker import ProgressCallback, WorkerContext
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext

logger = getLogger(__name__)

Expand Down
10 changes: 8 additions & 2 deletions api/onnx_web/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from PIL import Image

from .params import Border, ImageParams, Param, Size, UpscaleParams
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
from .server import ServerContext
from .utils import base_join

Expand Down Expand Up @@ -36,6 +36,7 @@ def json_params(
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> Any:
json = {
"outputs": outputs,
Expand All @@ -49,6 +50,10 @@ def json_params(
json["border"] = border.tojson()
size = size.add_border(border)

if highres is not None:
json["highres"] = highres.tojson()
size = highres.resize(size)

if upscale is not None:
json["upscale"] = upscale.tojson()
size = upscale.resize(size)
Expand Down Expand Up @@ -106,9 +111,10 @@ def save_params(
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> str:
path = base_join(ctx.output_path, f"{output}.json")
json = json_params(output, params, size, upscale=upscale, border=border)
json = json_params(output, params, size, upscale=upscale, border=border, highres=highres)
with open(path, "w") as f:
f.write(dumps(json))
logger.debug("saved image params to: %s", path)
Expand Down
22 changes: 22 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,25 @@ def with_args(self, **kwargs):
kwargs.get("tile_pad", self.tile_pad),
kwargs.get("upscale_order", self.upscale_order),
)


class HighresParams:
def __init__(
self,
scale: int,
steps: int,
strength: float,
):
self.scale = scale
self.steps = steps
self.strength = strength

def resize(self, size: Size) -> Size:
return Size(size.width * self.scale, size.height * self.scale)

def tojson(self):
return {
"scale": self.scale,
"steps": self.steps,
"strength": self.strength,
}
6 changes: 4 additions & 2 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_noise_sources,
get_upscaling_models,
)
from .params import border_from_request, pipeline_from_request, upscale_from_request
from .params import border_from_request, highres_from_request, pipeline_from_request, upscale_from_request
from .utils import wrap_route

logger = getLogger(__name__)
Expand Down Expand Up @@ -174,6 +174,7 @@ def img2img(context: ServerContext, pool: DevicePoolExecutor):
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(context)
upscale = upscale_from_request()
highres = highres_from_request()

output = make_output_name(context, "txt2img", params, size)
job_name = output[0]
Expand All @@ -187,10 +188,11 @@ def txt2img(context: ServerContext, pool: DevicePoolExecutor):
size,
output,
upscale,
highres,
needs_device=device,
)

return jsonify(json_params(output, params, size, upscale=upscale))
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))


def inpaint(context: ServerContext, pool: DevicePoolExecutor):
Expand Down
14 changes: 13 additions & 1 deletion api/onnx_web/server/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flask import request

from ..diffusers.load import pipeline_schedulers
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
from ..params import Border, DeviceParams, ImageParams, HighresParams, Size, UpscaleParams
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
from .context import ServerContext
from .load import (
Expand Down Expand Up @@ -169,3 +169,15 @@ def upscale_from_request() -> UpscaleParams:
scale=scale,
upscale_order=upscale_order,
)


def highres_from_request() -> HighresParams:
scale = get_and_clamp_int(request.args, "highresScale", 1, 4, 1)
steps = get_and_clamp_int(request.args, "highresSteps", 1, 4, 1)
strength = get_and_clamp_float(request.args, "highresStrength", 0.5, 1.0, 0.0)

return HighresParams(
scale,
steps,
strength,
)
19 changes: 17 additions & 2 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ export interface BlendParams {
mask: Blob;
}

export interface HighresParams {
enabled: boolean;

highresScale: number;
highresSteps: number;
highresStrength: number;
}

/**
* Output image data within the response.
*/
Expand Down Expand Up @@ -201,6 +209,7 @@ export type RetryParams = {
model: ModelParams;
params: Txt2ImgParams;
upscale?: UpscaleParams;
highres?: HighresParams;
} | {
type: 'img2img';
model: ModelParams;
Expand Down Expand Up @@ -274,7 +283,7 @@ export interface ApiClient {
/**
* Start a txt2img pipeline.
*/
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;

/**
* Start an im2img pipeline.
Expand Down Expand Up @@ -477,7 +486,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
},
};
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);

Expand All @@ -493,6 +502,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
appendUpscaleToURL(url, upscale);
}

if (doesExist(highres) && highres.enabled) {
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
}

const image = await parseRequest(url, {
method: 'POST',
});
Expand Down
73 changes: 73 additions & 0 deletions gui/src/components/control/HighresControl.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import { mustExist } from '@apextoaster/js-utils';
import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';

import { ConfigContext, StateContext } from '../../state.js';
import { NumericField } from '../input/NumericField.js';

export function UpscaleControl() {
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const highres = useStore(state, (s) => s.highres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setHighres);
const { t } = useTranslation();

return <Stack direction='row' spacing={4}>
<FormControlLabel
label={t('parameter.highres.label')}
control={<Checkbox
checked={highres.enabled}
value='check'
onChange={(event) => {
setHighres({
enabled: highres.enabled === false,
});
}}
/>}
/>
<NumericField
label={t('parameter.highres.steps')}
decimal
disabled={highres.enabled === false}
min={params.denoise.min}
max={params.denoise.max}
step={params.denoise.step}
value={highres.highresSteps}
onChange={(steps) => {
setHighres({
highresSteps: steps,
});
}}
/>
<NumericField
label={t('parameter.highres.scale')}
disabled={highres.enabled === false}
min={params.scale.min}
max={params.scale.max}
step={params.scale.step}
value={highres.highresScale}
onChange={(scale) => {
setHighres({
highresScale: scale,
});
}}
/>
<NumericField
label={t('parameter.highres.strength')}
disabled={highres.enabled === false}
min={params.strength.min}
max={params.strength.max}
step={params.outscale.step}
value={highres.highresStrength}
onChange={(strength) => {
setHighres({
highresStrength: strength,
});
}}
/>
</Stack>;
}

0 comments on commit ba09748

Please sign in to comment.