Skip to content

Commit

Permalink
feat(api): blend LoRAs and Textual Inversions from extras file
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 18, 2023
1 parent 1d44f98 commit 84bd852
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 11 deletions.
72 changes: 65 additions & 7 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from urllib.parse import urlparse

from jsonschema import ValidationError, validate
from onnx import load_model, save_model
from transformers import CLIPTokenizer
from yaml import safe_load

from .correction_gfpgan import convert_correction_gfpgan
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .diffusion.textual_inversion import blend_textual_inversions
from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
ConversionContext,
Expand Down Expand Up @@ -229,22 +232,77 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source,
)

# keep track of which models have been blended
blend_models = {}

for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))

if "tokenizer" not in blend_models:
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer")

inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", "embeddings")
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)

blend_textual_inversions(
ctx,
inversion_name,
model["source"],
inversion_source,
inversion_format,
base_token=inversion.get("token"),
blend_models["text_encoder"],
blend_models["tokenizer"],
[inversion_source],
[inversion_format],
base_token=inversion_token,
inversion_weights=[inversion_weight],
)

for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))

if "unet" not in blend_models:
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx"))

# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
ctx, f"{name}-lora-{lora_name}", lora_source
)
lora_weight = lora.get("weight", 1.0)

blend_loras(
ctx,
blend_models["text_encoder"],
[lora_name],
[lora_source],
"text_encoder",
lora_weights=[lora_weight],
)

if "tokenizer" in blend_models:
dest_path = path.join(ctx.model_path, model, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)

for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(ctx.model_path, model, name, "model.onnx")
logger.debug("saving blended %s model to %s", name, dest_path)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)


except Exception:
logger.exception(
"error converting diffusion model %s",
Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def convert_diffusion_textual_inversion(
inversion: str,
format: str,
base_token: Optional[str] = None,
weight: Optional[float] = 1.0,
):
dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info(
Expand Down Expand Up @@ -161,7 +162,7 @@ def convert_diffusion_textual_inversion(
tokenizer,
[inversion],
[format],
[1.0],
[weight],
base_token=(base_token or name),
)

Expand Down
7 changes: 4 additions & 3 deletions docs/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ You can blend extra networks with the diffusion model using `<type:name:weight>`

#### LoRA tokens

You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a `lora` token:
You can blend one or more [LoRA embeddings](https://arxiv.org/abs/2106.09685) with the ONNX diffusion model using a
`lora` token:

```none
<lora:name:0.5>
Expand All @@ -341,8 +342,8 @@ contain any special characters.

#### Textual Inversion tokens

You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model using the `inversion`
token:
You can blend one or more [Textual Inversions](https://textual-inversion.github.io/) with the ONNX diffusion model
using the `inversion` token:

```none
<inversion:autumn:1.0>
Expand Down

0 comments on commit 84bd852

Please sign in to comment.