Skip to content

Commit

Permalink
feat(api): convert Textual Inversion weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 25, 2023
1 parent 947a1bf commit a31f7b9
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 14 deletions.
5 changes: 5 additions & 0 deletions api/extras.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
{
"diffusion": [
{
"name": "diffusion-ugly-sonic",
"source": "runwayml/stable-diffusion-v1-5",
"inversion": "sd-concepts-library/ugly-sonic"
},
{
"name": "diffusion-knollingcase",
"source": "Aybeeceedee/knollingcase"
Expand Down
10 changes: 7 additions & 3 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from yaml import safe_load

from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .diffusion.original import convert_diffusion_original
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
ConversionContext,
Expand Down Expand Up @@ -216,14 +217,17 @@ def convert_models(ctx: ConversionContext, args, models: Models):
ctx, name, model["source"], model_format=model_format
)

if "inversion" in model:
convert_diffusion_textual_inversion(ctx, source, model["inversion"])

if model_format in model_formats_original:
convert_diffusion_original(
ctx,
model,
source,
)
else:
convert_diffusion_stable(
convert_diffusion_diffusers(
ctx,
model,
source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
from onnx import load, save_model
from torch.onnx import export

from onnx_web.diffusion.load import optimize_pipeline

from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
from ...diffusion.load import optimize_pipeline
from ...diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from .utils import ConversionContext
from ..utils import ConversionContext

logger = getLogger(__name__)

Expand Down Expand Up @@ -63,7 +62,7 @@ def onnx_export(


@torch.no_grad()
def convert_diffusion_stable(
def convert_diffusion_diffusers(
ctx: ConversionContext,
model: Dict,
source: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
CLIPVisionConfig,
)

from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
from .diffusers import convert_diffusion_diffusers
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name

logger = getLogger(__name__)

Expand Down Expand Up @@ -1428,5 +1428,5 @@ def convert_diffusion_original(
if "vae" in model:
del model["vae"]

convert_diffusion_stable(ctx, model, working_name)
convert_diffusion_diffusers(ctx, model, working_name)
logger.info("ONNX pipeline saved to %s", name)
88 changes: 88 additions & 0 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from os import mkdir, path
from huggingface_hub.file_download import hf_hub_download
from transformers import CLIPTokenizer, CLIPTextModel
from torch.onnx import export
from sys import argv
from logging import getLogger

from ..utils import ConversionContext, sanitize_name

import torch

logger = getLogger(__name__)


def convert_diffusion_textual_inversion(context: ConversionContext, base_model: str, inversion: str):
cache_path = path.join(context.cache_path, f"inversion-{sanitize_name(inversion)}")
logger.info("converting textual inversion: %s -> %s", inversion, cache_path)

if not path.exists(cache_path):
mkdir(cache_path)

embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")

with open(token_file, "r") as f:
token = f.read()

tokenizer = CLIPTokenizer.from_pretrained(
base_model,
subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
base_model,
subfolder="text_encoder",
)

loaded_embeds = torch.load(embeds_file, map_location=context.map_location)

# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]

# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)

# add the token in tokenizer
num_added_tokens = tokenizer.add_tokens(token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)

# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))

# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds

# conversion stuff
text_input = tokenizer(
"A sample prompt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)

export(
text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
(
text_input.input_ids.to(device=context.training_device, dtype=torch.int32)
),
f=path.join(cache_path, "text_encoder", "model.onnx"),
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
do_constant_folding=True,
opset_version=context.opset,
)

if __name__ == "__main__":
context = ConversionContext.from_environ()
convert_diffusion_textual_inversion(context, argv[1], argv[2])
15 changes: 12 additions & 3 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@
available_platforms: List[DeviceParams] = []

# loaded from model_path
diffusion_models: List[str] = []
correction_models: List[str] = []
diffusion_models: List[str] = []
inversion_models: List[str] = []
upscaling_models: List[str] = []


Expand Down Expand Up @@ -301,8 +302,9 @@ def get_model_name(model: str) -> str:


def load_models(context: ServerContext) -> None:
global diffusion_models
global correction_models
global diffusion_models
global inversion_models
global upscaling_models

diffusion_models = [
Expand All @@ -323,6 +325,12 @@ def load_models(context: ServerContext) -> None:
correction_models = list(set(correction_models))
correction_models.sort()

inversion_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
]
inversion_models = list(set(inversion_models))
inversion_models.sort()

upscaling_models = [
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
]
Expand Down Expand Up @@ -496,8 +504,9 @@ def list_mask_filters():
def list_models():
return jsonify(
{
"diffusion": diffusion_models,
"correction": correction_models,
"diffusion": diffusion_models,
"inversion": inversion_models,
"upscaling": upscaling_models,
}
)
Expand Down

0 comments on commit a31f7b9

Please sign in to comment.