Skip to content

Commit

Permalink
feat: add method parameter for highres mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 1, 2023
1 parent 0f79f42 commit f451d8d
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 6 deletions.
32 changes: 29 additions & 3 deletions api/onnx_web/diffusers/run.py
Expand Up @@ -89,7 +89,11 @@ def run_txt2img_pipeline(
callback=progress,
)

for image, output in zip(result.images, outputs):
image_outputs = list(zip(result.images, outputs))
del result
del pipe

for image, output in image_outputs:
if highres.scale > 1:
highres_progress = ChainProgress.from_progress(progress)

Expand All @@ -99,7 +103,10 @@ def run_txt2img_pipeline(
StageParams(),
params,
image,
upscale=upscale,
upscale=upscale.with_args(
scale=1,
outscale=1,
),
callback=highres_progress,
)

Expand All @@ -116,7 +123,26 @@ def run_txt2img_pipeline(
)

def highres_tile(tile: Image.Image, dims):
tile = tile.resize((size.height, size.width))
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
tile = tile.resize((size.height, size.width), resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
tile = tile.resize((size.height, size.width), resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
tile = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale.with_args(
faces=False,
),
callback=highres_progress,
)

if params.lpw:
logger.debug("using LPW pipeline for highres")
rng = torch.manual_seed(params.seed)
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/params.py
Expand Up @@ -325,16 +325,19 @@ def __init__(
scale: int,
steps: int,
strength: float,
method: Literal["bilinear", "lanczos", "upscale"] = "lanczos",
):
self.scale = scale
self.steps = steps
self.strength = strength
self.method = method

def resize(self, size: Size) -> Size:
return Size(size.width * self.scale, size.height * self.scale)

def tojson(self):
return {
"method": self.method,
"scale": self.scale,
"steps": self.steps,
"strength": self.strength,
Expand Down
9 changes: 9 additions & 0 deletions api/onnx_web/server/load.py
Expand Up @@ -51,6 +51,11 @@
"gaussian-multiply": mask_filter_gaussian_multiply,
"gaussian-screen": mask_filter_gaussian_screen,
}
highres_methods = {
"bilinear": highres_method_bilinear,
"lanczos": highres_method_lanczos,
"upscale": highres_method_upscale,
}


# Available ORT providers
Expand Down Expand Up @@ -94,6 +99,10 @@ def get_extra_strings():
return extra_strings


def get_highres_methods():
return highres_methods


def get_mask_filters():
return mask_filters

Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/server/params.py
Expand Up @@ -19,6 +19,7 @@
get_available_platforms,
get_config_value,
get_correction_models,
get_highres_methods,
get_upscaling_models,
)
from .utils import get_model_path
Expand Down Expand Up @@ -179,6 +180,7 @@ def upscale_from_request() -> UpscaleParams:


def highres_from_request() -> HighresParams:
method = get_from_list(request.args, "highresMethod", get_highres_methods())
scale = get_and_clamp_int(request.args, "highresScale", 1, 4, 1)
steps = get_and_clamp_int(request.args, "highresSteps", 1, 200, 1)
strength = get_and_clamp_float(request.args, "highresStrength", 0.5, 1.0, 0.0)
Expand All @@ -187,4 +189,5 @@ def highres_from_request() -> HighresParams:
scale,
steps,
strength,
method=method,
)
1 change: 1 addition & 0 deletions api/onnx_web/worker/worker.py
Expand Up @@ -22,6 +22,7 @@
"hipErrorOutOfMemory",
"MIOPEN failure 7",
"out of memory",
"rocblas_status_memory_error",
]


