Skip to content

Commit

Permalink
fix(api): apply fp16 optimizations to LoRA and Textual Inversion blen…
Browse files Browse the repository at this point in the history
…ding
  • Loading branch information
ssube committed Mar 22, 2023
1 parent 4f6574c commit 0315a8c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
3 changes: 1 addition & 2 deletions api/onnx_web/convert/diffusion/diffusers.py
Expand Up @@ -97,8 +97,7 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")

torch_half = "torch-fp16" in ctx.optimizations
torch_dtype = torch.float16 if torch_half else torch.float32
torch_dtype = ctx.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", torch_dtype)

dest_path = path.join(ctx.model_path, name)
Expand Down
11 changes: 6 additions & 5 deletions api/onnx_web/convert/diffusion/lora.py
Expand Up @@ -62,6 +62,7 @@ def blend_loras(
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = context.torch_dtype()

base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
Expand All @@ -88,11 +89,11 @@ def blend_loras(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
)

down_weight = lora_model[key].to(dtype=torch.float32)
up_weight = lora_model[up_key].to(dtype=torch.float32)
down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)

dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy()
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()

try:
if len(up_weight.size()) == 2:
Expand Down Expand Up @@ -203,7 +204,7 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape)

# replace the original initializer
updated_node = numpy_helper.from_array(blended, weight_node.name)
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name)
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names:
Expand Down Expand Up @@ -232,7 +233,7 @@ def blend_loras(
logger.trace("blended weight shape: %s", blended.shape)

# replace the original initializer
updated_node = numpy_helper.from_array(blended, matmul_node.name)
updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Expand Up @@ -22,7 +22,7 @@ def blend_textual_inversions(
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float
dtype = context.numpy_dtype()
embeds = {}

for name, weight, base_token, inversion_format in inversions:
Expand Down Expand Up @@ -149,7 +149,7 @@ def blend_textual_inversions(
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(np.float32), embedding_node.name
embedding_weights.astype(dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
Expand Down
15 changes: 15 additions & 0 deletions api/onnx_web/server/context.py
Expand Up @@ -2,6 +2,9 @@
from os import environ, path
from typing import List, Optional

import torch
import numpy as np

from ..utils import get_boolean
from .model_cache import ModelCache

Expand Down Expand Up @@ -77,3 +80,15 @@ def from_environ(cls):
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit,
)

def torch_dtype(self):
if "torch-fp16" in self.optimizations:
return torch.float16
else:
return torch.float32

def numpy_dtype(self):
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
return np.float16
else:
return np.float32

0 comments on commit 0315a8c

Please sign in to comment.