Skip to content

Commit

Permalink
feat(api): blend Textual Inversions from prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 15, 2023
1 parent 973ad0f commit 506cf9f
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 125 deletions.
45 changes: 20 additions & 25 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,19 @@
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file

from onnx_web.convert.utils import ConversionContext
from ..utils import ConversionContext

logger = getLogger(__name__)


###
# everything in this file is still super experimental and may not produce valid ONNX models
###


def buffer_external_data_tensors(
model: ModelProto,
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
external_data = []
for tensor in model.graph.initializer:
name = tensor.name

logger.info("externalizing tensor: %s", name)
logger.debug("externalizing tensor: %s", name)
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
Expand All @@ -59,13 +54,13 @@ def fix_node_name(key: str):
return fixed_name


def merge_lora(
def blend_loras(
base_name: str,
lora_names: List[str],
dest_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None,
):
base_model = load(base_name)
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_file(name) for name in lora_names]
lora_count = len(lora_models)
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
Expand All @@ -86,7 +81,7 @@ def merge_lora(

up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.info(
logger.debug(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
)

Expand All @@ -99,7 +94,7 @@ def merge_lora(
try:
if len(up_weight.size()) == 2:
# blend for nn.Linear
logger.info(
logger.debug(
"blending weights for Linear node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
Expand All @@ -109,7 +104,7 @@ def merge_lora(
np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1
logger.info(
logger.debug(
"blending weights for Conv node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
Expand Down Expand Up @@ -161,7 +156,7 @@ def merge_lora(
conv_key = base_key + "_Conv"
matmul_key = base_key + "_MatMul"

logger.info(
logger.debug(
"key %s has conv: %s, matmul: %s",
base_key,
conv_key in fixed_node_names,
Expand All @@ -171,28 +166,28 @@ def merge_lora(
if conv_key in fixed_node_names:
conv_idx = fixed_node_names.index(conv_key)
conv_node = base_model.graph.node[conv_idx]
logger.info("found conv node: %s", conv_node.name)
logger.debug("found conv node: %s", conv_node.name)

# find weight initializer
logger.info("conv inputs: %s", conv_node.input)
logger.debug("conv inputs: %s", conv_node.input)
weight_name = [n for n in conv_node.input if ".weight" in n][0]
weight_name = fix_initializer_name(weight_name)

weight_idx = fixed_initializer_names.index(weight_name)
weight_node = base_model.graph.initializer[weight_idx]
logger.info("found weight initializer: %s", weight_node.name)
logger.debug("found weight initializer: %s", weight_node.name)

# blending
base_weights = numpy_helper.to_array(weight_node)
logger.info(
logger.debug(
"found blended weights for conv: %s, %s",
weights.shape,
base_weights.shape,
)

blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
blended = np.expand_dims(blended, (2, 3))
logger.info("blended weight shape: %s", blended.shape)
logger.debug("blended weight shape: %s", blended.shape)

# replace the original initializer
updated_node = numpy_helper.from_array(blended, weight_node.name)
Expand All @@ -201,33 +196,33 @@ def merge_lora(
elif matmul_key in fixed_node_names:
weight_idx = fixed_node_names.index(matmul_key)
weight_node = base_model.graph.node[weight_idx]
logger.info("found matmul node: %s", weight_node.name)
logger.debug("found matmul node: %s", weight_node.name)

# find the MatMul initializer
logger.info("matmul inputs: %s", weight_node.input)
logger.debug("matmul inputs: %s", weight_node.input)
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]

matmul_idx = fixed_initializer_names.index(matmul_name)
matmul_node = base_model.graph.initializer[matmul_idx]
logger.info("found matmul initializer: %s", matmul_node.name)
logger.debug("found matmul initializer: %s", matmul_node.name)

# blending
base_weights = numpy_helper.to_array(matmul_node)
logger.info(
logger.debug(
"found blended weights for matmul: %s, %s",
weights.shape,
base_weights.shape,
)

blended = base_weights + weights.transpose()
logger.info("blended weight shape: %s", blended.shape)
logger.debug("blended weight shape: %s", blended.shape)

# replace the original initializer
updated_node = numpy_helper.from_array(blended, matmul_node.name)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:
logger.info("could not find any nodes for %s", base_key)
logger.warning("could not find any nodes for %s", base_key)

logger.info(
"node counts: %s -> %s, %s -> %s",
Expand Down Expand Up @@ -256,7 +251,7 @@ def merge_lora(
args.lora_weights,
)

blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
Expand Down
195 changes: 110 additions & 85 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,115 @@
from logging import getLogger
from os import makedirs, path
from typing import Optional
from typing import List, Optional, Tuple

import numpy as np
import torch
from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model
from torch.onnx import export
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTokenizer

from ...server.context import ServerContext
from ..utils import ConversionContext

logger = getLogger(__name__)


@torch.no_grad()
def blend_textual_inversions(
context: ServerContext,
text_encoder: Optional[ModelProto],
tokenizer: Optional[CLIPTokenizer],
inversion_names: List[str],
inversion_formats: List[str],
inversion_weights: Optional[List[float]] = None,
base_tokens: Optional[List[str]] = None,
) -> Tuple[ModelProto, CLIPTokenizer]:
dtype = np.float # TODO: fixed type, which one?
# prev: text_encoder.get_input_embeddings().weight.dtype
embeds = {}

for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
if format == "concept":
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")

with open(token_file, "r") as f:
token = base_token or f.read()

loaded_embeds = torch.load(embeds_file)

# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]

layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype)
layer *= weight
if trained_token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
elif format == "embeddings":
loaded_embeds = torch.load(name)

string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]

# separate token and embeds
trained_token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[trained_token]

num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens", num_tokens)

for i in range(num_tokens):
token = f"{base_token or name}-{i}"
layer = trained_embeds[i,:].cpu().numpy().astype(dtype)
layer *= weight
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
else:
raise ValueError(f"unknown Textual Inversion format: {format}")

# add the tokens to the tokenizer
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)

logger.debug("added %s tokens", num_added_tokens)

# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
embedding_weights = numpy_helper.to_array(embedding_node)

weights_dim = embedding_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)

for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.debug(
"embedding %s weights for token %s", weights.shape, token
)
embedding_weights[token_id] = weights

# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
logger.debug("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)

return (text_encoder, tokenizer)


@torch.no_grad()
def convert_diffusion_textual_inversion(
context: ConversionContext,
Expand Down Expand Up @@ -40,101 +138,28 @@ def convert_diffusion_textual_inversion(

makedirs(encoder_path, exist_ok=True)

if format == "concept":
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")

with open(token_file, "r") as f:
token = base_token or f.read()

loaded_embeds = torch.load(embeds_file, map_location=context.map_location)

# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]
elif format == "embeddings":
loaded_embeds = torch.load(inversion, map_location=context.map_location)

string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]

# separate token and embeds
trained_token = list(string_to_token.keys())[0]
embeds = string_to_param[trained_token]

num_tokens = embeds.shape[0]
logger.info("generating %s layer tokens", num_tokens)
token = [f"{base_token or name}-{i}" for i in range(num_tokens)]
else:
raise ValueError(f"unknown textual inversion format: {format}")

logger.info("found embeddings for token %s: %s", token, embeds.shape)

text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx"))
tokenizer = CLIPTokenizer.from_pretrained(
base_model,
subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
base_model,
subfolder="text_encoder",
)

# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)

# add the token in tokenizer
num_added_tokens = tokenizer.add_tokens(token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)

logger.info("added %s tokens", num_added_tokens)

# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))

if len(embeds.shape) == 2:
# multiple vectors in embeds
for i in range(embeds.shape[0]):
layer_embeds = embeds[i]
layer_token = token[i]
logger.debug(
"embedding %s vector for layer %s", layer_embeds.shape, layer_token
)
token_id = tokenizer.convert_tokens_to_ids(layer_token)
text_encoder.get_input_embeddings().weight.data[token_id] = layer_embeds
else:
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds

# conversion stuff
text_input = tokenizer(
"A sample prompt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
text_encoder, tokenizer = blend_textual_inversions(
context,
text_encoder,
tokenizer,
[inversion],
[format],
[1.0],
base_token=(base_token or name),
)

logger.info("saving tokenizer for textual inversion")
tokenizer.save_pretrained(tokenizer_path)

logger.info("saving text encoder for textual inversion")
export(
save_model(
text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
(text_input.input_ids.to(dtype=torch.int32)),
f=encoder_model,
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
do_constant_folding=True,
opset_version=context.opset,
)

logger.info("textual inversion saved to %s", dest_path)

0 comments on commit 506cf9f

Please sign in to comment.