Skip to content

Commit

Permalink
feat: add a way to select textual inversions
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 25, 2023
1 parent 45f5fca commit 2e7de16
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 2 deletions.
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def blend_img2img(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)

if params.lpw:
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def source_txt2img(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)

if params.lpw:
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:
logger.debug("using LPW pipeline for inpaint")
Expand Down
16 changes: 15 additions & 1 deletion api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
OnnxRuntimeModel,
)

try:
Expand Down Expand Up @@ -138,8 +139,9 @@ def load_pipeline(
scheduler_type: Any,
device: DeviceParams,
lpw: bool,
inversion: Optional[str],
):
pipe_key = (pipeline, model, device.device, device.provider, lpw)
pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion)
scheduler_key = (scheduler_type, model)

cache_pipe = server.cache.get("diffusion", pipe_key)
Expand Down Expand Up @@ -182,6 +184,17 @@ def load_pipeline(
sess_options=device.sess_options(),
subfolder="scheduler",
)

text_encoder = None
if inversion is not None:
logger.debug("loading text encoder from %s", inversion)
text_encoder = OnnxRuntimeModel.from_pretrained(
inversion,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="text_encoder",
)

pipe = pipeline.from_pretrained(
model,
custom_pipeline=custom_pipeline,
Expand All @@ -190,6 +203,7 @@ def load_pipeline(
revision="onnx",
safety_checker=None,
scheduler=scheduler,
text_encoder=text_encoder,
)

if not server.show_progress:
Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run_txt2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
progress = job.get_progress_callback()

Expand Down Expand Up @@ -109,6 +110,7 @@ def run_img2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
progress = job.get_progress_callback()
if params.lpw:
Expand Down
8 changes: 7 additions & 1 deletion api/onnx_web/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
def hash_value(sha, param: Param):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
Expand Down Expand Up @@ -73,8 +75,12 @@ def make_output_name(
hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg)
hash_value(sha, params.steps)
hash_value(sha, params.seed)
hash_value(sha, params.steps)
hash_value(sha, params.lpw)
hash_value(sha, params.eta)
hash_value(sha, params.batch)
hash_value(sha, params.inversion)
hash_value(sha, size.width)
hash_value(sha, size.height)

Expand Down
4 changes: 4 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
lpw: bool = False,
eta: float = 0.0,
batch: int = 1,
inversion: str = None,
) -> None:
self.model = model
self.scheduler = scheduler
Expand All @@ -168,6 +169,7 @@ def __init__(
self.lpw = lpw or False
self.eta = eta
self.batch = batch
self.inversion = inversion

def tojson(self) -> Dict[str, Optional[Param]]:
return {
Expand All @@ -181,6 +183,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"lpw": self.lpw,
"eta": self.eta,
"batch": self.batch,
"inversion": self.inversion,
}

def with_args(self, **kwargs):
Expand All @@ -195,6 +198,7 @@ def with_args(self, **kwargs):
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
kwargs.get("inversion", self.inversion),
)


Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
scheduler = get_from_map(
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
)
inversion = get_not_empty(request.args, "inversion", get_config_value("inversion"))
inversion_path = get_model_path(inversion)

# image params
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
Expand Down Expand Up @@ -240,6 +242,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
lpw=lpw,
negative_prompt=negative_prompt,
batch=batch,
inversion=inversion_path,
)
size = Size(width, height)
return (device, params, size)
Expand Down
4 changes: 4 additions & 0 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ export interface ModelParams {
* Use the long prompt weighting pipeline.
*/
lpw: boolean;

inversion: string;
}

/**
Expand Down Expand Up @@ -183,6 +185,7 @@ export interface ReadyResponse {
export interface ModelsResponse {
diffusion: Array<string>;
correction: Array<string>;
inversion: Array<string>;
upscaling: Array<string>;
}

Expand Down Expand Up @@ -325,6 +328,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('upscaling', params.upscaling);
url.searchParams.append('correction', params.correction);
url.searchParams.append('lpw', String(params.lpw));
url.searchParams.append('inversion', params.inversion);
}

/**
Expand Down
15 changes: 15 additions & 0 deletions gui/src/components/control/ModelControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ export function ModelControl() {
});
}}
/>
<QueryList
id='inversion'
labels={MODEL_LABELS}
name='Textual Inversion'
query={{
result: models,
selector: (result) => result.inversion,
}}
value={params.inversion}
onChange={(inversion) => {
setModel({
inversion,
});
}}
/>
<QueryList
id='upscaling'
labels={MODEL_LABELS}
Expand Down
1 change: 1 addition & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ export function createStateSlices(server: ServerParams) {
platform: server.platform.default,
upscaling: server.upscaling.default,
correction: server.correction.default,
inversion: server.inversion.default,
lpw: false,
},
setModel(params) {
Expand Down

0 comments on commit 2e7de16

Please sign in to comment.