Skip to content

Commit

Permalink
fix(api): blend LoHA and LoRA weights for 1x1 kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 21, 2023
1 parent f963b12 commit f0109d3
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def blend_loras(

np_weights *= lora_weight
if base_key in blended:
blended_weights = blended[base_key]
logger.trace("summing LoHA weights: %s + %s", blended_weights.shape, np_weights.shape)

if blended_weights.shape != np_weights.shape and kernel == (1, 1):
logger.debug("expanding mismatched weights for 1x1 kernel")
blended[base_key] = np.expand_dims(blended_weights, axis=(2, 3))

blended[base_key] += np_weights
else:
blended[base_key] = np_weights
Expand Down Expand Up @@ -258,6 +265,13 @@ def blend_loras(

np_weights *= lora_weight
if base_key in blended:
blended_weights = blended[base_key]
logger.trace("summing weights: %s + %s", blended_weights.shape, np_weights.shape)

if blended_weights.shape != np_weights.shape and kernel == (1, 1):
logger.debug("expanding mismatched weights for 1x1 kernel")
blended[base_key] = np.expand_dims(blended_weights, axis=(2, 3))

blended[base_key] += np_weights
else:
blended[base_key] = np_weights
Expand Down

0 comments on commit f0109d3

Please sign in to comment.