Expand Down
10 changes: 9 additions & 1 deletion api/params.json
Expand Up @@ -60,6 +60,14 @@
"max": 1024,
"step": 8
},
"highresMethods": {
"default": "lanczos",
"keys": [
"bilinear",
"lanczos",
"upscale"
]
},
"highresScale": {
"default": 1,
"min": 1,
Expand All @@ -76,7 +84,7 @@
"default": 0.5,
"min": 0,
"max": 1,
"step": 0.1
"step": 0.01
},
"inversion": {
"default": "",
Expand Down
28 changes: 27 additions & 1 deletion gui/examples/config.json
Expand Up @@ -64,6 +64,32 @@
"max": 1024,
"step": 8
},
"highresMethods": {
"default": "lanczos",
"keys": [
"bilinear",
"lanczos",
"upscale"
]
},
"highresScale": {
"default": 1,
"min": 1,
"max": 4,
"step": 1
},
"highresSteps": {
"default": 0,
"min": 1,
"max": 200,
"step": 1
},
"highresStrength": {
"default": 0.5,
"min": 0,
"max": 1,
"step": 0.01
},
"inversion": {
"default": "",
"keys": []
Expand Down Expand Up @@ -166,4 +192,4 @@
"step": 8
}
}
}
}
2 changes: 2 additions & 0 deletions gui/src/client/api.ts
Expand Up @@ -147,6 +147,7 @@ export interface BlendParams {
export interface HighresParams {
enabled: boolean;

highresMethod: string;
highresScale: number;
highresSteps: number;
highresStrength: number;
Expand Down Expand Up @@ -503,6 +504,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
}

if (doesExist(highres) && highres.enabled) {
url.searchParams.append('highresMethod', highres.highresMethod);
url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER));
url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER));
url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT));
Expand Down
17 changes: 17 additions & 0 deletions gui/src/components/control/HighresControl.tsx
Expand Up @@ -69,5 +69,22 @@ export function HighresControl() {
});
}}
/>
<FormControl>
<InputLabel id={'highres-method'}>{t('parameter.highres.method')}</InputLabel>
<Select
labelId={'highres-method'}
label={t('parameter.highres.method')}
value={highres.highresMethod}
onChange={(e) => {
setHighres({
highresMethod: e.target.value,
});
}}
>
{Object.entries(params.highresMethod.keys).map(([key, name]) =>
<MenuItem key={key} value={name}>{t(`highresMethod.${name}`)}</MenuItem>)
}
</Select>
</FormControl>
</Stack>;
}
2 changes: 1 addition & 1 deletion gui/src/components/control/UpscaleControl.tsx
Expand Up @@ -109,7 +109,7 @@ export function UpscaleControl() {
}}
/>
<FormControl>
<InputLabel id={'upscale-order'}>Upscale Order</InputLabel>
<InputLabel id={'upscale-order'}>{t('parameter.upscale.order')}</InputLabel>
<Select
labelId={'upscale-order'}
label={t('parameter.upscale.order')}
Expand Down
2 changes: 2 additions & 0 deletions gui/src/state.ts
Expand Up @@ -433,6 +433,7 @@ export function createStateSlices(server: ServerParams) {
const createHighresSlice: Slice<HighresSlice> = (set) => ({
highres: {
enabled: false,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
Expand All @@ -449,6 +450,7 @@ export function createStateSlices(server: ServerParams) {
set({
highres: {
enabled: false,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
Expand Down
5 changes: 5 additions & 0 deletions gui/src/strings/de.ts
Expand Up @@ -12,6 +12,11 @@ export const I18N_STRINGS_DE = {
},
},
generate: 'Erzeugen',
highresMethod: {
bilinear: '',
lanczos: '',
upscale: '',
},
history: {
empty: 'Keine neuere Geschichte. Drücken Sie Generieren, um ein Bild zu erstellen.',
},
Expand Down
5 changes: 5 additions & 0 deletions gui/src/strings/en.ts
Expand Up @@ -7,6 +7,11 @@ export const I18N_STRINGS_EN = {
},
},
generate: 'Generate',
highresMethod: {
bilinear: 'Bilinear',
lanczos: 'Lanczos',
upscale: 'Upscaling',
},
history: {
empty: 'No recent history. Press Generate to create an image.',
},
Expand Down
5 changes: 5 additions & 0 deletions gui/src/strings/es.ts
Expand Up @@ -15,6 +15,11 @@ export const I18N_STRINGS_ES = {
history: {
empty: 'Sin antecedentes recientes. Presiona generar para crear una nueva imagen.',
},
highresMethod: {
bilinear: '',
lanczos: '',
upscale: '',
},
input: {
image: {
empty: 'Por favor, seleccione una imagen.',
Expand Down
5 changes: 5 additions & 0 deletions gui/src/strings/fr.ts
Expand Up @@ -12,6 +12,11 @@ export const I18N_STRINGS_FR = {
},
},
generate: 'générer',
highresMethod: {
bilinear: '',
lanczos: '',
upscale: '',
},
history: {
empty: 'pas d\'histoire récente. appuyez sur générer pour créer une image.',
},
Expand Down

0 comments on commit f451d8d

Please sign in to comment.