Skip to content

Commit

Permalink
fix(api): handle blending of mismatched kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jun 17, 2023
1 parent f8d59ab commit 719b349
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def fix_node_name(key: str):
return fixed_name


def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
return (
max(x, shape[2]),
max(y, shape[3]),
)


def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
Expand Down Expand Up @@ -278,9 +285,12 @@ def blend_loras(

for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = up_weight.squeeze(3).squeeze(
down_w, down_h = kernel_slice(w, h, down_weight.shape)
up_w, up_h = kernel_slice(w, h, up_weight.shape)

weights[:, :, w, h] = up_weight[:, :, up_w, up_h].squeeze(3).squeeze(
2
) @ down_weight.squeeze(3).squeeze(2)
) @ down_weight[:, :, down_w, down_h].squeeze(3).squeeze(2)

np_weights = weights.numpy() * (alpha / dim)
else:
Expand Down

0 comments on commit 719b349

Please sign in to comment.