Skip to content

Commit

Permalink
feat: add eta parameter (fixes #194)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 20, 2023
1 parent f8cfc18 commit c1189aa
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 18 deletions.
2 changes: 2 additions & 0 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
Expand All @@ -106,6 +107,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)

Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,6 @@ def upscale_stable_diffusion(
source,
generator=generator,
num_inference_steps=params.steps,
eta=params.eta,
callback=callback,
).images[0]
4 changes: 4 additions & 0 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
else:
Expand All @@ -64,6 +65,7 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)

Expand Down Expand Up @@ -119,6 +121,7 @@ def run_img2img_pipeline(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
callback=progress,
)
else:
Expand All @@ -131,6 +134,7 @@ def run_img2img_pipeline(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
callback=progress,
)

Expand Down
6 changes: 5 additions & 1 deletion api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def __init__(
steps: int,
seed: int,
negative_prompt: Optional[str] = None,
lpw: Optional[bool] = False,
lpw: bool = False,
eta: float = 0.0,
) -> None:
self.model = model
self.scheduler = scheduler
Expand All @@ -164,6 +165,7 @@ def __init__(
self.seed = seed
self.steps = steps
self.lpw = lpw or False
self.eta = eta

def tojson(self) -> Dict[str, Optional[Param]]:
return {
Expand All @@ -175,6 +177,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"seed": self.seed,
"steps": self.steps,
"lpw": self.lpw,
"eta": self.eta,
}

def with_args(self, **kwargs):
Expand All @@ -187,6 +190,7 @@ def with_args(self, **kwargs):
kwargs.get("seed", self.seed),
kwargs.get("negative_prompt", self.negative_prompt),
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
)


Expand Down
8 changes: 8 additions & 0 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
request.args,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
steps = get_and_clamp_int(
request.args,
"steps",
Expand Down Expand Up @@ -220,6 +227,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
cfg,
steps,
seed,
eta=eta,
lpw=lpw,
negative_prompt=negative_prompt,
)
Expand Down
6 changes: 6 additions & 0 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
"max": 1,
"step": 0.1
},
"eta": {
"default": 0.0,
"min": 0,
"max": 1,
"step": 0.1
},
"faceOutscale": {
"default": 1,
"min": 1,
Expand Down
2 changes: 2 additions & 0 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface BaseImgParams {
cfg: number;
steps: number;
seed: number;
eta: number;
}

/**
Expand Down Expand Up @@ -279,6 +280,7 @@ export function makeApiUrl(root: string, ...path: Array<string>) {
export function makeImageURL(root: string, type: string, params: BaseImgParams): URL {
const url = makeApiUrl(root, type);
url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT));
url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT));
url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER));

if (doesExist(params.scheduler)) {
Expand Down
52 changes: 35 additions & 17 deletions gui/src/components/control/ImageControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,41 @@ export function ImageControl(props: ImageControlProps) {
});

return <Stack spacing={2}>
<QueryList
id='schedulers'
labels={SCHEDULER_LABELS}
name='Scheduler'
query={{
result: schedulers,
}}
value={mustDefault(controlState.scheduler, '')}
onChange={(value) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
scheduler: value,
});
}
}}
/>
<Stack direction='row' spacing={4}>
<QueryList
id='schedulers'
labels={SCHEDULER_LABELS}
name='Scheduler'
query={{
result: schedulers,
}}
value={mustDefault(controlState.scheduler, '')}
onChange={(value) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
scheduler: value,
});
}
}}
/>
<NumericField
decimal
label='Eta'
min={params.eta.min}
max={params.eta.max}
step={params.eta.step}
value={controlState.eta}
onChange={(eta) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
eta,
});
}
}}
/>
</Stack>
<Stack direction='row' spacing={4}>
<NumericField
decimal
Expand Down
1 change: 1 addition & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ export const DEFAULT_HISTORY = {
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
Expand Down

0 comments on commit c1189aa

Please sign in to comment.