Skip to content

Commit

Permalink
fix(api): write tests for embedding/inversion blending
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Oct 7, 2023
1 parent ebdfa78 commit e9b1375
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 149 deletions.
301 changes: 156 additions & 145 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,155 @@
logger = getLogger(__name__)


def detect_embedding_format(loaded_embeds) -> str:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
return "concept"
elif "emb_params" in keys:
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
return "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
return "embeddings"
else:
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
return None


def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
# separate token and the embeds
token = list(loaded_embeds.keys())[0]

layer = loaded_embeds[token].numpy().astype(dtype)
layer *= weight

if base_token in embeds:
embeds[base_token] += layer
else:
embeds[base_token] = layer

if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer


def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
emb_params = loaded_embeds["emb_params"]

num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)

sum_layer = np.zeros(emb_params[0, :].shape)

for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = emb_params[i, :].numpy().astype(dtype)
layer *= weight

sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer

# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer

sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer


def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]

# separate token and embeds
token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[token]

num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)

sum_layer = np.zeros(trained_embeds[0, :].shape)

for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = trained_embeds[i, :].numpy().astype(dtype)
layer *= weight

sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer

# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer

sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer


def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
base_weights = numpy_helper.to_array(embedding_node)

weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)

for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.trace("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights

# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)


@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}

for name, weight, base_token, inversion_format in inversions:
for name, weight, base_token, format in embeddings:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
Expand All @@ -43,153 +179,28 @@ def blend_textual_inversions(
logger.warning("unable to load tensor")
continue

if inversion_format is None:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept"
elif "emb_params" in keys:
logger.debug(
"detected Textual Inversion parameter embeddings: %s", keys
)
inversion_format = "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
inversion_format = "embeddings"
else:
logger.error(
"unknown Textual Inversion format, no recognized keys: %s", keys
)
continue

if inversion_format == "concept":
# separate token and the embeds
token = list(loaded_embeds.keys())[0]

layer = loaded_embeds[token].numpy().astype(dtype)
layer *= weight

if base_token in embeds:
embeds[base_token] += layer
else:
embeds[base_token] = layer

if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
elif inversion_format == "parameters":
emb_params = loaded_embeds["emb_params"]

num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)

sum_layer = np.zeros(emb_params[0, :].shape)

for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = emb_params[i, :].numpy().astype(dtype)
layer *= weight

sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer

# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer

sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
elif inversion_format == "embeddings":
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]

# separate token and embeds
token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[token]

num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)

sum_layer = np.zeros(trained_embeds[0, :].shape)

for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = trained_embeds[i, :].numpy().astype(dtype)
layer *= weight

sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer

# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer

sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
if format is None:
format = detect_embedding_format()

if format == "concept":
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "parameters":
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "embeddings":
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
else:
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
raise ValueError(f"unknown Textual Inversion format: {format}")

# add the tokens to the tokenizer
logger.debug(
"found embeddings for %s tokens: %s",
len(embeds.keys()),
list(embeds.keys()),
# add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)

logger.trace("added %s tokens", num_added_tokens)

# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
base_weights = numpy_helper.to_array(embedding_node)

weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)

for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.trace("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights

# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
logger.trace("added %s tokens", num_added_tokens)

blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)

return (text_encoder, tokenizer)

Expand Down
4 changes: 3 additions & 1 deletion api/onnx_web/image/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

from PIL import Image, ImageChops

from ..params import Border, Size
Expand All @@ -13,7 +15,7 @@ def expand_image(
fill="white",
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
):
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
size = Size(*source.size).add_border(expand)
size = tuple(size)
origin = (expand.left, expand.top)
Expand Down
8 changes: 8 additions & 0 deletions api/onnx_web/server/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
from functools import partial
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse

from optimum.onnxruntime.modeling_diffusion import (
ORTModel,
ORTStableDiffusionPipelineBase,
)

from ..torch_before_ort import SessionOptions
from ..utils import run_gc
from .context import ServerContext

Expand Down
8 changes: 6 additions & 2 deletions api/onnx_web/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@
]


def worker_main(worker: WorkerContext, server: ServerContext, *args):
apply_patches(server)
def worker_main(
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
):
setproctitle("onnx-web worker: %s" % (worker.device.device))

if patch:
apply_patches(server)

logger.trace(
"checking in from worker with providers: %s", get_available_providers()
)
Expand Down

0 comments on commit e9b1375

Please sign in to comment.