Skip to content

Commit

Permalink
fix(api): match SDXL keys per LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 24, 2023
1 parent 8d44103 commit 74832fc
Showing 1 changed file with 32 additions and 36 deletions.
68 changes: 32 additions & 36 deletions api/onnx_web/convert/diffusion/lora.py
Expand Up @@ -452,72 +452,68 @@ def blend_loras(
else:
lora_prefix = f"lora_{model_type}_"

blended: Dict[str, np.ndarray] = {}
layers = []
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue

blended: Dict[str, np.ndarray] = {}
layers.append(blended)

for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
base_key, np_weights = blend_weights_loha(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing LoHA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
logger.trace(
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key, np_weights = blend_weights_lora(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
if base_key in blended:
logger.trace(
"summing LoRA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights

# rewrite node names for XL and flatten layers
weights = Dict[str, np.ndarray] = {}

for blended in layers:
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)

for key, value in blended.items():
if key in weights:
weights[key] = sum_weights(weights[key], value)
else:
weights[key] = value

# fix node names once
fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer
]
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]

# rewrite node names for XL
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)

logger.debug(
"updating %s of %s initializers",
len(blended.keys()),
len(weights.keys()),
len(base_model.graph.initializer),
)

unmatched_keys = []
for base_key, weights in blended.items():
for base_key, weights in weights.items():
conv_key = base_key + "_Conv"
gemm_key = base_key + "_Gemm"
matmul_key = base_key + "_MatMul"
Expand Down Expand Up @@ -579,7 +575,7 @@ def blend_loras(
else:
unmatched_keys.append(base_key)

logger.debug(
logger.trace(
"node counts: %s -> %s, %s -> %s",
len(fixed_initializer_names),
len(base_model.graph.initializer),
Expand All @@ -588,7 +584,7 @@ def blend_loras(
)

if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)

return base_model

Expand Down

0 comments on commit 74832fc

Please sign in to comment.