Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LORA #213

Closed
ForserX opened this issue Mar 5, 2023 · 12 comments
Closed

LORA #213

ForserX opened this issue Mar 5, 2023 · 12 comments
Labels
model/diffusion model/lora status/fixed issues that have been fixed and released type/feature new features
Milestone

Comments

@ForserX
Copy link
Contributor

ForserX commented Mar 5, 2023

# LoRA magic

There is a way to assemble LoRA + Diffusers on the fly. I will be glad if you figure out how to throw something like this in ONNX.

@ssube
Copy link
Owner

ssube commented Mar 5, 2023

So I've been looking into this, and in theory, it is possible: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L123

The difficulty seems to be in building a valid ONNX protobuf of the new model.

That code is not hooked up yet because I haven't been able to make it reliably emit valid models, but I'm not sure if that is a bug with the merge code or with the model checker. Adding too many initializers to the graph at once can cause a segfault in onnx or protobuf as well, but looping seemed to help with that (https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L85).

@ssube ssube added status/planned issues that have been planned but not started type/feature new features model/diffusion labels Mar 5, 2023
@ssube ssube added this to the v0.9 milestone Mar 11, 2023
@ssube
Copy link
Owner

ssube commented Mar 12, 2023

I've made some progress and done some research that may or may not be worth writing down, and see two issues so far:

  • I think I have the UNet blended, at least it emits a valid model, but the text_encoder/model.onnx is primarily .bias nodes and missing a lot of .weights that I would expect. I'm not sure if that is a problem with the models, the conversion, or just the naming - there are a bunch of onnx::MatMul nodes that are linked to some of the nodes with "missing" weights, often as inputs. I'm wondering if the ONNX converter is losing the name on those nodes and I can follow the bias, or maybe this is something else entirely?
  • the ONNX runtime's InferenceSession takes a model by filename/path or from a bytes object, but model.SerializeToString() has the 2GB protobuf limitation, and won't serialize the UNet to bytes in its whole form. Calling convert_model_to_external_data/write_external_data_tensors will write the extra data into an external protobuf file, but that doesn't have any way to write into memory (bytes or BytesIO). If you have a modern SSD or some Optane, that might not be a problem, but I need to raise issues/PRs with ONNX and ORT. (https://onnx.ai/onnx/_modules/onnx/external_data_helper.html#write_external_data_tensors)
Some of those node names:
lora_te_text_model_encoder_layers_6_self_attn_k_proj.lora_down.weight
lora_te_text_model_encoder_layers_6_self_attn_out_proj.lora_down.weight
lora_te_text_model_encoder_layers_6_self_attn_q_proj.lora_down.weight
lora_te_text_model_encoder_layers_6_self_attn_v_proj.lora_down.weight
lora_te_text_model_encoder_layers_7_mlp_fc1.lora_down.weight
lora_te_text_model_encoder_layers_7_mlp_fc2.lora_down.weight
lora_te_text_model_encoder_layers_7_self_attn_k_proj.lora_down.weight
lora_te_text_model_encoder_layers_7_self_attn_out_proj.lora_down.weight
lora_te_text_model_encoder_layers_7_self_attn_q_proj.lora_down.weight

# LoRA has:
>>> [k for k in lm.keys() if "lora_te_" in k and "layers_5" in k]
[
'lora_te_text_model_encoder_layers_5_mlp_fc1.alpha', 
'lora_te_text_model_encoder_layers_5_mlp_fc1.lora_down.weight', 
'lora_te_text_model_encoder_layers_5_mlp_fc1.lora_up.weight',
'lora_te_text_model_encoder_layers_5_mlp_fc2.alpha', 
'lora_te_text_model_encoder_layers_5_mlp_fc2.lora_down.weight', 
'lora_te_text_model_encoder_layers_5_mlp_fc2.lora_up.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_k_proj.alpha', 
'lora_te_text_model_encoder_layers_5_self_attn_k_proj.lora_down.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_k_proj.lora_up.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_out_proj.alpha', 
'lora_te_text_model_encoder_layers_5_self_attn_out_proj.lora_down.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_out_proj.lora_up.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_q_proj.alpha', 
'lora_te_text_model_encoder_layers_5_self_attn_q_proj.lora_down.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_q_proj.lora_up.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_v_proj.alpha', 
'lora_te_text_model_encoder_layers_5_self_attn_v_proj.lora_down.weight', 
'lora_te_text_model_encoder_layers_5_self_attn_v_proj.lora_up.weight'
]

