Skip to content

Commit

Permalink
feat(api): support 3x3 kernels in LoRA and LoCONs
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 7, 2023
1 parent 25176fe commit 7e3ca8a
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def blend_loras(
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()

try:
if len(up_weight.size()) == 2:
if len(down_weight.size()) == 2:
# blend for nn.Linear
logger.trace(
"blending weights for Linear node: %s, %s, %s",
Expand All @@ -106,7 +106,10 @@ def blend_loras(
)
weights = up_weight @ down_weight
np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
elif len(down_weight.size()) == 4 and down_weight.shape[-2:] == (
1,
1,
):
# blend for nn.Conv2d 1x1
logger.trace(
"blending weights for Conv 1x1 node: %s, %s, %s",
Expand All @@ -123,7 +126,10 @@ def blend_loras(
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
elif len(down_weight.size()) == 4 and down_weight.shape[-2:] == (
3,
3,
):
# blend for nn.Conv2d 3x3
logger.trace(
"blending weights for Conv 3x3 node: %s, %s, %s",
Expand Down Expand Up @@ -199,8 +205,12 @@ def blend_loras(
base_weights.shape,
)

blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
blended = np.expand_dims(blended, (2, 3))
if base_weights.shape[-2:] == (1, 1):
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
blended = np.expand_dims(blended, (2, 3))
else:
blended = base_weights + weights

logger.trace("blended weight shape: %s", blended.shape)

# replace the original initializer
Expand Down

0 comments on commit 7e3ca8a

Please sign in to comment.