Skip to content

Commit

Permalink
feat(api): experimental support for LoHA networks
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 7, 2023
1 parent 7e3ca8a commit 35432f1
Showing 1 changed file with 56 additions and 4 deletions.
60 changes: 56 additions & 4 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,80 @@ def blend_loras(
continue

for key in lora_model.keys():
if ".lora_down" in key and lora_prefix in key:
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")

w1b_key = key.replace("hada_w1_a", "hada_w1_b")
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
logger.trace(
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
key,
w1b_key,
w2a_key,
w2b_key,
alpha_key,
)

w1a_weight = lora_model[key].to(dtype=dtype)
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
w2b_weight = lora_model[w2b_key].to(dtype=dtype)

dim = w1a_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()

try:
logger.trace(
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
w1a_weight,
w1b_weight,
w2a_weight,
w2b_weight,
)
weights = (w1a_weight @ w2a_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim)

np_weights *= lora_weight
if base_key in blended:
blended[base_key] += np_weights
else:
blended[base_key] = np_weights
except Exception:
logger.exception("error blending weights for LoHA key %s", base_key)
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")

mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
)

down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)

mid_weight = None
if mid_key in lora_model:
mid_weight = lora_model[mid_key].to(dtype=dtype)

dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()

try:
if len(down_weight.size()) == 2:
# blend for nn.Linear
logger.trace(
"blending weights for Linear node: %s, %s, %s",
"blending weights for Linear node: (%s @ %s) * %s",
down_weight.shape,
up_weight.shape,
alpha,
)
# TODO: include mids
weights = up_weight @ down_weight
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and down_weight.shape[-2:] == (
Expand All @@ -117,6 +167,7 @@ def blend_loras(
up_weight.shape,
alpha,
)
# TODO: include mids
weights = (
(
up_weight.squeeze(3).squeeze(2)
Expand All @@ -137,6 +188,7 @@ def blend_loras(
up_weight.shape,
alpha,
)
# TODO: include mids
weights = torch.nn.functional.conv2d(
down_weight.permute(1, 0, 2, 3), up_weight
).permute(1, 0, 2, 3)
Expand All @@ -156,7 +208,7 @@ def blend_loras(
blended[base_key] = np_weights

except Exception:
logger.exception("error blending weights for key %s", base_key)
logger.exception("error blending weights for LoRA key %s", base_key)

logger.trace(
"updating %s of %s initializers: %s",
Expand Down

0 comments on commit 35432f1

Please sign in to comment.