>>> [(k, lm[k].shape) for k in lm.keys() if "lora_te_" in k and "layers_5" in k and "mlp" in k]
[
('lora_te_text_model_encoder_layers_5_mlp_fc1.alpha', torch.Size([])), 
('lora_te_text_model_encoder_layers_5_mlp_fc1.lora_down.weight', torch.Size([128, 768])), 
('lora_te_text_model_encoder_layers_5_mlp_fc1.lora_up.weight', torch.Size([3072, 128])), 
('lora_te_text_model_encoder_layers_5_mlp_fc2.alpha', torch.Size([])), 
('lora_te_text_model_encoder_layers_5_mlp_fc2.lora_down.weight', torch.Size([128, 3072])), 
('lora_te_text_model_encoder_layers_5_mlp_fc2.lora_up.weight', torch.Size([768, 128]))
]

# ONNX has:
>>> [(n.name, onnx.numpy_helper.to_array(n).shape) for n in om_nn]
[
('text_model.encoder.layers.5.self_attn.k_proj.bias', (768,)), 
('text_model.encoder.layers.5.self_attn.v_proj.bias', (768,)), 
('text_model.encoder.layers.5.self_attn.q_proj.bias', (768,)), 
('text_model.encoder.layers.5.self_attn.out_proj.bias', (768,)), 
('text_model.encoder.layers.5.layer_norm1.weight', (768,)), 
('text_model.encoder.layers.5.layer_norm1.bias', (768,)), 
('text_model.encoder.layers.5.mlp.fc1.bias', (3072,)), 
('text_model.encoder.layers.5.mlp.fc2.bias', (768,)),
('text_model.encoder.layers.5.layer_norm2.weight', (768,)), 
('text_model.encoder.layers.5.layer_norm2.bias', (768,))
]

# related names:
/text_model/encoder/layers.5/mlp/fc1/Add
/text_model/encoder/layers.5/mlp/fc1/MatMul
Operator type sanity check script:
import onnx.numpy_helper
import torch.nn as nn
import torch.onnx

# make a net with single Conv2d
conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=4, stride=1, padding=1, dilation=1, groups=1, bias=True)
dummy_input = torch.randn(10, 3, 224, 224)
torch.onnx.export(conv, dummy_input, f="/tmp/onnx-conv2d.pb")

# load it back
conv2d_onnx = onnx.load("/tmp/onnx-conv2d.pb")
print("conv2d init:", [(n.name, len(n.raw_data)) for n in conv2d_onnx.graph.initializer])
print("conv2d node:", [(n.name, n.input) for n in conv2d_onnx.graph.node])
print("conv2d output:", conv2d_onnx.graph.output)


# make a net with single Linear
conv = nn.Linear(224, 6720, bias=True)
torch.onnx.export(conv, dummy_input, f="/tmp/onnx-linear.pb")

# load it back
lin_onnx = onnx.load("/tmp/onnx-linear.pb")
print("linear init:", [(n.name, len(n.raw_data)) for n in lin_onnx.graph.initializer])
print("linear node:", [(n.name, n.input) for n in lin_onnx.graph.node])
print("linear output:", lin_onnx.graph.output)


# shapes
print("conv2d shapes:", [(n.name, onnx.numpy_helper.to_array(n).shape) for n in conv2d_onnx.graph.initializer])
print("linear shapes:", [(n.name, onnx.numpy_helper.to_array(n).shape) for n in lin_onnx.graph.initializer])
Operator script output:
conv2d init: [('weight', 576), ('bias', 12)]                                                                      
conv2d node: [('/Conv', ['input', 'weight', 'bias'])]                                                             
conv2d output: [name: "3"                                                                                         
type {                                                                                                            
  tensor_type {                                                                                                   
    elem_type: 1                                                                                                  
    shape {                                                                                                                                                                                                                         
      dim {                                                                                                       
        dim_value: 10                                                                                             
      }                                                                                                           
      dim {                                                                                                       
        dim_value: 3                                     
      }                                                                                                           
      dim {                                                                                                       
        dim_value: 223                                                                                            
      }                                                                                                           
      dim {                                              
        dim_value: 223                                                                                            
      }                                                                                                           
    }                                                                                                                                                                                                                               
  }                                                                                                               
}                                                                                                                 
]                                                        
linear init: [('bias', 26880), ('onnx::MatMul_6', 6021120)]                                                       
linear node: [('/MatMul', ['onnx::MatMul_0', 'onnx::MatMul_6']), ('/Add', ['bias', '/MatMul_output_0'])]       
linear output: [name: "5"                                                                                                                                                                                                           
type {                                                                                                            
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 10
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 224
      }
      dim {
        dim_value: 6720
      }
    }
  }
}
]
conv2d shapes: [('weight', (3, 3, 4, 4)), ('bias', (3,))]
linear shapes: [('bias', (6720,)), ('onnx::MatMul_6', (224, 6720))]

tl;dr: nn.Conv2d becomes [('weight', 576), ('bias', 12)] but nn.Linear becomes [('bias', 26880), ('onnx::MatMul_6', 6021120)], and I think that onnx::MatMul needs to be found/adjusted.

@ssube
Copy link
Owner

ssube commented Mar 13, 2023

Looking at this further, with a little bit of progress:

It looks like ORT does offer a way to load external data from memory, so blending and even converting models without ever writing them to disk should be possible. For most of the large base models, saving them will still make sense, but that should save some SSD wear for LoRAs.

I was able to blend some models by looking up the MatMul nodes and write that out as a valid ONNX model, at least valid enough to load and run inference, but it comes out as random colored spots. Rather than guess at what the nodes mean, I'm writing a script to diff the ONNX models and working backwards from there.

diff script:
from logging import getLogger, basicConfig, DEBUG
from onnx import load_model, ModelProto
from onnx.numpy_helper import to_array
from sys import argv, stdout


basicConfig(stream=stdout, level=DEBUG)

logger = getLogger(__name__)

def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
  if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
    logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
  else:
    for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
      if ref_init.name != cmp_init.name:
        logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
      elif ref_init.data_location != cmp_init.data_location:
        logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
      elif ref_init.data_type != cmp_init.data_type:
        logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
      elif len(ref_init.raw_data) != len(cmp_init.raw_data):
        logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
      elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
        ref_data = to_array(ref_init)
        cmp_data = to_array(cmp_init)
        data_diff = ref_data - cmp_data
        if data_diff.max() > 0:
          logger.info("raw data differs: %s", data_diff)
      else:
        logger.info("initializers are identical in all checked fields: %s", ref_init.name)


if __name__ == "__main__":
  ref_path = argv[1]
  cmp_paths = argv[2:]

  logger.info("loading reference model from %s", ref_path)
  ref_model = load_model(ref_path)

  for cmp_path in cmp_paths:
    logger.info("loading comparison model from %s", cmp_path)
    cmp_model = load_model(cmp_path)
    diff_models(ref_model, cmp_model)

My initial comparison of two text_encoders is that all of the changes to initializers are in those onnx::MatMul_2214 nodes, so my blending must be wrong somewhere.

INFO:__main__:raw data differs for onnx::MatMul_2214: [[-5.97857405e-04 -8.29659402e-05  2.55594961e-04 ... -3.97935510e-05
   1.16263516e-04 -2.64257425e-04]
 [-6.45183027e-05  1.43377110e-05 -1.45569444e-04 ... -2.29005702e-04
   3.29062605e-04 -6.15114346e-04]
 [-2.21131369e-04 -2.17929482e-05  2.85109505e-04 ... -1.59675255e-04
  -1.37356110e-04 -4.39988682e-04]
 ...
 [ 2.69189477e-05 -1.48676336e-05 -2.39353627e-04 ... -9.81168123e-05
   1.52554829e-04 -1.76237896e-04]
 [ 1.43175945e-04  6.34291209e-05  2.50307843e-04 ...  2.84195878e-04
   4.65737656e-04  3.63643281e-04]
 [ 1.87259167e-04  1.60088763e-04  1.67272985e-04 ... -1.12581765e-04
  -2.32110731e-04  3.32263298e-04]]
INFO:__main__:raw data differs for onnx::MatMul_2221: [[ 5.9778802e-05  1.2901332e-04 -2.7800910e-05 ... -5.8663264e-04
  -3.4572277e-04 -2.7506612e-05]
 [ 2.1521933e-04  6.2489137e-04  5.4099597e-05 ... -1.2307428e-05
  -8.6032785e-05  3.8439641e-05]
 [ 1.4275452e-04  5.2705873e-04  4.9114600e-04 ...  6.3313171e-05
   3.5880134e-05  3.1778496e-04]
 ...
 [ 1.3154559e-04 -3.6602188e-04  2.6111957e-06 ... -1.4451891e-04
  -5.4034032e-04 -1.6938546e-04]
 [ 5.1567703e-04  1.9872934e-04  6.5570464e-04 ...  7.1686227e-05
   2.2495119e-04 -2.4915300e-04]
 [-1.3470743e-04 -6.4091664e-04 -5.4577552e-04 ...  8.0384768e-04

@ForserX
Copy link
Contributor Author

ForserX commented Mar 14, 2023

I want to say one thing. Even with Diffusers, LoRA models sometimes give out strange artifacts

@ssube
Copy link
Owner

ssube commented Mar 14, 2023

I think I figured out the problem with the artifacts, or at least the cause: the script I have so far is converting the nn.Linear nodes but the nn.Conv2d still show a difference in the output value.

(onnx_env) ssube@compute-infer-1:/opt/onnx-web/api$ python3 onnx-diff.py /opt/onnx-web/models/diffusion-lora-buffy/unet/model.onnx /tmp/lora-unet.onnx  | grep differs
INFO:__main__:raw data differs for down_blocks.0.attentions.0.proj_in.weight: 0.80306447
INFO:__main__:raw data differs for down_blocks.0.attentions.0.proj_out.weight: 0.97295666
INFO:__main__:raw data differs for down_blocks.0.attentions.1.proj_in.weight: 0.75665176
INFO:__main__:raw data differs for down_blocks.0.attentions.1.proj_out.weight: 0.90466636
INFO:__main__:raw data differs for down_blocks.1.attentions.0.proj_in.weight: 0.7046069
INFO:__main__:raw data differs for down_blocks.1.attentions.0.proj_out.weight: 0.7026885
INFO:__main__:raw data differs for down_blocks.1.attentions.1.proj_in.weight: 0.7333188
INFO:__main__:raw data differs for down_blocks.1.attentions.1.proj_out.weight: 0.9321379
INFO:__main__:raw data differs for down_blocks.2.attentions.0.proj_in.weight: 0.593909
INFO:__main__:raw data differs for down_blocks.2.attentions.0.proj_out.weight: 0.5531869
INFO:__main__:raw data differs for down_blocks.2.attentions.1.proj_in.weight: 0.58815855
INFO:__main__:raw data differs for down_blocks.2.attentions.1.proj_out.weight: 0.71980715
INFO:__main__:raw data differs for up_blocks.1.attentions.0.proj_in.weight: 0.7294675
INFO:__main__:raw data differs for up_blocks.1.attentions.0.proj_out.weight: 0.86747974
INFO:__main__:raw data differs for up_blocks.1.attentions.1.proj_in.weight: 0.7914733
INFO:__main__:raw data differs for up_blocks.1.attentions.1.proj_out.weight: 0.9515898
INFO:__main__:raw data differs for up_blocks.1.attentions.2.proj_in.weight: 0.8609139
INFO:__main__:raw data differs for up_blocks.1.attentions.2.proj_out.weight: 1.0044274
INFO:__main__:raw data differs for up_blocks.2.attentions.0.proj_in.weight: 1.0911213
INFO:__main__:raw data differs for up_blocks.2.attentions.0.proj_out.weight: 1.110176
INFO:__main__:raw data differs for up_blocks.2.attentions.1.proj_in.weight: 1.0298531
INFO:__main__:raw data differs for up_blocks.2.attentions.1.proj_out.weight: 1.1285026
INFO:__main__:raw data differs for up_blocks.2.attentions.2.proj_in.weight: 1.0006539
INFO:__main__:raw data differs for up_blocks.2.attentions.2.proj_out.weight: 1.2107546
INFO:__main__:raw data differs for up_blocks.3.attentions.0.proj_in.weight: 1.2077166
INFO:__main__:raw data differs for up_blocks.3.attentions.0.proj_out.weight: 1.2001367
INFO:__main__:raw data differs for up_blocks.3.attentions.1.proj_in.weight: 0.8941159
INFO:__main__:raw data differs for up_blocks.3.attentions.1.proj_out.weight: 1.0703257
INFO:__main__:raw data differs for up_blocks.3.attentions.2.proj_in.weight: 0.7948036
INFO:__main__:raw data differs for up_blocks.3.attentions.2.proj_out.weight: 0.88075924
INFO:__main__:raw data differs for mid_block.attentions.0.proj_in.weight: 0.579389
INFO:__main__:raw data differs for mid_block.attentions.0.proj_out.weight: 1.0566761
INFO:__main__:raw data differs for onnx::MatMul_9037: 2.9802322e-08
INFO:__main__:raw data differs for onnx::MatMul_9046: 7.450581e-09
INFO:__main__:raw data differs for onnx::MatMul_9047: 7.450581e-09
INFO:__main__:raw data differs for onnx::MatMul_9056: 7.450581e-09
INFO:__main__:raw data differs for onnx::MatMul_9057: 1.4901161e-08
INFO:__main__:raw data differs for onnx::MatMul_9058: 1.4901161e-08
INFO:__main__:raw data differs for onnx::MatMul_9065: 1.4901161e-08
INFO:__main__:raw data differs for onnx::MatMul_9074: 1.4901161e-08
INFO:__main__:raw data differs for onnx::MatMul_9075: 7.450581e-09

The rest of the onnx::MatMul nodes are all off by e-08 or e-09, and there's a bf16 to fp32 conversion that could explain that.

diff script:
from logging import getLogger, basicConfig, DEBUG
from numpy import maximum
from onnx import load_model, ModelProto
from onnx.numpy_helper import to_array
from sys import argv, stdout


basicConfig(stream=stdout, level=DEBUG)

logger = getLogger(__name__)

def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
  if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
    logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
  else:
    for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
      if ref_init.name != cmp_init.name:
        logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
      elif ref_init.data_location != cmp_init.data_location:
        logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
      elif ref_init.data_type != cmp_init.data_type:
        logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
      elif len(ref_init.raw_data) != len(cmp_init.raw_data):
        logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
      elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
        ref_data = to_array(ref_init)
        cmp_data = to_array(cmp_init)
        data_diff = ref_data - cmp_data
        if data_diff.max() > 0:
          logger.info("raw data differs for %s: %s\n%s", ref_init.name, data_diff.max(), data_diff)
      else:
        logger.info("initializers are identical in all checked fields: %s", ref_init.name)

  if len(ref_model.graph.node) != len(cmp_model.graph.node):
    logger.warning("different number of nodes: %s vs %s", len(ref_model.graph.node), len(cmp_model.graph.node))
  else:
    for (ref_node, cmp_node) in zip(ref_model.graph.node, cmp_model.graph.node):
      if ref_node.name != cmp_node.name:
        logger.info("different node names: %s vs %s", ref_node.name, cmp_node.name)
      elif ref_node.input != cmp_node.input:
        logger.info("different inputs: %s vs %s", ref_node.input, cmp_node.input)
      elif ref_node.output != cmp_node.output:
        logger.info("different outputs: %s vs %s", ref_node.output, cmp_node.output)
      elif ref_node.op_type != cmp_node.op_type:
        logger.info("different op type: %s vs %s", ref_node.op_type, cmp_node.op_type)
      else:
        logger.info("nodes are identical in all checked fields: %s", ref_init.name)


if __name__ == "__main__":
  ref_path = argv[1]
  cmp_paths = argv[2:]

  logger.info("loading reference model from %s", ref_path)
  ref_model = load_model(ref_path)

  for cmp_path in cmp_paths:
    logger.info("loading comparison model from %s", cmp_path)
    cmp_model = load_model(cmp_path)
    diff_models(ref_model, cmp_model)

@ssube
Copy link
Owner

ssube commented Mar 15, 2023

I have something that seems to be working, #243, with a few caveats:

  • it only supports nn.Linear and nn.Conv2d 1x1 nodes (3x3 was added while I was working on this)
  • changing the LoRA names and/or weights forces the model to be reloaded, but caching works otherwise
  • you cannot add LoRAs to extras.json and convert/optimize them ahead of time
  • optimizations may break this, I need to test that

What it does support so far:

  • loading LoRAs from .safetensors in onnx-web/models/lora/
  • blending LoRAs with the base ONNX model without ever writing it to disk
  • blending with your selected diffusion model in the onnx-web UI
  • parsing LoRA names from the prompt, using the <lora:name:weight> syntax

All of the good stuff is in https://github.com/ssube/onnx-web/blob/feat/213-lora/api/onnx_web/convert/diffusion/lora.py

Important parts, for my own reference and anyone else who finds them useful:

image

@ssube ssube added status/progress issues that are in progress and have a branch and removed status/planned issues that have been planned but not started labels Mar 17, 2023
@ssube
Copy link
Owner

ssube commented Mar 19, 2023

I have this working pretty well for LoRAs produced by the sd-scripts repo and most Textual Inversions, but it doesn't support the cloneofsimo LoRAs yet (#206). Let me know how this works for you, @ForserX, and if you or @Amblyopius have any info about getting the other networks working (hypernetworks, etc) I would be very interested.

I'm going to release what I have and the new ORT optimization stuff with 6/8GB support (#241) as v0.9.

@ForserX
Copy link
Contributor Author

ForserX commented Mar 23, 2023

I hope to see it before the end of the month. While I'm dying at work...

@ForserX
Copy link
Contributor Author

ForserX commented Mar 29, 2023

So, I lost a little "connection with the world"...
Do I need to use merge_to_sd_model or unpack first and then blend_loras?

@ForserX
Copy link
Contributor Author

ForserX commented Mar 29, 2023

getting the other networks working (hypernetworks, etc)

Regarding hypernetwork: I have only seen the implementation of auto111. But to me, a person far from Python, his code is like a personal hell, in which I understand only "None"

@ssube ssube added status/fixed issues that have been fixed and released and removed status/progress issues that are in progress and have a branch labels Mar 29, 2023
@ssube
Copy link
Owner

ssube commented Mar 29, 2023

Do I need to use merge_to_sd_model or unpack first and then blend_loras?

No need to merge. You can load the base model with onnx.load_model or pass the filename to blend_loras, and it will load the LoRA weights from the tensors directly: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L67

    base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
    lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]

The nodes need to have the right names, and running some of the more aggressive ORT optimization scripts will break that.

The logic is pretty much normal, same as yours or sd-scripts, up until https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L170

@ForserX
Copy link
Contributor Author

ForserX commented Apr 5, 2023

Good job! I tested it myself - it works great!

ForserX added a commit to ForserX/StableDiffusionUI that referenced this issue Apr 5, 2023
https: //github.com/ssube/onnx-web/issues/213
Co-Authored-By: Sean Sube <seansube@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model/diffusion model/lora status/fixed issues that have been fixed and released type/feature new features
Projects
None yet
Development

No branches or pull requests

2 participants